In [None]:
import os
import time
import copy
from tqdm import tqdm
from pathlib import Path
import pickle
import random

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.amp import autocast, GradScaler

import torchvision
from torchvision import datasets, models, transforms

from torch.utils.data import Dataset, DataLoader, random_split

In [None]:
# #load files
if not os.path.exists("/content/drive"):
  from google.colab import drive
  drive.mount('/content/drive')
if not os.path.exists("/content/AppliedMLProject"):
  !git clone https://YasinKaryagdi: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@github.com/YasinKaryagdi/AppliedMLProject.git
if not os.path.exists("/content/augmented_set.zip"):
  !cp -r /content/drive/MyDrive/Machinelearning_files/augmented_set.zip /content/
  !unzip augmented_set.zip
if not os.path.exists("/content/validate_split.csv"):
  !cp -r /content/drive/MyDrive/Machinelearning_files/validate_split.csv /content/
  !cp -r /content/drive/MyDrive/Machinelearning_files/train_augmented.csv /content/
  !cp -r /content/drive/MyDrive/Machinelearning_files/train_split.csv /content/
  !cp -r /content/drive/MyDrive/Machinelearning_files/train_balanced.csv /content/

In [None]:
cwd = Path.cwd()
gitpath = cwd / "AppliedMLProject"
dirpath = gitpath / "aml-2025-feathers-in-focus"
train_images_csv = dirpath / "train_images.csv"
train_images_folder = dirpath / "train_images"
image_classes = dirpath / "class_names.npy"
drive_path = cwd / "drive" / "MyDrive" / "Machinelearning files"
val_images_csv = cwd / "validate_split.csv"
train_balanced_csv = cwd / "train_balanced.csv"


In [None]:
#Defining model and training variables
#use augmented trainingset and if so, use balanced set?
use_augmented = True
if use_augmented:
  use_balanced = True
  augmentations = []
#model
model_name = "MODERNRESDEEP" # <- modelname goes here
#possible models: "M3MAX", "SIMPLE1", "CLASSIC1", "CLASSICRES", "MODERNRES", "MODERNRESDEEP"
#use model transformations or standard
use_model_transforms = False
#use_scaler
use_scaler = True
#earlystop
early_stopping = True
patience = 5
min_delta = 0
#training batchsize
train_batch_size = 16
#validation & testing batchsize
val_batch_size = 32
#Epochs
num_epochs = 15
#Optimizer build:
#learningrate
learning_rate = 0.001
#momentum
moment = 0.9
#weight decay
wd = 0.001
#resize to:
size = (256,256)
#use pretrained or not
use_pretrained = True
classes = np.load(image_classes, allow_pickle=True).item()
num_classes = len(classes)
#model save name
model_save_name = (model_name + "_" +
                   ("_aug" if use_augmented else "noaug")+
                   ("_bal" if use_balanced else "")
                   )
model_save_name
#use seed?
use_seed = True
seed = 42
SEEDS = [42]

Test model classes go here

