In [1]:
# loading dependencies

import torch
import torch.nn as nn
from torchvision import models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import tqdm
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# load resnet50 with IMAGENET weights

model = models.resnet50(weights = models.ResNet50_Weights.IMAGENET1K_V1)

In [3]:
# customize classification head

old_classes = model.fc.in_features
new_classes = 45

model.fc = nn.Linear(old_classes, new_classes)

In [4]:
# freeze all layers except layer 4  and FCL and moving to GPU

for name, param in model.named_parameters():
    if 'layer4' not in name and 'fc' not in name:
        param.requires_grad = False # frozen

model = model.to(device)

In [5]:
# set up optimizers and a LR scheduler

fc_params = list(model.fc.parameters())
layer4_params = list(model.layer4.parameters())

# optimizer with different LRs for FC & Layer4

optimizer = torch.optim.Adam([
    {'params': fc_params, 'lr': 1e-4}, # bigger steps for faster convergence
    {'params': layer4_params, 'lr': 1e-5} # slower steps for rich feature representations
], weight_decay = 1e-4 # Regularization technique - keeps weights small and increases val. acc. on small datasets by reducing risk of overfitting.
)

# LR scheduler

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( # for small datasets val. acc. plateaus fast, when that happens we reduce LR
    optimizer,
    mode = 'min',
    patience = 3,
    factor = 0.1
)

# loss function

criterion = torch.nn.CrossEntropyLoss()

In [6]:
# data augmentations and dataloaders

img_size = 224

train_transforms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(p = 0.5),
    transforms.RandomVerticalFlip(p = 0.2),
    transforms.RandomRotation(degrees = 15),
    transforms.RandomResizedCrop(img_size, scale = (0.8, 1.0)),
    transforms.ColorJitter(
        brightness = 0.2,
        contrast = 0.2,
        saturation = 0.2,
        hue = 0.02
    ),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

valid_transforms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

# data dirs

train_dir = 'data/train/'
valid_dir = 'data/val/'

train_ds = datasets.ImageFolder(train_dir, transform = train_transforms)
val_ds = datasets.ImageFolder(valid_dir, transform = valid_transforms)

# dataloaders

train_loader = DataLoader(
    train_ds,
    batch_size = 32,
    shuffle = True,
    num_workers = 4,
    pin_memory = True
)

valid_loader = DataLoader(
    val_ds,
    batch_size = 32,
    shuffle = False,
    num_workers = 4,
    pin_memory = True
)

In [7]:
# training loop

epochs = 10
patience = 5
best_val_loss = np.inf
no_improve = 0
scaler = torch.cuda.amp.GradScaler()

for epoch in range(epochs):

    model.train()
    train_loss = 0.0
    train_correct = 0

    for imgs, labels in tqdm.tqdm(train_loader, desc = f"Train Epoch {epoch+1}/{epochs}"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad(set_to_none = True)

        with torch.cuda.amp.autocast():
            outputs = model(imgs)
            loss = criterion(outputs, labels)

        # step backward with gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item() * imgs.size(0)
        preds = outputs.argmax(dim = 1)
        train_correct += (preds ==labels).sum().item()

    train_loss /= len(train_loader.dataset)
    train_acc = train_correct / len(train_loader.dataset)

    # validation

    model.eval()
    val_loss = 0.0
    val_correct = 0

    with torch.no_grad():
        for imgs, labels in tqdm.tqdm(valid_loader, desc = 'Valid'):
            imgs, labels = imgs.to(device), labels.to(device)

            with torch.cuda.amp.autocast():
                outputs = model(imgs)
                loss = criterion(outputs, labels)

            val_loss += loss.item() * imgs.size(0)
            preds = outputs.argmax(dim =  1)
            val_correct += (preds == labels).sum().item()

    val_loss /= len(valid_loader.dataset)
    val_acc = val_correct / len(valid_loader.dataset)

    scheduler.step(val_loss)

    print(f"\nEpoch {epoch+1}/{epochs}")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.4f}")

    # early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improve = 0

        torch.save(model.state_dict(), "best_model.pth")
        print("Saved best model checkpoint.")

    else:
        no_improve += 1
        print(f"No improvement for {no_improve}/{patience} epochs.")

        if no_improve >= patience:
            print("\nEarly stopping triggered.")
            break

  scaler = torch.cuda.amp.GradScaler()
