In [1]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
from torchvision.transforms import InterpolationMode

# Define Transforms 
# These numbers are the Mean and Std of the ImageNet dataset
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224), interpolation=InterpolationMode.BILINEAR),
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.ToTensor(),
        normalize 
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224), interpolation=InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        normalize
    ]),
}

data_dir = './final_split_dataset'

# Create Datasets
TRAIN_FOLDER_NAME = 'train' 
VAL_FOLDER_NAME = 'validate'      

image_datasets = {
    'train': datasets.ImageFolder(os.path.join(data_dir, TRAIN_FOLDER_NAME), data_transforms['train']),
    'val':   datasets.ImageFolder(os.path.join(data_dir, VAL_FOLDER_NAME), data_transforms['val'])
}

# Create DataLoaders
dataloaders = {
    'train': DataLoader(image_datasets['train'], batch_size=32, shuffle=True, num_workers=4),
    'val':   DataLoader(image_datasets['val'], batch_size=32, shuffle=False, num_workers=4)
}

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models

# Load out Backbone here I have ResNet
weights = models.ResNet152_Weights.DEFAULT
model = models.resnet152(weights=weights)

# Freeze the "Body"
# We only want to train your new MLP, not the millions of parameters in the ResNet.
for param in model.parameters():
    param.requires_grad = False

# Construct the 3-Layer MLP Head with GELU
num_ftrs = model.fc.in_features 



model.fc = nn.Sequential(
    # Layer 1: Reduction
    nn.Linear(num_ftrs,4)
#    nn.Linear(num_ftrs, 1024),
#    nn.GELU(),                # You requested GELU instead of ReLU
#    nn.Dropout(0.3),          # Heavier dropout for a larger model
#    
#    # Layer 2: Reduction
#    nn.Linear(1024, 512),
#    nn.GELU(),
#    nn.Dropout(0.3),
#    
#    # Layer 3: Classification
#    nn.Linear(512, 4)         # Output = 4 categories
)

In [6]:
# Move to GPU (Must happen before defining optimizer)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Loss Function
criterion = nn.CrossEntropyLoss()

# Optimizer (AdamW)
optimizer = optim.AdamW(model.fc.parameters(), lr=0.001, weight_decay=0.01)

In [7]:
num_epochs = 10

for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    print('-' * 10)

    # --- TRAIN PHASE ---
    model.train()  # CRITICAL: Switches layers like Dropout to "Train" mode
    
    running_loss = 0.0
    
    # Iterate over the DataLoader
    for inputs, labels in dataloaders['train']:
        # Move data to the same device as the model
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero the parameter gradients
        # PyTorch accumulates gradients by default. We must reset them.
        optimizer.zero_grad()

        # Forward Pass
        # We need to track gradients here (enabled by default in train mode)
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward Pass & Optimize
        loss.backward()  # Calculates the gradients (dLoss/dx)
        optimizer.step() # Updates weights: w = w - lr * gradient

        # Statistics
        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(image_datasets['train'])
    print(f'Train Loss: {epoch_loss:.4f}')

    
    model.eval() 
    
    val_running_loss = 0.0 
    val_corrects = 0      
    
    with torch.no_grad(): # Disable Gradient Calculation
        for inputs, labels in dataloaders['val']:
            inputs = inputs.to(device)
            labels = labels.to(device)

        
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            
            
            # Statistics
            val_running_loss += loss.item() * inputs.size(0)
            
            # Calculate Accuracy
            # torch.max returns (max_value, index_of_max_value)
            _, preds = torch.max(outputs, 1) 
            val_corrects += torch.sum(preds == labels.data)

    val_loss = val_running_loss / len(image_datasets['val'])
    val_acc = val_corrects.double() / len(image_datasets['val'])
    
    print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')

Epoch 1/10
----------
Train Loss: 0.8559
Val Loss: 0.6615 Acc: 0.7673
Epoch 2/10
----------
Train Loss: 0.5863
Val Loss: 0.5505 Acc: 0.8139
Epoch 3/10
----------
Train Loss: 0.5253
Val Loss: 0.5204 Acc: 0.8139
Epoch 4/10
----------
Train Loss: 0.4551
Val Loss: 0.4982 Acc: 0.8257
Epoch 5/10
----------
Train Loss: 0.4369
Val Loss: 0.4704 Acc: 0.8321
Epoch 6/10
----------
Train Loss: 0.4188
Val Loss: 0.4445 Acc: 0.8339
Epoch 7/10
----------
Train Loss: 0.3923
Val Loss: 0.4449 Acc: 0.8422
Epoch 8/10
----------
Train Loss: 0.3901
Val Loss: 0.4464 Acc: 0.8312
Epoch 9/10
----------
Train Loss: 0.3677
Val Loss: 0.4320 Acc: 0.8485
Epoch 10/10
----------
Train Loss: 0.3589
Val Loss: 0.4209 Acc: 0.8449
