In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import shutil
from PIL import Image
from tempfile import TemporaryDirectory
import albumentations as A
import warnings
from tqdm import tqdm
import random
from sklearn.model_selection import train_test_split
import gc

warnings.filterwarnings('ignore')

# Set seeds for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed()

# Configuration 
CONFIG = {
    'data_dir': "D:\\Work\\dont_plot_images\\IPI App Data-20250119T224817Z-001\\IPI App Data",
    'augmented_data_dir': "D:\\Work\\dont_plot_images\\augmented_data", 
    'batch_size': 8,      # Small batch size to prevent OOM
    'num_epochs': 150,
    'learning_rate': 0.0002,
    'weight_decay': 0.005,
    'image_size': 224,    # Standard ImageNet size
    'train_ratio': 0.7,
    'val_ratio': 0.15,
    'test_ratio': 0.15,
    'num_workers': 2      # Reduced workers to limit memory usage
}

# Device configuration - keep it simple
device = torch.device("cpu")  # Default to CPU to prevent OOM errors
print(f'Using device: {device}')

# Memory cleanup function
def clear_memory():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

class SimplifiedAugmentation:
    def __init__(self, num_variations=3):  # Reduced from 5 to 3
        self.num_variations = num_variations
        self.transform_sets = [
            # Set 1: Basic Transformations
            A.Compose([
                A.RandomRotate90(p=0.7),
                A.HorizontalFlip(p=0.5),
                A.RandomBrightnessContrast(p=0.7),
            ]),
            
            # Set 2: Geometric Transformations
            A.Compose([
                A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=45, p=0.7),
            ]),
            
            # Set 3: Mixed Transformations
            A.Compose([
                A.HorizontalFlip(p=0.5),
                A.RandomBrightnessContrast(p=0.5),
            ])
        ]
    
    def __call__(self, img):
        augmented_images = []
        img_array = np.array(img)
        
        if len(img_array.shape) == 2:
            img_array = np.stack([img_array] * 3, axis=-1)
            
        # Original image
        augmented_images.append(img)
        
        # Generate variations using different transform sets
        for transform_set in self.transform_sets:
            augmented = transform_set(image=img_array)['image']
            augmented_images.append(Image.fromarray(augmented))
            
        return augmented_images

def create_augmented_dataset():
    """Simplified augmentation pipeline"""
    original_data_dir = CONFIG['data_dir']
    augmented_data_dir = CONFIG['augmented_data_dir']
    
    if not os.path.exists(augmented_data_dir):
        os.makedirs(augmented_data_dir)
    
    # Get all class folders
    class_folders = []
    
    # Process training directory
    if os.path.exists(os.path.join(original_data_dir, 'train')):
        for class_name in os.listdir(os.path.join(original_data_dir, 'train')):
            class_path = os.path.join(original_data_dir, 'train', class_name)
            if os.path.isdir(class_path):
                class_folders.append((class_path, class_name))
    
    # Process test directory
    if os.path.exists(os.path.join(original_data_dir, 'test')):
        for class_name in os.listdir(os.path.join(original_data_dir, 'test')):
            class_path = os.path.join(original_data_dir, 'test', class_name)
            if os.path.isdir(class_path):
                class_folders.append((class_path, class_name))
    
    # Track all images
    all_images = []
    
    # Create augmentation pipeline
    augmentation = SimplifiedAugmentation(num_variations=3)
    
    # Augment all images
    for class_path, class_name in class_folders:
        # Create class directory
        augmented_class_dir = os.path.join(augmented_data_dir, class_name)
        if not os.path.exists(augmented_class_dir):
            os.makedirs(augmented_class_dir)
        
        # Process images
        for img_name in tqdm(os.listdir(class_path), desc=f"Augmenting {class_name}"):
            img_path = os.path.join(class_path, img_name)
            
            if not img_path.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp')):
                continue
            
            try:
                img = Image.open(img_path).convert('RGB')
                augmented_images = augmentation(img)
                
                for i, aug_img in enumerate(augmented_images):
                    suffix = f"_aug{i}" if i > 0 else ""
                    base_name, ext = os.path.splitext(img_name)
                    aug_img_path = os.path.join(augmented_class_dir, f"{base_name}{suffix}{ext}")
                    aug_img.save(aug_img_path)
                    all_images.append((aug_img_path, class_name))
            
            except Exception as e:
                print(f"Error processing {img_path}: {e}")
    
    clear_memory()
    return all_images