In [None]:
class ModelM3MAX(nn.Module):
    def __init__(self):
        super(ModelM3MAX, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, 3, padding=1, stride = 1, bias=False)
        self.conv1_bn = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(32, 48, 3, padding=1, bias=False)
        self.conv2_bn = nn.BatchNorm2d(48)

        self.conv3 = nn.Conv2d(48, 64, 3, padding=1, bias=False)
        self.conv3_bn = nn.BatchNorm2d(64)

        self.conv4 = nn.Conv2d(64, 80, 3, padding=1, bias=False)
        self.conv4_bn = nn.BatchNorm2d(80)

        self.conv5 = nn.Conv2d(80, 96, 3, padding=1, bias=False)
        self.conv5_bn = nn.BatchNorm2d(96)

        self.conv6 = nn.Conv2d(96, 112, 3, padding=1, bias=False)
        self.conv6_bn = nn.BatchNorm2d(112)

        self.conv7 = nn.Conv2d(112, 128, 3, padding=1, bias=False)
        self.conv7_bn = nn.BatchNorm2d(128)

        self.conv8 = nn.Conv2d(128, 144, 3, padding=1, bias=False)
        self.conv8_bn = nn.BatchNorm2d(144)

        self.conv9 = nn.Conv2d(144, 160, 3, padding=1, bias=False)
        self.conv9_bn = nn.BatchNorm2d(160)

        self.conv10 = nn.Conv2d(160, 176, 3, padding=1, bias=False)
        self.conv10_bn = nn.BatchNorm2d(176)

        # unchanged: 176 * 8 * 8 = 11264
        self.fc1 = nn.Linear(11264, 200, bias=False)
        self.fc1_bn = nn.BatchNorm1d(200)

    def get_logits(self, x):
        x = (x - 0.5) * 2.0

        conv1 = F.relu(self.conv1_bn(self.conv1(x)))
        conv2 = F.relu(self.conv2_bn(self.conv2(conv1)))
        conv2 = F.max_pool2d(conv2, 2)  # 256 -> 128

        conv3 = F.relu(self.conv3_bn(self.conv3(conv2)))
        conv4 = F.relu(self.conv4_bn(self.conv4(conv3)))
        conv4 = F.max_pool2d(conv4, 2)  # 128 -> 64

        conv5 = F.relu(self.conv5_bn(self.conv5(conv4)))
        conv6 = F.relu(self.conv6_bn(self.conv6(conv5)))
        conv6 = F.max_pool2d(conv6, 2)  # 64 -> 32

        conv7 = F.relu(self.conv7_bn(self.conv7(conv6)))
        conv8 = F.relu(self.conv8_bn(self.conv8(conv7)))
        conv8 = F.max_pool2d(conv8, 2)  # 32 -> 16

        conv9 = F.relu(self.conv9_bn(self.conv9(conv8)))
        conv10 = F.relu(self.conv10_bn(self.conv10(conv9)))
        conv10 = F.max_pool2d(conv10, 2)  # 16 -> 8

        # Now conv10 is (batch, 176, 8, 8)
        flat = torch.flatten(conv10.permute(0, 2, 3, 1), 1)
        logits = self.fc1_bn(self.fc1(flat))
        return logits

    def forward(self, x):
        logits = self.get_logits(x)
        return logits

In [None]:
class SIMPLE1(nn.Module):
    def __init__(self):
        super(SIMPLE1, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, 3, padding=1, stride = 1, bias=False)
        self.conv1_bn = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(32, 64, 3, padding=1, bias=False)
        self.conv2_bn = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(64, 128, 3, padding=1, bias=False)
        self.conv3_bn = nn.BatchNorm2d(128)

        self.conv4 = nn.Conv2d(128, 164, 3, padding=1, bias=False)
        self.conv4_bn = nn.BatchNorm2d(164)

        self.conv5 = nn.Conv2d(164, 176, 3, padding=1, bias=False)
        self.conv5_bn = nn.BatchNorm2d(176)

        # unchanged: 176 * 8 * 8 = 11264
        self.fc1 = nn.Linear(11264, 200, bias=False)
        self.fc1_bn = nn.BatchNorm1d(200)

    def get_logits(self, x):
        x = (x - 0.5) * 2.0

        #conv 1
        x = F.relu(self.conv1_bn(self.conv1(x)))
        x = F.max_pool2d(x, 2) #256 -> 128

        #conv 2
        x = F.relu(self.conv2_bn(self.conv2(x)))
        x = F.max_pool2d(x, 2) #128 -> 64

        #conv 3
        x = F.relu(self.conv3_bn(self.conv3(x)))
        x = F.max_pool2d(x, 2) #64 -> 32

        # Conv 4
        x = F.relu(self.conv4_bn(self.conv4(x)))
        x = F.max_pool2d(x, 2)   # 32 -> 16

        # Conv 5
        x = F.relu(self.conv5_bn(self.conv5(x)))
        x = F.max_pool2d(x, 2)   # 16 -> 8

        # x is now (batch, 176, 8, 8)
        x = torch.flatten(x, 1)  # (batch, 11264)

        logits = self.fc1_bn(self.fc1(x))
        return logits

    def forward(self, x):
        logits = self.get_logits(x)
        return logits