Train Epoch 1/10:   0%|          | 0/344 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():
Train Epoch 1/10: 100%|██████████| 344/344 [02:28<00:00,  2.32it/s]
  with torch.cuda.amp.autocast():
Valid: 100%|██████████| 87/87 [00:37<00:00,  2.30it/s]



Epoch 1/10
Train Loss: 2.0768 | Train Acc: 0.6229
Val   Loss: 0.6054 | Val   Acc: 0.8962
Saved best model checkpoint.


Train Epoch 2/10: 100%|██████████| 344/344 [02:27<00:00,  2.33it/s]
Valid: 100%|██████████| 87/87 [00:52<00:00,  1.66it/s]



Epoch 2/10
Train Loss: 0.7508 | Train Acc: 0.8449
Val   Loss: 0.3032 | Val   Acc: 0.9306
Saved best model checkpoint.


Train Epoch 3/10: 100%|██████████| 344/344 [02:28<00:00,  2.32it/s]
Valid: 100%|██████████| 87/87 [00:37<00:00,  2.34it/s]



Epoch 3/10
Train Loss: 0.4990 | Train Acc: 0.8767
Val   Loss: 0.2371 | Val   Acc: 0.9403
Saved best model checkpoint.


Train Epoch 4/10: 100%|██████████| 344/344 [02:27<00:00,  2.34it/s]
Valid: 100%|██████████| 87/87 [00:37<00:00,  2.30it/s]



Epoch 4/10
Train Loss: 0.3966 | Train Acc: 0.8973
Val   Loss: 0.2040 | Val   Acc: 0.9439
Saved best model checkpoint.


Train Epoch 5/10: 100%|██████████| 344/344 [02:27<00:00,  2.32it/s]
Valid: 100%|██████████| 87/87 [00:44<00:00,  1.97it/s]



Epoch 5/10
Train Loss: 0.3367 | Train Acc: 0.9084
Val   Loss: 0.1879 | Val   Acc: 0.9447
Saved best model checkpoint.


Train Epoch 6/10: 100%|██████████| 344/344 [02:33<00:00,  2.25it/s]
Valid: 100%|██████████| 87/87 [00:42<00:00,  2.07it/s]



Epoch 6/10
Train Loss: 0.2885 | Train Acc: 0.9222
Val   Loss: 0.1673 | Val   Acc: 0.9512
Saved best model checkpoint.


Train Epoch 7/10: 100%|██████████| 344/344 [02:27<00:00,  2.33it/s]
Valid: 100%|██████████| 87/87 [00:37<00:00,  2.30it/s]



Epoch 7/10
Train Loss: 0.2484 | Train Acc: 0.9326
Val   Loss: 0.1648 | Val   Acc: 0.9490
Saved best model checkpoint.


Train Epoch 8/10: 100%|██████████| 344/344 [02:27<00:00,  2.33it/s]
Valid: 100%|██████████| 87/87 [00:37<00:00,  2.30it/s]



Epoch 8/10
Train Loss: 0.2223 | Train Acc: 0.9398
Val   Loss: 0.1567 | Val   Acc: 0.9497
Saved best model checkpoint.


Train Epoch 9/10: 100%|██████████| 344/344 [02:27<00:00,  2.33it/s]
Valid: 100%|██████████| 87/87 [00:37<00:00,  2.30it/s]



Epoch 9/10
Train Loss: 0.2055 | Train Acc: 0.9439
Val   Loss: 0.1540 | Val   Acc: 0.9512
Saved best model checkpoint.


Train Epoch 10/10: 100%|██████████| 344/344 [02:27<00:00,  2.34it/s]
Valid: 100%|██████████| 87/87 [00:37<00:00,  2.35it/s]


Epoch 10/10
Train Loss: 0.1825 | Train Acc: 0.9502
Val   Loss: 0.1407 | Val   Acc: 0.9559
Saved best model checkpoint.