def split_dataset(all_images):
    """Split images into train/val/test sets"""
    # Group by original image to prevent data leakage
    image_groups = {}
    for img_path, class_name in all_images:
        # Extract base name without augmentation suffix
        base_name = os.path.basename(img_path)
        if "_aug" in base_name:
            base_name = base_name.split("_aug")[0]
        
        if base_name not in image_groups:
            image_groups[base_name] = []
        
        image_groups[base_name].append((img_path, class_name))
    
    # Get unique image groups
    unique_images = list(image_groups.keys())
    
    # Split into train, validation, and test sets
    train_val_images, test_images = train_test_split(
        unique_images, 
        test_size=CONFIG['test_ratio'], 
        random_state=42
    )
    
    train_images, val_images = train_test_split(
        train_val_images, 
        test_size=CONFIG['val_ratio'] / (CONFIG['train_ratio'] + CONFIG['val_ratio']), 
        random_state=42
    )
    
    # Create the split datasets
    splits = {
        'train': [],
        'val': [],
        'test': []
    }
    
    # Add all variations of each original image to the appropriate split
    for base_name in train_images:
        splits['train'].extend(image_groups[base_name])
    
    for base_name in val_images:
        splits['val'].extend(image_groups[base_name])
    
    for base_name in test_images:
        splits['test'].extend(image_groups[base_name])
    
    # Create directories for the splits
    for split in splits:
        split_dir = os.path.join(CONFIG['augmented_data_dir'], split)
        if not os.path.exists(split_dir):
            os.makedirs(split_dir)
        
        # Create class directories
        for _, class_name in splits[split]:
            class_dir = os.path.join(split_dir, class_name)
            if not os.path.exists(class_dir):
                os.makedirs(class_dir)
    
    # Copy images to appropriate directories
    for split, images in splits.items():
        for img_path, class_name in tqdm(images, desc=f"Copying {split} images"):
            dest_path = os.path.join(
                CONFIG['augmented_data_dir'], 
                split, 
                class_name, 
                os.path.basename(img_path)
            )
            shutil.copy(img_path, dest_path)
    
    # Print statistics
    print(f"\nDataset split complete:")
    print(f"  Training: {len(splits['train'])} images")
    print(f"  Validation: {len(splits['val'])} images")
    print(f"  Test: {len(splits['test'])} images")
    
    clear_memory()
    return splits

# Data transforms
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

def load_data():
    """Load datasets and create dataloaders"""
    try:
        # Create ImageFolder datasets
        datasets_dict = {
            split: torchvision.datasets.ImageFolder(
                os.path.join(CONFIG['augmented_data_dir'], split),
                transform=data_transforms[split]
            )
            for split in ['train', 'val', 'test']
        }
        
        # Create dataloaders
        dataloaders = {
            split: torch.utils.data.DataLoader(
                datasets_dict[split],
                batch_size=CONFIG['batch_size'],
                shuffle=(split == 'train'),
                num_workers=CONFIG['num_workers'],
                pin_memory=False
            )
            for split in ['train', 'val', 'test']
        }
        
        # Get dataset sizes
        dataset_sizes = {split: len(datasets_dict[split]) for split in ['train', 'val', 'test']}
        
        # Get class names
        class_names = datasets_dict['train'].classes
        
        return dataloaders, dataset_sizes, class_names
        
    except Exception as e:
        print(f"Error loading data: {str(e)}")
        raise

def create_model(num_classes):
    """Create a simplified ResNet18 model (smaller than ResNet50)"""
    # ResNet18 is much smaller than ResNet50
    model = models.resnet18(weights='IMAGENET1K_V1')
    
    # Modify the classifier
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    
    return model

def train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, num_epochs=25):
    """Training function with early stopping"""
    since = time.time()
    stats = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'test_loss': None,
        'test_acc': None
    }
    
    patience = 5  # Reduced patience for faster stopping
    early_stopping_counter = 0
    best_val_loss = float('inf')
    
    # Set up mixed precision if available
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
    
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')
        torch.save(model.state_dict(), best_model_params_path)

        for epoch in range(num_epochs):
            print(f'\nEpoch {epoch}/{num_epochs - 1}')
            print('-' * 10)

            # Only train and validate during training loop
            for phase in ['train', 'val']:
                if phase == 'train':
                    clear_memory()
                    model.train()
                else:
                    model.eval()

                running_loss = 0.0
                running_corrects = 0

                pbar = tqdm(dataloaders[phase], desc=f'{phase} epoch {epoch}')
                
                for inputs, labels in pbar:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    optimizer.zero_grad()

                    # Forward pass
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                        # Backward + optimize only in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    # Statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)

                    # Update progress bar
                    batch_size = inputs.size(0)
                    current_loss_avg = running_loss / ((pbar.n + 1) * batch_size)
                    current_acc = running_corrects.double() / ((pbar.n + 1) * batch_size)
                    pbar.set_postfix({
                        'loss': f'{current_loss_avg:.4f}',
                        'acc': f'{current_acc:.4f}'
                    })

                    # Free memory
                    del inputs, labels, outputs
                    clear_memory()

                if phase == 'train':
                    scheduler.step()

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]

                if phase == 'train':
                    stats['train_loss'].append(epoch_loss)
                    stats['train_acc'].append(epoch_acc.item())
                else:
                    stats['val_loss'].append(epoch_loss)
                    stats['val_acc'].append(epoch_acc.item())

                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                # Early stopping logic
                if phase == 'val':
                    if epoch_loss < best_val_loss:
                        best_val_loss = epoch_loss
                        early_stopping_counter = 0
                        torch.save(model.state_dict(), best_model_params_path)
                    else:
                        early_stopping_counter += 1
                    
                    if early_stopping_counter >= patience:
                        print(f'Early stopping triggered at epoch {epoch}')
                        model.load_state_dict(torch.load(best_model_params_path))
                        # Test on the test set before returning
                        test_loss, test_acc = evaluate_model(model, criterion, dataloaders['test'], dataset_sizes['test'])
                        stats['test_loss'] = test_loss
                        stats['test_acc'] = test_acc
                        return model, stats

        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        model.load_state_dict(torch.load(best_model_params_path))
        
        # Test on the test set after training
        test_loss, test_acc = evaluate_model(model, criterion, dataloaders['test'], dataset_sizes['test'])
        stats['test_loss'] = test_loss
        stats['test_acc'] = test_acc
        
    clear_memory()
    return model, stats