In [None]:
class CLASSIC1(nn.Module):
    def __init__(self):
        super(CLASSIC1, self).__init__()

        # 5 stages with double conv + pooling
        self.stage1 = self.conv_block(3, 32)
        self.stage2 = self.conv_block(32, 64)
        self.stage3 = self.conv_block(64, 128)
        self.stage4 = self.conv_block(128, 256) #adapt to 256 and 512 to conform to memory norms (powers of 2)
        self.stage5 = self.conv_block(256, 512)

        self.gap = nn.AdaptiveAvgPool2d(1)   # Global Average Pooling reduces parameters and betters generalization
        self.dropout = nn.Dropout(p=0.4) # Prevents overfitting, with some reduced probability to allow quicker learning and not over-regularize
        self.fc1 = nn.Linear(512, 200)

    @staticmethod
    def conv_block(in_ch, out_ch): #first building block for conv layers (stack of 2)
        # Standard pattern: Conv -> BN -> ReLU -> Conv -> BN -> ReLU -> MaxPool
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),

            nn.MaxPool2d(2)
        )

    def forward(self, x): #replaced forward + logit with just forward
        x = (x - 0.5) * 2.0

        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)
        # x is now (Batch_Size, 512, 1, 1) assuming input was 32x32 (5 max pools)

        x = self.gap(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        logits = self.fc1(x)
        return logits

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)

        # shortcut: identity if channels match, otherwise 1x1 conv
        self.shortcut = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1, bias=False)

        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        identity = self.shortcut(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        out = self.relu(out)
        out = self.pool(out)
        return out

class CLASSICRES(nn.Module):
    def __init__(self, num_classes=200):
        super().__init__()

        # Stage-level residual blocks
        self.stage1 = ResidualBlock(3, 32)       # 256 -> 128
        self.stage2 = ResidualBlock(32, 64)      # 128 -> 64
        self.stage3 = ResidualBlock(64, 128)     # 64 -> 32
        self.stage4 = ResidualBlock(128, 256)    # 32 -> 16
        self.stage5 = ResidualBlock(256, 512)    # 16 -> 8

        # Classifier
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(p=0.4)
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = (x - 0.5) * 2.0  # normalize to [-1, 1]

        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)

        x = self.gap(x)                  # (B, 512, 1, 1)
        x = torch.flatten(x, 1)          # (B, 512)
        x = self.dropout(x)
        logits = self.fc(x)
        return logits

In [None]:
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)
        self.activation = nn.ReLU()  # <- changed back from SiLU
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.size()
        # Squeeze: global average pooling
        y = x.mean(dim=(2, 3))           # (B, C)
        # Excitation: MLP
        y = self.fc2(self.activation(self.fc1(y)))  # (B, C)
        y = self.sigmoid(y).view(b, c, 1, 1)
        # Scale: multiply original feature map
        return x * y


