In [None]:
%pip install torch torchvision matplotlib numpy
%pip install torchvision  

import torch
import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torch.optim import lr_scheduler
import time
import copy
import os

In [None]:
import os
import torch
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from torch import nn, optim
from torch.optim import lr_scheduler
from tqdm import tqdm
import time
import copy
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Set paths
DATASET_PATH = '../dataset/raw'
TRAIN_PATH = os.path.join(DATASET_PATH, 'train')
TEST_PATH = os.path.join(DATASET_PATH, 'test')
MODEL_SAVE_PATH = '../data/models'
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set hyperparameters
IMAGE_SIZE = 224
BATCH_SIZE = 32
NUM_EPOCHS = 25 
LEARNING_RATE = 0.001
WEIGHT_DECAY = 1e-4  # Add regularization
FEATURE_EXTRACT = False  # Set to False to fine-tune all layers

# 1. This calculates the dataset statistics for better normalization
def calculate_dataset_stats(data_loader):
    mean = torch.zeros(3)
    std = torch.zeros(3)
    total_images = 0
    
    for images, _ in tqdm(data_loader, desc="Calculating dataset statistics"):
        batch_size = images.size(0)
        images = images.view(batch_size, images.size(1), -1)
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)
        total_images += batch_size
    
    mean /= total_images
    std /= total_images
    
    return mean, std

# First use a simple transform to calculate statistics
simple_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor()
])

initial_dataset = datasets.ImageFolder(root=TRAIN_PATH, transform=simple_transforms)
initial_loader = DataLoader(initial_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

try:
    print("Calculating dataset statistics...")
    mean, std = calculate_dataset_stats(initial_loader)
    print(f"Dataset mean: {mean}, std: {std}")
except Exception as e:
    print(f"Error calculating statistics: {e}. Using default values.")
    mean = torch.tensor([0.485, 0.456, 0.406])  #Using ImageNet stats as fallback
    std = torch.tensor([0.229, 0.224, 0.225])

# 2. This creates improved data transforms with proper normalization and more augmentations
train_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(), 
    transforms.RandomRotation(20),  
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),  
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

test_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

# 3. This creates the dataset and address class imbalance
train_dataset = datasets.ImageFolder(root=TRAIN_PATH, transform=train_transforms)
test_dataset = datasets.ImageFolder(root=TEST_PATH, transform=test_transforms)

# Calculating class weights for handling imbalance
class_counts = [0] * len(train_dataset.classes)
for _, index in train_dataset.samples:
    class_counts[index] += 1

# Creating weights for each sample in the dataset
class_weights = 1. / torch.tensor(class_counts, dtype=torch.float)
sample_weights = [class_weights[class_idx] for _, class_idx in train_dataset.samples]
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

# Creating data loaders with sampler for training
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2, pin_memory=True
)
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True
)

class_names = train_dataset.classes
print(f"Classes: {class_names}")

# 4. This is a function to set model parameters to require or not require gradients
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

# 5. This initaites the model (ResNet50 for better performance)
def initialize_model(num_classes, feature_extract=False):
    model = models.resnet50(pretrained=True)
    set_parameter_requires_grad(model, feature_extract)
    
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.3), 
        nn.Linear(num_ftrs, num_classes)
    )
    return model

# Initialize model
model = initialize_model(len(class_names), feature_extract=FEATURE_EXTRACT)
model = model.to(device)

# 6. This defines loss function with class weights
class_weights_tensor = class_weights.to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)

# 7. This defines optimizers and schedulers
if FEATURE_EXTRACT:
    # Only optimize the classifier parameters
    params_to_update = [p for p in model.parameters() if p.requires_grad]
else:
    # Different learning rates for layers
    # Lower learning rate for pre-trained layers, higher for the new classifier
    params_to_update = [
        {'params': model.fc.parameters(), 'lr': LEARNING_RATE},
        {'params': [p for n, p in model.named_parameters() if 'fc' not in n], 'lr': LEARNING_RATE/10}
    ]

optimizer = optim.AdamW(params_to_update, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Learning rate scheduler - reduce LR when validation loss plateaus
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

# 8. This function is to train the model with improvements
def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25):
    start_time = time.time()
    
    # Track best model weights
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    # Track history for plotting
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': []
    }
    
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
                dataloader = dataloaders['train']
            else:
                model.eval()   # Set model to evaluate mode
                dataloader = dataloaders['val']
            
            running_loss = 0.0
            running_corrects = 0
            
            # Progress bar for iterations
            loop = tqdm(dataloader, desc=f"{phase.capitalize()} Progress")
            
            # Iterate over data
            for inputs, labels in loop:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                # Zero the parameter gradients
                optimizer.zero_grad()
                
                # Forward
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        # Gradient clipping to prevent exploding gradients
                        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                        optimizer.step()
                
                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
                # Update progress bar
                loop.set_postfix(loss=loss.item())
            
            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_corrects.double() / len(dataloader.dataset)
            
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            # Save history for plotting
            if phase == 'train':
                history['train_loss'].append(epoch_loss)
                history['train_acc'].append(epoch_acc.item())
            else:
                history['val_loss'].append(epoch_loss)
                history['val_acc'].append(epoch_acc.item())
                # Update the learning rate scheduler
                scheduler.step(epoch_loss)
            
            # Deep copy the model if best validation accuracy
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                # Save the best model so far
                torch.save(model.state_dict(), os.path.join(MODEL_SAVE_PATH, 'best_model.pth'))
                print(f"Saved new best model with accuracy: {best_acc:.4f}")
        
        # Save checkpoint every 5 epochs for long training
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_acc': best_acc,
            }, os.path.join(MODEL_SAVE_PATH, f'checkpoint_epoch_{epoch+1}.pth'))
    
    time_elapsed = time.time() - start_time
    print(f'\nTraining completed in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best validation accuracy: {best_acc:.4f}')
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.legend()
    plt.title('Loss over epochs')
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Accuracy')
    plt.plot(history['val_acc'], label='Validation Accuracy')
    plt.legend()
    plt.title('Accuracy over epochs')
    
    plt.tight_layout()
    plt.savefig(os.path.join(MODEL_SAVE_PATH, 'training_history.png'))
    plt.show()
    
    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model, history

# 9. Train the model
dataloaders = {
    'train': train_loader,
    'val': test_loader
}

model, history = train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=NUM_EPOCHS)

# 10. This is to save the final model
torch.save({
    'model_state_dict': model.state_dict(),
    'class_names': class_names,
    'mean': mean,
    'std': std,
    'history': history
}, os.path.join(MODEL_SAVE_PATH, 'final_model.pth'))

print("Final model saved to:", os.path.join(MODEL_SAVE_PATH, 'final_model.pth'))

# 11. Perform validation and create confusion matrix
def validate_model(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Validating"):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(15, 13))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.xticks(rotation=90)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(os.path.join(MODEL_SAVE_PATH, 'final_confusion_matrix.png'))
    plt.show()
    
    return all_preds, all_labels

print("\nValidating model on test set...")
predictions, true_labels = validate_model(model, test_loader)