def evaluate_model(model, criterion, dataloader, dataset_size):
    """Evaluate model on a dataset"""
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)
            
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            
            # Free memory
            del inputs, labels, outputs
            clear_memory()
    
    loss = running_loss / dataset_size
    acc = running_corrects.double() / dataset_size
    
    print(f'Test Loss: {loss:.4f} Acc: {acc:.4f}')
    
    return loss, acc.item()

def visualize_model(model, dataloaders, class_names, num_images=6):
    """Visualize model predictions"""
    was_training = model.training
    model.eval()
    
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['test']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            plt.figure(figsize=(15, 5))
            for j in range(min(num_images, inputs.size()[0])):
                plt.subplot(2, num_images//2, j+1)
                plt.axis('off')
                plt.title(f'predicted: {class_names[preds[j]]}')
                img = inputs.cpu().data[j].numpy().transpose((1, 2, 0))
                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                img = std * img + mean
                img = np.clip(img, 0, 1)
                plt.imshow(img)
            plt.tight_layout()
            break
    
    model.train(mode=was_training)

def main():
    try:
        print("Step 1: Creating augmented dataset...")
        if not os.path.exists(CONFIG['augmented_data_dir']) or len(os.listdir(CONFIG['augmented_data_dir'])) == 0:
            all_images = create_augmented_dataset()
            print(f"Created {len(all_images)} augmented images")
            
            print("Step 2: Splitting dataset into train, validation, and test sets...")
            splits = split_dataset(all_images)
        else:
            print("Using existing augmented dataset...")
        
        print("Step 3: Loading data...")
        dataloaders, dataset_sizes, class_names = load_data()
        
        # Print dataset statistics
        print("\nDataset Statistics:")
        print(f"Total training samples: {dataset_sizes['train']}")
        print(f"Total validation samples: {dataset_sizes['val']}")
        print(f"Total test samples: {dataset_sizes['test']}")
        print(f"Number of classes: {len(class_names)}")
        
        print("\nStep 4: Initializing model...")
        model = create_model(len(class_names))
        
        # Try GPU, but with fallback to CPU if there's an issue
        if torch.cuda.is_available():
            try:
                # Small test tensor to check if GPU has memory
                test = torch.ones(1).cuda()
                del test
                # If we get here, GPU seems available
                global device
                device = torch.device("cuda:0")
                print(f"Using GPU: {torch.cuda.get_device_name(0)}")
            except RuntimeError:
                print("GPU memory issue detected, falling back to CPU")
                device = torch.device("cpu")
        
        model = model.to(device)
        
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(
            model.parameters(),
            lr=CONFIG['learning_rate'],
            weight_decay=CONFIG['weight_decay']
        )
        
        # Simple step LR scheduler
        scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
        
        print("\nStep 5: Starting training...")
        model, stats = train_model(
            model, criterion, optimizer, scheduler,
            dataloaders, dataset_sizes, num_epochs=CONFIG['num_epochs']
        )
        
        print("\nStep 6: Plotting results...")
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        plt.plot(stats['train_loss'], label='Train')
        plt.plot(stats['val_loss'], label='Validation')
        plt.title('Loss vs. Epoch')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(stats['train_acc'], label='Train')
        plt.plot(stats['val_acc'], label='Validation')
        plt.title('Accuracy vs. Epoch')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig('training_results.png')
        plt.show()
        
        print("\nStep 7: Testing final model on test set...")
        test_loss, test_acc = stats['test_loss'], stats['test_acc']
        print(f"Final Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")
        
        print("\nStep 8: Visualizing model predictions...")
        visualize_model(model, dataloaders, class_names)
        plt.show()
        
        # Save the final model
        torch.save({
            'epoch': CONFIG['num_epochs'],
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'stats': stats,
            'class_names': class_names
        }, 'final_model.pth')
        
        print("Training pipeline complete!")
        
    except Exception as e:
        print(f"Error in main: {str(e)}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()