In [3]:
import os
import time
import torch
import random
import torchvision
import numpy as np
import torch.nn as nn
from torchvision import models
from torch.optim.lr_scheduler import StepLR

# Constants
IMGWIDTH = 224  # Smaller input image size
RANDOM_SEED = 888
BATCH_SIZE = 64  # Smaller batch size for faster training
NUM_EPOCHS = 20
SPLITRATIO = [0.80, 0.95, 0.1]
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Set seed for reproducibility
os.environ["PL_GLOBAL_SEED"] = str(RANDOM_SEED)
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)

# Data transformations
img_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((IMGWIDTH, IMGWIDTH)),
    torchvision.transforms.RandomHorizontalFlip(p=0.5),  # Data augmentation
    torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Data augmentation
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load dataset
dataset = torchvision.datasets.ImageFolder(root="D:\\Abhishek\\Spring 24\\AML\\Project\\Data\\Final/", transform=img_transform)

# Split dataset
train_idxs = torch.arange(0, int(SPLITRATIO[0] * len(dataset)))
test_idxs = torch.arange(int(SPLITRATIO[0] * len(dataset)), int(SPLITRATIO[1] * len(dataset)))
valid_idxs = torch.arange(int(SPLITRATIO[1] * len(dataset)), len(dataset))

# Create data loaders
train_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    num_workers=4,
    sampler=torch.utils.data.SubsetRandomSampler(train_idxs),
)

test_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    num_workers=4,
    sampler=torch.utils.data.SubsetRandomSampler(test_idxs))

val_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    num_workers=4,
    sampler=torch.utils.data.SubsetRandomSampler(valid_idxs))

# Use a smaller model architecture (e.g., ResNet-18)
model = models.resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)  # Binary classification
model = model.to(DEVICE)

# Use mixed precision training
scaler = torch.cuda.amp.GradScaler()

# Optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Learning rate scheduler
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)  # Reduce LR by 10x every 5 epochs

start_time = time.time()
train_acc_list, valid_acc_list = [], []

def calc_accuracy(model, data_loader, device):
    model.eval()
    correct_preds = 0
    total_preds = 0
    
    with torch.no_grad():
        for features, targets in data_loader:
            features = features.to(device)
            targets = targets.to(device)
            
            logits = model(features)
            _, predictions = torch.max(logits, 1)
            
            total_preds += targets.size(0)
            correct_preds += (predictions == targets).sum().item()
    
    accuracy = (correct_preds / total_preds) * 100
    return accuracy

# Early stopping variables
patience = 5  # Stop training after 5 epochs without validation improvement
best_val_acc = 0.0
epochs_without_improvement = 0

for epoch in range(NUM_EPOCHS):
    model.train()

    for features, targets in train_loader:
        features = features.to(DEVICE)
        targets = targets.to(DEVICE)

        with torch.cuda.amp.autocast():  # Mixed precision training
            logits = model(features)
            loss = criterion(logits, targets)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
        scaler.step(optimizer)
        scaler.update()

    scheduler.step()  # Update learning rate

    model.eval()
    with torch.no_grad():
        train_acc = calc_accuracy(model, train_loader, device=DEVICE)
        valid_acc = calc_accuracy(model, val_loader, device=DEVICE)
        train_acc_list.append(train_acc)  # No need for .item()
        valid_acc_list.append(valid_acc)  # No need for .item()

    elapsed = (time.time() - start_time) / 60
    print(f"Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | Train: {train_acc :.2f}% | Validation: {valid_acc :.2f}% | Time: {elapsed:.2f} min")

    # Early stopping
    if valid_acc > best_val_acc:
        best_val_acc = valid_acc
        epochs_without_improvement = 0
        torch.save(model.state_dict(), 'best_model.pth')  # Save the best model
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            print(f"Early stopping after {epoch+1} epochs")
            break

elapsed = (time.time() - start_time) / 60
print(f"Total Training Time: {elapsed:.2f} min")

test_acc = calc_accuracy(model, test_loader, device=DEVICE)
print(f"Test accuracy {test_acc :.2f}%")

Epoch: 001/020 | Train: 60.57% | Validation: 27.21% | Time: 8.24 min
Epoch: 002/020 | Train: 79.75% | Validation: 68.50% | Time: 16.54 min
Epoch: 003/020 | Train: 68.87% | Validation: 41.54% | Time: 24.81 min
Epoch: 004/020 | Train: 60.96% | Validation: 25.86% | Time: 33.16 min
Epoch: 005/020 | Train: 66.86% | Validation: 99.88% | Time: 41.31 min
Epoch: 006/020 | Train: 92.17% | Validation: 97.06% | Time: 49.94 min
Epoch: 007/020 | Train: 94.23% | Validation: 93.87% | Time: 62.04 min
Epoch: 008/020 | Train: 93.65% | Validation: 95.10% | Time: 70.58 min
Epoch: 009/020 | Train: 94.34% | Validation: 97.06% | Time: 79.08 min
Epoch: 010/020 | Train: 95.78% | Validation: 93.75% | Time: 87.52 min
Early stopping after 10 epochs
Total Training Time: 87.52 min
Test accuracy 94.36%
