# Imports

In [120]:
# PyTorch core libraries
import torch                        # Core PyTorch package for tensor operations
import torch.nn as nn               # For building neural network layers
import torch.nn.functional as F
import torch.optim as optim         # Optimizers like SGD, Adam

# torchvision for image-specific utilities and models
from torchvision import models      # Common pretrained models

# Misc
import time                        # To measure training duration

In [121]:
import random
import numpy as np

torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

# Now, everything that uses randomness below will be reproducible

In [137]:
import sys
%reload_ext autoreload
%autoreload 2
# Add the parent directory to sys.path to import 'scripts'
sys.path.append("..")  

# ====== Load Data ======
from scripts.data_preprocessing import train_loader, val_loader, train_dataset

In [123]:
# Sanity check: Print dataset info

print(f"Total samples in dataset: {len(train_dataset)}")
print(f"Batch size: {train_loader.batch_size}")
print(f"Number of training batches: {len(train_loader)}")

Total samples in dataset: 6712
Batch size: 128
Number of training batches: 53


In [124]:
# Step 2: Setup device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [125]:
#Step 3: Load Pretrained Model (Transfer Learning)
from torchvision.models import resnet50, ResNet50_Weights

# Load pretrained ResNet50
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)

# Freeze all layers except layer3, layer4, and fc
for name, param in model.named_parameters():
    if "layer3" in name or "layer4" in name or "fc" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

# Replace the classifier for binary output
model.fc = nn.Sequential(
    nn.Linear(2048, 512),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(512, 2)  # Binary classification
)

# Move model to device
model = model.to(device)

In [126]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha  # Set as 0.25 for class imbalance
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)  # prevents nans when probability is 0
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        if self.alpha is not None:
            alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
            focal_loss *= alpha_t

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        return focal_loss

In [127]:
#Step 4: Define Loss Function and Optimizer
criterion = FocalLoss(gamma=2, alpha=0.25)

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=2, factor=0.5)

In [128]:
# Step 5: Training Loop + Validation (Clean Version)
num_epochs = 30
best_val_acc = 0.0

epochs_no_improve = 0

overall_start = time.time()  # Start tracking total training time

class EarlyStopping:
    def __init__(self, patience=4, delta=0.001):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0
            
# Initialize EarlyStopping
early_stopping = EarlyStopping(patience=5)

for epoch in range(num_epochs):
    start_time = time.time()  # Start timer
    
    model.train()
    running_loss = 0.0
    train_correct = 0  # To track correct predictions during training
    train_total = 0    # To track total samples during training

    # Training loop
    for images, labels in train_loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Calculate training accuracy
        _, predicted = torch.max(outputs, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()

    avg_train_loss = running_loss / len(train_loader)
    train_acc = train_correct / train_total  # Calculate training accuracy

    # ====== Validation ======
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            outputs = model(images)
            loss = criterion(outputs, labels)  # Compute val loss
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss /= len(val_loader)  # Average val loss
    val_acc = correct / total
    scheduler.step(val_acc)  # Update learning rate based on validation accuracy
    
    # Print current learning rate
    for param_group in optimizer.param_groups:
        print(f"Current LR: {param_group['lr']}")
    
    epoch_time = time.time() - start_time  # End timer

    print(f"Epoch [{epoch+1}/{num_epochs}] - "
          f"Train Loss: {avg_train_loss:.4f} | "
          f"Train Accuracy: {train_acc * 100:.2f}% | "
          f"Val Accuracy: {val_acc * 100:.2f}% | "
          f"Val Loss: {val_loss:.4f} | "
          f"Time: {epoch_time:.2f}s")

    # Save best model by val_acc
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")

    # Check early stopping using val_loss
    early_stopping(val_loss)
    if early_stopping.early_stop:
        print(f"Early stopping triggered at epoch {epoch+1}")
        break
            
overall_end = time.time()
print(f"\n Total training time: {(overall_end - overall_start)/60:.2f} minutes")

Current LR: 0.0001
Epoch [1/30] - Train Loss: 0.0285 | Train Accuracy: 82.97% | Val Accuracy: 88.07% | Val Loss: 0.0353 | Time: 118.71s
Current LR: 0.0001
Epoch [2/30] - Train Loss: 0.0101 | Train Accuracy: 95.96% | Val Accuracy: 95.58% | Val Loss: 0.0135 | Time: 82.29s
Current LR: 0.0001
Epoch [3/30] - Train Loss: 0.0082 | Train Accuracy: 97.12% | Val Accuracy: 97.14% | Val Loss: 0.0109 | Time: 84.15s
Current LR: 0.0001
Epoch [4/30] - Train Loss: 0.0070 | Train Accuracy: 97.35% | Val Accuracy: 96.18% | Val Loss: 0.0139 | Time: 84.15s
Current LR: 0.0001
Epoch [5/30] - Train Loss: 0.0066 | Train Accuracy: 97.54% | Val Accuracy: 89.26% | Val Loss: 0.0226 | Time: 85.81s
Current LR: 5e-05
Epoch [6/30] - Train Loss: 0.0064 | Train Accuracy: 97.51% | Val Accuracy: 94.63% | Val Loss: 0.0121 | Time: 89.24s
Current LR: 5e-05
Epoch [7/30] - Train Loss: 0.0051 | Train Accuracy: 97.72% | Val Accuracy: 96.90% | Val Loss: 0.0096 | Time: 90.14s
Current LR: 5e-05
Epoch [8/30] - Train Loss: 0.0052 | Tr

In [129]:
# Load best model for testing
model.load_state_dict(torch.load("best_model.pth"))

<All keys matched successfully>