class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch, use_se=False, use_pool=False):
        super().__init__()
        self.use_se = use_se
        self.use_pool = use_pool

        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)

        self.shortcut = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.act = nn.SiLU(inplace=True)  # <-- changed from ReLU
        if use_pool:
            self.pool = nn.MaxPool2d(2)
        if use_se:
            self.se = SEBlock(out_ch)

    def forward(self, x):
        identity = self.shortcut(x)
        out = self.act(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity

        if self.use_se:
            out = self.se(out)

        out = self.act(out)
        if self.use_pool:
            out = self.pool(out)
        return out


class MODERNRES(nn.Module):
    def __init__(self, num_classes=200):
        super().__init__()

        # Stage-level residual blocks
        self.stage1 = ResidualBlock(3, 32, use_se=True)
        self.stage2 = ResidualBlock(32, 64, use_se=True)
        self.stage3 = ResidualBlock(64, 96, use_se=True)
        self.stage4 = ResidualBlock(96, 128, use_se=True)
        self.stage5 = ResidualBlock(128, 160, use_se=True)

        # Classifier
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(p=0.4)
        self.fc = nn.Linear(160, num_classes)

    def forward(self, x):
        x = (x - 0.5) * 2.0  # normalize to [-1, 1]

        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)

        x = self.gap(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        logits = self.fc(x)
        return logits

In [None]:
class MODERNRESDEEP(nn.Module):
    def __init__(self, num_classes=200):
        super().__init__()

        # Stage-level residual blocks
        self.stage1 = ResidualBlock(3, 32, use_se=True)
        self.stage2 = ResidualBlock(32, 64, use_se=True, use_pool = False)
        self.stage3 = ResidualBlock(64, 96, use_se=True)
        self.stage4 = ResidualBlock(96, 128, use_se=True, use_pool = False)
        self.stage5 = ResidualBlock(128, 160, use_se=True)
        self.stage6 = ResidualBlock(160, 192, use_se=True, use_pool = False)
        self.stage7 = ResidualBlock(192, 224, use_se=True)
        self.stage8 = ResidualBlock(224, 256, use_se=True, use_pool = False)
        self.stage9 = ResidualBlock(256, 288, use_se=True)

        # Classifier
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(p=0.4)
        self.fc = nn.Linear(288, num_classes)

    def forward(self, x):
        x = (x - 0.5) * 2.0  # normalize to [-1, 1]

        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)
        x = self.stage6(x)
        x = self.stage7(x)
        x = self.stage8(x)
        x = self.stage9(x)

        x = self.gap(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        logits = self.fc(x)
        return logits

Test model activations go here

In [None]:
# define test model+transforms here
if model_name == "M3MAX":
  custom_model = ModelM3MAX()
#gets 12.99 percent with: balanced augdata, lr 0.001, moment 0.9, wd 0.001, batchsize 32, epochs 15,
if model_name == "SIMP1":
  custom_model = SIMPLE1()
#gets 12.7 percent with: balanced augdata, lr 0.001, moment 0.9, wd 0.001, batchsize 32, epochs 15 (but fewer params, not much though)
if model_name == "CLASSIC1":
  custom_model = CLASSIC1()
#gets 15.6 percent with: balanced augdata, lr 0.001, moment 0.9, wd 0.001, batchsize 32, epochs 15 (but fewer params, not much though)
#but has quite some more potential with more epochs/higher learning rate/higher batch size
if model_name == "CLASSICRES":
  custom_model = CLASSICRES()
#gets 23 percent with: balanced augdata, lr 0.001, moment 0.9, wd 0.001, batchsize 64, epochs 15 (but fewer params, not much though)
if model_name == "MODERNRES":
  custom_model = MODERNRES()
#gets 23-25 percent on 30 epochs, but is much lighter than classic, trains about 2x as fast
if model_name == "MODERNRESDEEP":
  custom_model = MODERNRESDEEP()


Class and function definitions goes here

In [None]:
#remove randomness for benchmarking
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
if use_seed:
  set_seed(seed)

In [None]:
#defining dataclass
class CSVDataset(Dataset):
    def __init__(self,
                 csv_file,
                 base_dir,
                 transform=None,
                 return_id=False,
                 augmentation_tags=None): # Added augmentation_tags parameter
        self.df = pd.read_csv(csv_file)

        # Apply augmentation filtering if tags are provided
        if augmentation_tags is not None:
            # Ensure 'original' is always included
            all_tags_to_include = list(set(augmentation_tags + ['original']))

            mask = pd.Series([False] * len(self.df), index=self.df.index)
            for tag in all_tags_to_include:
                # Check if the image_path contains the augmentation tag
                mask = mask | self.df['image_path'].str.contains(f'_{tag}.jpg', regex=False)
            self.df = self.df[mask].copy()

        self.base_dir = base_dir
        self.transform = transform
        self.return_id = return_id

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # extract fields
        img_id = row['id'] if self.return_id else None
        relative_path = row['image_path'].lstrip('/')  # safe
        label = row['label'] - 1   # shift to 0-based indexing

        # build full path
        img_path = os.path.join(self.base_dir, relative_path)

        # load
        image = Image.open(img_path).convert('RGB')

        # transform
        if self.transform:
            image = self.transform(image)

        # optionally return id
        if self.return_id:
            return image, label, img_id

        return image, label

In [None]:
def train_model(model,
                train_loader,
                val_loader,
                criterion,
                optimizer,
                schedular=None,
                num_epochs=10,
                early_stopping=False,
                epochs_no_improve=0,
                patience=5,
                min_delta=0.0,
                device="cuda"):

    dataloaders_dict = {"train": train_loader, "val": val_loader}
    since = time.time()
    model.to(device)
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    if use_scaler:
      scaler = GradScaler()

    # Initialize history dictionary
    hist = {
        "train_loss": [],
        "val_loss": [],
        "train_acc": [],
        "val_acc": []
    }

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in tqdm(dataloaders_dict[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                if use_scaler:
                  with torch.set_grad_enabled(phase == 'train'):
                    with autocast("cuda"):  # Mixed precision context
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        _, preds = torch.max(outputs, 1)

                    if phase == 'train':
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()
                else:
                  with torch.set_grad_enabled(phase == 'train'):
                      outputs = model(inputs)
                      loss = criterion(outputs, labels)

                      _, preds = torch.max(outputs, 1)

                      if phase == 'train':
                          loss.backward()
                          optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders_dict[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders_dict[phase].dataset)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # Save metrics in history
            if phase == 'train':
                hist['train_loss'].append(epoch_loss)
                hist['train_acc'].append(epoch_acc.item())
            else:
                hist['val_loss'].append(epoch_loss)
                hist['val_acc'].append(epoch_acc.item())

                # Early stopping logic
                if epoch_acc > best_acc + min_delta:
                    print(f"Validation improved ({best_acc:.4f} → {epoch_acc:.4f})")
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    epochs_no_improve = 0
                else:
                    epochs_no_improve += 1
                    print(f"No improvement for {epochs_no_improve} epoch(s).")

        if epochs_no_improve >= patience and early_stopping:
            print(f"Early stopping triggered at epoch {epoch+1}!")
            break

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')

    # load best model weights
    model.load_state_dict(best_model_wts)

    return model, hist

Most basic transformations, standardizes rgb values, resizes images to set values and converts image to tensor

In [None]:
#Define some standard transformations
transformations = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((size)),
    transforms.Normalize(mean = (0.5,0.5,0.5), std = (0.5,0.5,0.5))
    ])
## Probably better to follow the original resnet transformations
#See: (model.ResNet152_Weights.IMAGENET1K_V1.transforms)
if not use_model_transforms:
  model_transforms = transformations


Define datasets based on augmented or not

In [None]:
if use_augmented == False:
  full_dataset = CSVDataset(
      csv_file=str(dirpath / "train_images.csv"),
      base_dir=str(dirpath),
      transform = model_transforms,
      return_id=False
  )
  train_size = int(split * len(full_dataset))
  val_size = len(full_dataset) - train_size
  train_dataset, val_dataset = random_split(
      full_dataset,
      [train_size, val_size],
      generator=torch.Generator().manual_seed(seed)
  )
  loader = DataLoader(full_dataset, batch_size=train_batch_size, shuffle=True)

In [None]:
if use_augmented == True:
  if use_balanced:
    train_dataset = CSVDataset(
        csv_file=str(cwd / "train_balanced.csv"),
        base_dir=str(cwd),
        transform = model_transforms,
        return_id=False
    )
  else:
    train_dataset = CSVDataset(
        csv_file=str(cwd / "train_augmented.csv"),
        base_dir=str(cwd),
        transform = model_transforms,
        return_id=False
    )
  val_dataset = CSVDataset(
      csv_file=str(val_images_csv),
      base_dir=str(dirpath),
      transform = model_transforms,
      return_id=False
  )

In [None]:
#define dataloaders
# data loaders
#create full loader
train_loader = DataLoader(train_dataset,
                          batch_size=train_batch_size,
                          shuffle=True,
                          num_workers=2,
                          pin_memory=True,
                          prefetch_factor=2,
                          persistent_workers=True)
val_loader = DataLoader(val_dataset,
                        batch_size=val_batch_size,
                        shuffle=False,
                        num_workers=2,
                        pin_memory=True,
                        prefetch_factor=2,
                        persistent_workers=True
                        )

In [None]:
# Detect if we have a GPU available
torch.cuda.empty_cache()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [None]:
#gather optimizable parameters
params_to_update = custom_model.parameters()
#Design optimzer
# optimizer = optim.SGD(params_to_update, lr=learning_rate, momentum=moment,
#                       weight_decay=wd
#                       )
optimizer = optim.AdamW(params_to_update, lr=learning_rate)
# Setup the loss func
criterion = nn.CrossEntropyLoss()

In [None]:
for seed in SEEDS:
  set_seed(seed)
  #initialize model helpers
  if model_name == "M3MAX":
    custom_model = ModelM3MAX()
  if model_name == "SIMP1":
    custom_model = SIMPLE1()
  if model_name == "CLASSIC1":
    custom_model = CLASSIC1()
  if model_name == "CLASSICRES":
    custom_model = CLASSICRES()
  if model_name == "MODERNRES":
    custom_model = MODERNRES()
  if model_name == "MODERNRESDEEP":
    custom_model = MODERNRESDEEP()

  #initialize optimzer and criterion
  # Initialize optimizer and criterion
  params_to_update = custom_model.parameters()
  optimizer = optim.AdamW(params_to_update, lr=learning_rate)
  criterion = nn.CrossEntropyLoss()

  #train model
  model_trained, hist = train_model(custom_model,
                              train_loader,
                              val_loader,
                              criterion,
                              optimizer,
                              early_stopping = early_stopping,
                              patience = patience,
                              min_delta = min_delta,
                              schedular=None,
                              num_epochs=num_epochs,
                              device=device)
  torch.save(model_trained.state_dict(), f"/content/drive/MyDrive/Test/{model_name}_{seed}.pth")
  with open(f"/content/drive/MyDrive/Test/{model_name}_{seed}_acc.pkl", "wb") as f:
        pickle.dump(hist, f)

  print(f"Seed {seed}: model and history saved.")


Epoch 1/15
----------


100%|██████████| 1250/1250 [07:50<00:00,  2.66it/s]


train Loss: 5.0758 Acc: 0.0133


100%|██████████| 25/25 [00:06<00:00,  3.96it/s]


val Loss: 4.8974 Acc: 0.0178
Validation improved (0.0000 → 0.0178)

Epoch 2/15
----------


100%|██████████| 1250/1250 [07:50<00:00,  2.65it/s]


train Loss: 4.7664 Acc: 0.0296


100%|██████████| 25/25 [00:05<00:00,  4.39it/s]


val Loss: 4.7888 Acc: 0.0216
Validation improved (0.0178 → 0.0216)

Epoch 3/15
----------


100%|██████████| 1250/1250 [07:51<00:00,  2.65it/s]


train Loss: 4.5184 Acc: 0.0456


100%|██████████| 25/25 [00:05<00:00,  4.38it/s]


val Loss: 4.6096 Acc: 0.0318
Validation improved (0.0216 → 0.0318)

Epoch 4/15
----------


100%|██████████| 1250/1250 [07:50<00:00,  2.65it/s]


train Loss: 4.3049 Acc: 0.0658


100%|██████████| 25/25 [00:05<00:00,  4.38it/s]


val Loss: 4.5974 Acc: 0.0407
Validation improved (0.0318 → 0.0407)

Epoch 5/15
----------


100%|██████████| 1250/1250 [07:51<00:00,  2.65it/s]


train Loss: 4.0620 Acc: 0.0966


100%|██████████| 25/25 [00:05<00:00,  4.40it/s]


val Loss: 4.5574 Acc: 0.0522
Validation improved (0.0407 → 0.0522)

Epoch 6/15
----------


100%|██████████| 1250/1250 [07:50<00:00,  2.66it/s]


train Loss: 3.8337 Acc: 0.1269


100%|██████████| 25/25 [00:05<00:00,  4.38it/s]


val Loss: 4.6579 Acc: 0.0573
Validation improved (0.0522 → 0.0573)

Epoch 7/15
----------


100%|██████████| 1250/1250 [07:50<00:00,  2.66it/s]


train Loss: 3.5880 Acc: 0.1655


100%|██████████| 25/25 [00:05<00:00,  4.39it/s]


val Loss: 4.6105 Acc: 0.0712
Validation improved (0.0573 → 0.0712)

Epoch 8/15
----------


100%|██████████| 1250/1250 [07:50<00:00,  2.66it/s]


train Loss: 3.3463 Acc: 0.2075


100%|██████████| 25/25 [00:05<00:00,  4.39it/s]


val Loss: 4.6648 Acc: 0.0776
Validation improved (0.0712 → 0.0776)

Epoch 9/15
----------


100%|██████████| 1250/1250 [07:50<00:00,  2.66it/s]


train Loss: 3.0962 Acc: 0.2491


100%|██████████| 25/25 [00:05<00:00,  4.38it/s]


val Loss: 4.7536 Acc: 0.0827
Validation improved (0.0776 → 0.0827)

Epoch 10/15
----------


100%|██████████| 1250/1250 [07:50<00:00,  2.66it/s]


train Loss: 2.8979 Acc: 0.2901


100%|██████████| 25/25 [00:05<00:00,  4.37it/s]


val Loss: 4.8099 Acc: 0.1018
Validation improved (0.0827 → 0.1018)

Epoch 11/15
----------


100%|██████████| 1250/1250 [07:50<00:00,  2.66it/s]


train Loss: 2.6835 Acc: 0.3376


100%|██████████| 25/25 [00:05<00:00,  4.39it/s]


val Loss: 4.8947 Acc: 0.0865
No improvement for 1 epoch(s).

Epoch 12/15
----------


100%|██████████| 1250/1250 [07:50<00:00,  2.66it/s]


train Loss: 2.5058 Acc: 0.3752


100%|██████████| 25/25 [00:05<00:00,  4.38it/s]


val Loss: 5.0015 Acc: 0.0941
No improvement for 2 epoch(s).

Epoch 13/15
----------


100%|██████████| 1250/1250 [07:50<00:00,  2.66it/s]


train Loss: 2.3455 Acc: 0.4085


100%|██████████| 25/25 [00:05<00:00,  4.39it/s]


val Loss: 5.0197 Acc: 0.0954
No improvement for 3 epoch(s).

Epoch 14/15
----------


100%|██████████| 1250/1250 [07:50<00:00,  2.66it/s]


train Loss: 2.1920 Acc: 0.4428


100%|██████████| 25/25 [00:05<00:00,  4.39it/s]


val Loss: 5.1899 Acc: 0.0980
No improvement for 4 epoch(s).

Epoch 15/15
----------


100%|██████████| 1250/1250 [07:50<00:00,  2.66it/s]


train Loss: 2.0576 Acc: 0.4733


100%|██████████| 25/25 [00:05<00:00,  4.39it/s]


val Loss: 5.2323 Acc: 0.1005
No improvement for 5 epoch(s).
Early stopping triggered at epoch 15!
Training complete in 119m 7s
Best val Acc: 0.1018
Seed 42: model and history saved.


In [None]:
# Example input
batchsize = 1
example_input = torch.randn(batchsize, 3, 256, 256)

# Init model
curr_model = TEMP()
curr_model.eval()

# compute flops
flops = FlopCountAnalysis(curr_model, example_input)
fw_flops = flops.total()

# create dataframe to store results
data = []
data.append({
    'model_name': 'baseline',
    'fw_flops': fw_flops,
})

# init model
curr_model = MODERNRESS()
curr_model.eval()

# compute flops
flops = FlopCountAnalysis(curr_model, example_input)
fw_flops = flops.total()

data.append({
    'model_name': 'base model',
    'fw_flops': fw_flops,
})


# init model
currmodel = MODERNRESS()
modelList = [currmodel] * 7
curr_model = EnsembleModel(modelList)
curr_model.eval()

# freeze parameters of individual models
for param in curr_model.parameters():
    param.requires_grad = False

# unfreeze parameters of the classifier
for param in curr_model.classifier.parameters():
    param.requires_grad = True
curr_model.eval()

# compute flops
flops = FlopCountAnalysis(curr_model, example_input)
fw_flops = flops.total()

data.append({
    'model_name': 'ensembled',
    'fw_flops': fw_flops,
})

df = pd.DataFrame(data)

print(df)

# plot
fig, ax = plt.subplots(figsize=(8,6))
ax.bar(df['model_name'], df['fw_flops']/1e9)
ax.set_ylabel('FLOPs (billions)')
ax.set_title('Total forward FLOPs for models')
plt.show()
