# Model

> All things related to models (loading, utils,...)

In [None]:
#| default_exp model

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os

import torch
import torch.nn as nn
from torchvision import transforms, models



In [None]:
#| export
def load_ckpt(model, ckpt_path):
    if os.path.exists(ckpt_path):
        state_dict = torch.load(ckpt_path, map_location='cpu')
        model.load_state_dict(state_dict)
        print(f"Loaded pretrained weights from {ckpt_path}")
    else:
        print(f"No checkpoint found at {ckpt_path}. Using random initialized weights.")
    return model

In [None]:
#| export
import torch
import torch.nn as nn
from torchvision import models
class CustomModel(nn.Module):
    def __init__(self, backbone= "resnet18", ckpt= None, num_classes= 3, pretrained=True):
        super().__init__()

        if backbone == "resnet18":
            self.model = models.resnet18(pretrained=pretrained)
            self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        elif backbone == "efficientnet":
            self.model = models.efficientnet_b0(pretrained=pretrained)
            self.model.classifier[1] = nn.Linear(self.model.classifier[1].in_features, num_classes)
        else:
            raise ValueError("Unsupported backbone")
        
        self.model = load_ckpt(self.model, ckpt)
    def forward(self, x):
        if x.ndim == 5:
            x = x.flatten(0, 1)
        return self.model(x)

In [None]:
#| export
def build_model(backbone="resnet18", ckpt=None, num_classes=3, pretrained=True):
    model = CustomModel(backbone, ckpt, num_classes, pretrained)
    return model

In [None]:
#| export
def init_model(cfg): 
    
    model = build_model(backbone=cfg.model.backbone, ckpt=cfg.model.ckpt, num_classes=len(cfg.data.label_names), pretrained=cfg.model.pretrained)
    for param in model.parameters():
        param.requires_grad = False

    if cfg.task == "eval":
        return model

    if cfg.task == 'probing':
        if cfg.model.backbone == "resnet18":
            for param in model.model.fc.parameters():
                param.requires_grad = True    
        elif cfg.model.backbone == "efficientnet":
            for param in model.model.classifier[1].parameters():
                param.requires_grad = True

    elif cfg.task == 'fine-tuning':
        for param in model.parameters():
            param.requires_grad = True

    else:
        raise ValueError("Unsupported type. Choose either 'eval', 'probing', or 'fine-tuning'.")

    return model

In [None]:
#| hide
from omegaconf import OmegaConf
cfg = OmegaConf.load("../cfgs/task_1/resnet/eval.yaml")
cfg.model.ckpt = "../pretrained_backbone/ckpt_resnet18_ep50.pt"
cfg.task = "probing"
model = init_model(cfg)


  state_dict = torch.load(ckpt_path, map_location='cpu')


Loaded pretrained weights from ../pretrained_backbone/ckpt_resnet18_ep50.pt


In [None]:
#| hide
model

CustomModel(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_r

In [None]:
cfg.model

{'backbone': 'efficientnet', 'pretrained': True, 'ckpt': '../pretrained_backbone/ckpt_efficientnet_ep50.pt'}

In [None]:
#| hide
from omegaconf import OmegaConf
cfg = OmegaConf.load("../cfgs/task_1/efficientnet/probing.yaml")
cfg.model.ckpt = "../pretrained_backbone/ckpt_efficientnet_ep50.pt"
# cfg.model.backbone = "efficientnet"
cfg.task = "probing"
model = init_model(cfg)


Loaded pretrained weights from ../pretrained_backbone/ckpt_efficientnet_ep50.pt


  state_dict = torch.load(ckpt_path, map_location='cpu')


In [None]:
#| hide
backbone = models.resnet18(pretrained=True)
backbone(torch.randn(1,3,256,256)).shape

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /home/ahmed/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|██████████| 20.5M/20.5M [00:01<00:00, 12.8MB/s]


torch.Size([1, 1000])

In [None]:
#| export
import torch
import torch.nn as nn
from torchvision import models

class ResNet18WithAttention(nn.Module):
    def __init__(self, num_classes=1000, num_heads=8):
        super(ResNet18WithAttention, self).__init__()
        resnet = models.resnet18(weights=None)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2]) # Stops at (B, 512, 8, 8)
        
        self.embed_dim = 512
        
        self.mha = nn.MultiheadAttention(embed_dim=self.embed_dim, num_heads=num_heads, batch_first=True)
        
        self.ln = nn.LayerNorm(self.embed_dim)
        self.fc = nn.Linear(self.embed_dim, num_classes)

    def forward(self, x):
        # x shape: (Batch, 3, 256, 256)
        features = self.backbone(x) # Shape: (Batch, 512, 8, 8)
        
        # Flatten spatial dimensions: (B, 512, 8, 8) -> (B, 512, 64) -> (B, 64, 512)
        # We treat the 64 pixels as the "sequence"
        b, c, h, w = features.shape
        features = features.view(b, c, h * w).permute(0, 2, 1) 
        
        attn_output, _ = self.mha(features, features, features)
        
        x = self.ln(attn_output + features)
        
        x = x.mean(dim=1)
        
        logits = self.fc(x)
        return logits



In [None]:
#| hide
model = ResNet18WithAttention(num_classes=3)
dummy_input = torch.randn(1, 3, 256, 256)
output = model(dummy_input)
print(f"Output shape: {output.shape}")

Output shape: torch.Size([1, 3])


In [None]:
#| hide
backbone  = models.efficientnet_b0(pretrained=True)
backbone(torch.randn(1,3,256,256)).shape

In [None]:
#| hide
backbone

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActivat

In [None]:
#| export
import torch
import torch.nn as nn
from torchvision import models

class EfficientNetB0WithAttention(nn.Module):
    def __init__(self, num_classes=1000, num_heads=8):
        super(EfficientNetB0WithAttention, self).__init__()
        
        # 1. Load EfficientNet_B0
        # We take only the 'features' part to exclude the default pooling/classifier
        effnet = models.efficientnet_b0(weights=None)
        self.backbone = effnet.features 
        
        # EfficientNet-B0 final stage outputs 1280 channels
        self.embed_dim = 1280 
        
        # 2. Multi-Head Attention Layer
        # Ensure embed_dim is divisible by num_heads (1280 / 8 = 160)
        self.mha = nn.MultiheadAttention(
            embed_dim=self.embed_dim, 
            num_heads=num_heads, 
            batch_first=True
        )
        
        # 3. Normalization and Classifier
        self.ln = nn.LayerNorm(self.embed_dim)
        self.fc = nn.Linear(self.embed_dim, num_classes)

    def forward(self, x):
        # x shape: (Batch, 3, 256, 256)
        # EfficientNet features output shape: (Batch, 1280, 8, 8)
        features = self.backbone(x)
        
        b, c, h, w = features.shape
        # Flatten spatial: (B, 1280, 64) -> Permute: (B, 64, 1280)
        features = features.view(b, c, h * w).permute(0, 2, 1)
        
        # Apply Self-Attention
        attn_output, _ = self.mha(features, features, features)
        
        # Residual connection + Layer Norm
        x = self.ln(attn_output + features)
        
        # Global Average Pooling (Across the 64 spatial tokens)
        x = x.mean(dim=1) 
        
        # Classification
        logits = self.fc(x)
        return logits



In [None]:
#| hide
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EfficientNetB0WithAttention(num_classes=3).to(device)
dummy_input = torch.randn(1, 3, 256, 256).to(device)
output = model(dummy_input)

print(f"Output shape: {output.shape}")

Output shape: torch.Size([1, 3])


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()