In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

class ResidualUnit(nn.Module):

  def __init__(self, in_channels, out_channels, stride = 1):
    super().__init__()

    DefaultConv2d = partial(
        nn.Conv2d, kernel_size = 3, stride = 1, padding = 1, bias = False)
    self.main_layers = nn.Sequential(
        DefaultConv2d(in_channels, out_channels, stride = stride),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        DefaultConv2d(out_channels, out_channels),
        nn.BatchNorm2d(out_channels)
    )

    if stride > 1:
      self.skip_connection = nn.Sequential(
          DefaultConv2d(in_channels, out_channels, kernel_size = 1,
                        stride = stride, padding = 0),
          nn.BatchNorm2d(out_channels)
      )
    else:
      self.skip_connection = nn.Identity()

  def forward(self, inputs):
    return F.relu(self.main_layers(inputs) + self.skip_connection(inputs))


  # We created Residual Unit which calculates the difference between each residual layers to learn the diff

In [2]:
class Resnet34(nn.Module):
  def __init__(self):
    super().__init__()
    layers = [
        nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size = 7,
                  stride = 2, padding = 3, bias = False),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
    ]

    prev_filter = 64

    for filters in [64] * 3 + [128] * 4 + [256] * 6 + [512] * 3:
      stride = 1 if filters == prev_filter else 2
      layers.append(ResidualUnit(prev_filter, filters, stride))
      prev_filter = filters

    layers += [
        nn.AdaptiveAvgPool2d(output_size = 1),
        nn.Flatten(),
        nn.LazyLinear(10)
    ]

    self.resnet = nn.Sequential(*layers)

  def forward(self, inputs):
    return self.resnet(inputs)

In [3]:
import torchvision
device = "cuda" if torch.cuda.is_available() else "cpu"

weights = torchvision.models.ConvNeXt_Base_Weights.IMAGENET1K_V1
model = torchvision.models.convnext_base(weights=weights).to(device)

Downloading: "https://download.pytorch.org/models/convnext_base-6075fbad.pth" to /root/.cache/torch/hub/checkpoints/convnext_base-6075fbad.pth


100%|██████████| 338M/338M [00:01<00:00, 182MB/s]


In [4]:
Defaultflowers102 = partial(torchvision.datasets.Flowers102, root = 'datasets',
                            transform = weights.transforms(), download = True)

train_set = Defaultflowers102(split = 'train')
valid_set = Defaultflowers102(split = "val")
test_set = Defaultflowers102(split = 'val')

# We train flowers102 dataset using transfer Learning, this dataset has only 10 images per class which is not enough to learn about a class
# We use ConvNext model lower layers and from there we try to learn this dataset

100%|██████████| 345M/345M [00:17<00:00, 19.4MB/s]
100%|██████████| 502/502 [00:00<00:00, 2.14MB/s]
100%|██████████| 15.0k/15.0k [00:00<00:00, 22.8MB/s]


In [5]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_set, batch_size = 32, shuffle = True)
valid_loader = DataLoader(valid_set, batch_size = 32)
test_loader = DataLoader(test_set, batch_size = 32)

In [6]:
[name for name, child in model.named_children()]

['features', 'avgpool', 'classifier']

In [7]:
model.features
# the ConvNext Architecture for the lower layers

Sequential(
  (0): Conv2dNormActivation(
    (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
    (1): LayerNorm2d((128,), eps=1e-06, elementwise_affine=True)
  )
  (1): Sequential(
    (0): CNBlock(
      (block): Sequential(
        (0): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
        (1): Permute()
        (2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (3): Linear(in_features=128, out_features=512, bias=True)
        (4): GELU(approximate='none')
        (5): Linear(in_features=512, out_features=128, bias=True)
        (6): Permute()
      )
      (stochastic_depth): StochasticDepth(p=0.0, mode=row)
    )
    (1): CNBlock(
      (block): Sequential(
        (0): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
        (1): Permute()
        (2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (3): Linear(in_features=128, out_features=512, bias=True)
        (4): GELU(approx

In [8]:
model.classifier

# Here this model was trained on Imagenet which is 1000 class dataset we need to change this part to work with our dataset

Sequential(
  (0): LayerNorm2d((1024,), eps=1e-06, elementwise_affine=True)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=1024, out_features=1000, bias=True)
)

In [9]:
n_classes = 102
model.classifier[2] = nn.Linear(1024, n_classes).to(device)

In [10]:
model.classifier
# Now we changed the output head to match our class

Sequential(
  (0): LayerNorm2d((1024,), eps=1e-06, elementwise_affine=True)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=1024, out_features=102, bias=True)
)

In [11]:
for param in model.parameters():
  param.requires_grad = False

for param in model.classifier.parameters():
  param.requires_grad = True

# We freeze the model parameters except for the classifer we cahnged to learn the image for this dataset
# enabling earlier layers to learn will cause overfitting because of our small dataset

In [12]:
import torchvision.transforms.v2 as T

transforms = T.Compose([
    T.RandomHorizontalFlip(p=0.5),
    T.RandomRotation(degrees = 30),
    T.RandomResizedCrop(size = (224,224), scale = (0.8, 1.0)),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue = 0.1),
    T.ToImage(),
    T.ToDtype(torch.float32, scale = True),
    T.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
])

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.classifier.parameters(), lr=1e-3)

# Reduce LR if validation loss stops improving
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=2
)

In [14]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):

        if self.best_loss is None:
            self.best_loss = val_loss

        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1

            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0


In [15]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0
    correct = 0
    total = 0

    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return running_loss / total, correct / total

In [16]:
def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return running_loss / total, correct / total

In [17]:
num_epochs = 30
early_stopping = EarlyStopping(patience=5)

best_val_loss = float("inf")

for epoch in range(num_epochs):

    train_loss, train_acc = train_one_epoch(
        model, train_loader, optimizer, criterion, device
    )

    val_loss, val_acc = validate(
        model, valid_loader, criterion, device
    )

    # Step scheduler using validation loss
    scheduler.step(val_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.4f}")
    print("-" * 50)

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_model.pth")

    # Early stopping check
    early_stopping(val_loss)
    if early_stopping.early_stop:
        print("Early stopping triggered.")
        break

Epoch [1/30]
Train Loss: 4.3064 | Train Acc: 0.1422
Val   Loss: 3.4932 | Val   Acc: 0.5647
--------------------------------------------------
Epoch [2/30]
Train Loss: 2.9995 | Train Acc: 0.7157
Val   Loss: 2.5210 | Val   Acc: 0.7373
--------------------------------------------------
Epoch [3/30]
Train Loss: 2.0594 | Train Acc: 0.8333
Val   Loss: 1.7991 | Val   Acc: 0.8235
--------------------------------------------------
Epoch [4/30]
Train Loss: 1.3858 | Train Acc: 0.9167
Val   Loss: 1.3369 | Val   Acc: 0.8608
--------------------------------------------------
Epoch [5/30]
Train Loss: 0.9548 | Train Acc: 0.9461
Val   Loss: 1.0694 | Val   Acc: 0.8647
--------------------------------------------------
Epoch [6/30]
Train Loss: 0.6896 | Train Acc: 0.9520
Val   Loss: 0.8837 | Val   Acc: 0.8784
--------------------------------------------------
Epoch [7/30]
Train Loss: 0.5072 | Train Acc: 0.9667
Val   Loss: 0.7654 | Val   Acc: 0.8833
--------------------------------------------------
Epoch 