In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
import math
from typing import Optional, Tuple
import glob

# ARGS

In [2]:

class CFG:
    epochs = 50
    train_batch_size = 10
    val_batch_size = 10
    test_batch_size = 10
    
    train_start = 0
    train_end = 4500
    
    val_start = 4500
    val_end =  5000
    
    test_start = 5000
    test_end = 5109
    
    lr = 2e-5
    # a higer learning rate was used for CNN compared to timm models as these models do not use any pretrained weights or any form of transfer learning
    # but for ViT even smaller learning rates did not yeild good results
    patience = 5
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# DATASET

In [3]:

class PathologyDataset(Dataset):
    def __init__(self, patches, labels, start_idx, end_idx):
        self.patches = patches
        self.labels = labels
        self.start_idx = start_idx
        self.end_idx = end_idx
        self.length = end_idx - start_idx
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        actual_idx = self.start_idx + idx
        
        patch = self.patches[actual_idx].astype(np.float32) 
        label = self.labels[actual_idx]
        
        patch = torch.tensor(patch) 
        label = torch.tensor(label, dtype=torch.long) 
        
        return patch, label


In [4]:
inputs = np.load("/kaggle/input/miccaireg/images.npy", mmap_mode="r")
labels = np.load("/kaggle/input/miccaireg/labels.npy")

print(f"Inputs shape: {inputs.shape}")
print(f"Labels shape: {labels.shape}")
print(f"Labels range: {labels.min()} to {labels.max()}")
print(f"Unique labels: {np.unique(labels)}")


Inputs shape: (5109, 3, 256, 256)
Labels shape: (5109,)
Labels range: 0 to 7
Unique labels: [0 1 2 3 4 5 6 7]


In [5]:
train_dataset = PathologyDataset(inputs, labels, CFG.train_start, CFG.train_end)
val_dataset = PathologyDataset(inputs, labels, CFG.val_start, CFG.val_end)
test_dataset = PathologyDataset(inputs, labels, CFG.test_start, CFG.test_end)

train_loader = DataLoader(train_dataset, batch_size=CFG.train_batch_size, shuffle=True, pin_memory=True,num_workers=4)
val_loader   = DataLoader(val_dataset, batch_size=CFG.val_batch_size, shuffle=False, pin_memory=True,num_workers=4)
test_loader  = DataLoader(test_dataset, batch_size=CFG.test_batch_size, shuffle=False, pin_memory=True,num_workers=2)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

sample_batch, sample_labels = next(iter(train_loader))
print(f"Sample batch shape: {sample_batch.shape}")
print(f"Sample labels shape: {sample_labels.shape}")
print(f"Sample labels: {sample_labels}")


Train batches: 450
Val batches: 50
Test batches: 11
Sample batch shape: torch.Size([10, 3, 256, 256])
Sample labels shape: torch.Size([10])
Sample labels: tensor([5, 5, 0, 6, 0, 6, 3, 6, 5, 6])


# MODEL

# ViT MODEL

In [6]:
class PatchEmbedding(nn.Module):
    """
    Convert image patches to embeddings
    """
    def __init__(self, img_size: int = 256, patch_size: int = 16, in_channels: int = 3, embed_dim: int = 768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # convolution to create patches and embed them
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, channels, height, width)
        x = self.projection(x)  # (batch_size, embed_dim, n_patches_sqrt, n_patches_sqrt)
        x = x.flatten(2)        # (batch_size, embed_dim, n_patches)
        x = x.transpose(1, 2)   # (batch_size, n_patches, embed_dim)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim: int = 768, num_heads: int = 12, dropout: float = 0.1):
        super().__init__()
        assert embed_dim % num_heads == 0
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, embed_dim = x.shape
        
        qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch_size, num_heads, seq_len, head_dim)
        q, k, v = qkv.unbind(0)
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
        x = self.proj(x)
        x = self.dropout(x)
        
        return x

class MLP(nn.Module):
    def __init__(self, embed_dim: int = 768, mlp_ratio: float = 4.0, dropout: float = 0.1):
        super().__init__()
        hidden_dim = int(embed_dim * mlp_ratio)
        
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim: int = 768, num_heads: int = 12, mlp_ratio: float = 4.0, dropout: float = 0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_ratio, dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class PathologyViT(nn.Module):
    """
        img_size: Input image size (default: 256)
        patch_size: Patch size (default: 16)
        in_channels: Number of input channels (default: 3)
        num_classes: Number of classification classes (default: 8)
        embed_dim: Embedding dimension (default: 768)
        depth: Number of transformer layers (default: 12)
        num_heads: Number of attention heads (default: 12)
        mlp_ratio: MLP hidden dim ratio (default: 4.0)
        dropout: Dropout rate (default: 0.1)
    """
    
    def __init__(
        self,
        img_size: int = 256,
        patch_size: int = 16,
        in_channels: int = 3,
        num_classes: int = 8,
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        dropout: float = 0.1
    ):
        super().__init__()
        
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.img_size = img_size
        self.patch_size = patch_size
        
        # Patch embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        n_patches = self.patch_embed.n_patches
        
        # CLS token and position embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
    
    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the transformer, returns CLS token features
        This is useful for feature extraction in Stage 2
        """
        batch_size = x.shape[0]
        
        x = self.patch_embed(x)  # (batch_size, n_patches, embed_dim)
        
        # Add CLS token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (batch_size, n_patches + 1, embed_dim)
        
        # Add position embeddings
        x = x + self.pos_embed
        x = self.dropout(x)
        
        # Pass through transformer blocks
        for block in self.blocks:
            x = block(x)
        
        # Final layer norm
        x = self.norm(x)
        return x[:, 0]  # Return CLS token features: (batch_size, embed_dim)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for classification
        """
        # Get CLS token features
        cls_features = self.forward_features(x)  # (batch_size, embed_dim)
        
        # Classification
        logits = self.head(cls_features)  # (batch_size, num_classes)
        
        return logits
    
    def get_attention_maps(self, x: torch.Tensor, layer_idx: int = -1) -> torch.Tensor:
        """
        Extract attention maps for visualization
        """
        batch_size = x.shape[0]
        
        # Forward pass up to the specified layer
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x = x + self.pos_embed
        x = self.dropout(x)
        
        for i, block in enumerate(self.blocks):
            if i == layer_idx or (layer_idx == -1 and i == len(self.blocks) - 1):
                # Extract attention from this layer
                x_norm = block.norm1(x)
                qkv = block.attn.qkv(x_norm).reshape(batch_size, x.shape[1], 3, block.attn.num_heads, block.attn.head_dim)
                qkv = qkv.permute(2, 0, 3, 1, 4)
                q, k, v = qkv.unbind(0)
                
                attn = (q @ k.transpose(-2, -1)) * block.attn.scale
                attn = F.softmax(attn, dim=-1)
                
                # Return attention maps (focusing on CLS token attention to patches)
                return attn[:, :, 0, 1:]  # (batch_size, num_heads, n_patches)
            
            x = block(x)
        
        return None

def vit_model(model_size: str = "base", num_classes: int = 8, img_size: int = 256) -> PathologyViT:
    configs = {
        "tiny": {"embed_dim": 192, "depth": 12, "num_heads": 3},
        "small": {"embed_dim": 384, "depth": 12, "num_heads": 6},
        "base": {"embed_dim": 768, "depth": 12, "num_heads": 12},
        "large": {"embed_dim": 1024, "depth": 24, "num_heads": 16}
    }
    
    if model_size not in configs:
        raise ValueError(f"Model size {model_size} not supported. Choose from {list(configs.keys())}")
    
    config = configs[model_size]
    
    return PathologyViT(
        img_size=img_size,
        patch_size=16,
        in_channels=3,
        num_classes=num_classes,
        embed_dim=config["embed_dim"],
        depth=config["depth"],
        num_heads=config["num_heads"],
        mlp_ratio=4.0,
        dropout=0.1
    )


# CNN MODEL

In [7]:
class ConvBlock(nn.Module):
    """Basic convolutional block with BatchNorm and ReLU"""
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, 
                 stride: int = 1, padding: int = 1, use_dropout: bool = False, dropout_rate: float = 0.2):
        super().__init__()
        
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.use_dropout = use_dropout
        if use_dropout:
            self.dropout = nn.Dropout2d(dropout_rate)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        if self.use_dropout:
            x = self.dropout(x)
        return x

class ResidualBlock(nn.Module):
    """Residual block with skip connection"""
    def __init__(self, in_channels: int, out_channels: int, stride: int = 1, downsample: bool = False):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        self.downsample = None
        if downsample or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity
        out = self.relu(out)
        
        return out

class SEBlock(nn.Module):
    """Squeeze-and-Excitation block for channel attention"""
    def __init__(self, channels: int, reduction: int = 16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class PathologyConvNet(nn.Module):
    def __init__(
        self,
        num_classes: int = 8,
        in_channels: int = 3,
        base_channels: int = 64,
        use_se: bool = True,
        use_residual: bool = True,
        dropout_rate: float = 0.2
    ):
        super().__init__()
        
        self.num_classes = num_classes
        self.use_se = use_se
        self.use_residual = use_residual
        
        # Initial convolution - larger kernel to capture more context
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        # Feature extraction layers
        self.layer1 = self._make_layer(base_channels, base_channels, 2, stride=1)
        self.layer2 = self._make_layer(base_channels, base_channels * 2, 2, stride=2)
        self.layer3 = self._make_layer(base_channels * 2, base_channels * 4, 3, stride=2)
        self.layer4 = self._make_layer(base_channels * 4, base_channels * 8, 3, stride=2)
        self.layer5 = self._make_layer(base_channels * 8, base_channels * 8, 2, stride=2)
        
        # Global pooling and classification
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(dropout_rate)
        
        # Feature dimension for downstream tasks
        self.feature_dim = base_channels * 8
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(self.feature_dim, self.feature_dim // 2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(self.feature_dim // 2, num_classes)
        )
        
        # Initialize weights
        self._init_weights()
    
    def _make_layer(self, in_channels: int, out_channels: int, num_blocks: int, stride: int = 1):
        layers = []
        
        # First block (may have stride > 1 for downsampling)
        if self.use_residual:
            layers.append(ResidualBlock(in_channels, out_channels, stride, stride > 1))
            in_channels = out_channels
            
            # Remaining blocks
            for _ in range(1, num_blocks):
                layers.append(ResidualBlock(in_channels, out_channels))
        else:
            # Simple conv blocks
            layers.append(ConvBlock(in_channels, out_channels, stride=stride, use_dropout=True))
            in_channels = out_channels
            
            for _ in range(1, num_blocks):
                layers.append(ConvBlock(in_channels, out_channels, use_dropout=True))
        
        # Add SE block if requested
        if self.use_se:
            layers.append(SEBlock(out_channels))
        
        return nn.Sequential(*layers)
    
    def _init_weights(self):
        """Initialize weights using He initialization for ReLU networks"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        """
        Returns: (batch_size, feature_dim) tensor
        """
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        
        x = self.global_pool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        
        return x
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for classification
        """
        features = self.forward_features(x)
        logits = self.classifier(features)
        return logits
    
    def get_feature_maps(self, x: torch.Tensor, layer_name: str = 'layer4') -> torch.Tensor:
        """
        Extract intermediate feature maps for visualization
        """
        x = self.stem(x)
        if layer_name == 'stem':
            return x
        
        x = self.layer1(x)
        if layer_name == 'layer1':
            return x
        
        x = self.layer2(x)
        if layer_name == 'layer2':
            return x
        
        x = self.layer3(x)
        if layer_name == 'layer3':
            return x
        
        x = self.layer4(x)
        if layer_name == 'layer4':
            return x
        
        x = self.layer5(x)
        if layer_name == 'layer5':
            return x
        
        return x

class LightweightConvNet(nn.Module):
    def __init__(
        self,
        num_classes: int = 8,
        in_channels: int = 3,
        base_channels: int = 32,
        dropout_rate: float = 0.3
    ):
        super().__init__()
        
        self.num_classes = num_classes
        
        self.features = nn.Sequential(
            # Block 1
            ConvBlock(in_channels, base_channels, kernel_size=3, stride=1),
            ConvBlock(base_channels, base_channels, kernel_size=3, stride=1),
            nn.MaxPool2d(2, 2),
            
            # Block 2  
            ConvBlock(base_channels, base_channels * 2, kernel_size=3, stride=1),
            ConvBlock(base_channels * 2, base_channels * 2, kernel_size=3, stride=1),
            nn.MaxPool2d(2, 2),
            
            # Block 3
            ConvBlock(base_channels * 2, base_channels * 4, kernel_size=3, stride=1),
            ConvBlock(base_channels * 4, base_channels * 4, kernel_size=3, stride=1),
            nn.MaxPool2d(2, 2),
            
            # Block 4
            ConvBlock(base_channels * 4, base_channels * 8, kernel_size=3, stride=1),
            ConvBlock(base_channels * 8, base_channels * 8, kernel_size=3, stride=1),
            nn.AdaptiveAvgPool2d(1)
        )
        
        self.feature_dim = base_channels * 8
        
        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(self.feature_dim, self.feature_dim // 2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(self.feature_dim // 2, num_classes)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = torch.flatten(x, 1)
        return x
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.forward_features(x)
        logits = self.classifier(features)
        return logits

def cnn_model(model_size: str = "standard",num_classes: int = 8,**kwargs) -> nn.Module:
    
    if model_size == "small":
        return LightweightConvNet(num_classes=num_classes, **kwargs)
    
    elif model_size == "base":
        return PathologyConvNet(
            num_classes=num_classes,
            base_channels=64,
            use_se=False,
            use_residual=True,
            **kwargs
        )
    
    elif model_size == "large":
        return PathologyConvNet(
            num_classes=num_classes,
            base_channels=64,
            use_se=True,
            use_residual=True,
            **kwargs
        )
    
    else:
        raise ValueError(f"model_size must be one of: small, base, large")


In [8]:
model_type = "vit"
variant = "small"

model = vit_model("small", num_classes=8, img_size=256)
# model = cnn_model("small", num_classes=8)

model = model.to(CFG.device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model created successfully!")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: {total_params/1e6:.2f}M parameters")

Model created successfully!
Total parameters: 21,677,960
Trainable parameters: 21,677,960
Model size: 21.68M parameters


In [9]:
model.forward_features(torch.rand(2,3,256,256).to(CFG.device)).shape

torch.Size([2, 384])

# TRAINING

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=CFG.lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=0.5, 
    patience=3, 
    min_lr=1e-7,
    verbose=False
)

print(f"Criterion: {criterion}")
print(f"Optimizer: {optimizer}")
print(f"Scheduler: {scheduler}")

Criterion: CrossEntropyLoss()
Optimizer: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 2e-05
    maximize: False
    weight_decay: 0
)
Scheduler: <torch.optim.lr_scheduler.ReduceLROnPlateau object at 0x79736a583910>




In [11]:
def manage_checkpoints(save_dir, keep_last_n=3):
    checkpoint_pattern = os.path.join(save_dir, 'epoch*.pth')
    checkpoint_files = glob.glob(checkpoint_pattern)
    
    checkpoints = []
    for checkpoint_file in checkpoint_files:
        filename = os.path.basename(checkpoint_file)
        try:
            epoch_num = int(filename.replace('epoch', '').replace('.pth', ''))
            checkpoints.append((epoch_num, checkpoint_file))
        except ValueError:
            continue
    
    checkpoints.sort(reverse=True)
    
    if len(checkpoints) > keep_last_n:
        for _, checkpoint_file in checkpoints[keep_last_n:]:
            try:
                os.remove(checkpoint_file)
                print(f"Removed old checkpoint: {os.path.basename(checkpoint_file)}")
            except Exception as e:
                print(f"Error removing checkpoint {checkpoint_file}: {e}")


In [12]:
print("Starting training...")

best_val_loss = float('inf')
patience_counter = 0
save_dir = "/kaggle/working"

os.makedirs(save_dir, exist_ok=True)

all_step_train_losses = []
all_epoch_train_losses = []
all_epoch_val_losses = []

for epoch in range(CFG.epochs):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch+1}/{CFG.epochs}")
    print(f"{'='*60}")
    
    model.train()
    train_losses = []
    
    train_pbar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}")
    for batch_idx, (data, target) in enumerate(train_pbar):
        data, target = data.to(CFG.device), target.to(CFG.device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        step_loss = loss.item()
        train_losses.append(step_loss)
        all_step_train_losses.append(step_loss)  
        train_pbar.set_postfix({'Loss': f'{step_loss:.4f}'})
    
    avg_train_loss = np.mean(train_losses)
    all_epoch_train_losses.append(avg_train_loss)
    
    model.eval()
    val_losses = []
    correct = 0
    total = 0
    
    with torch.no_grad():
        val_pbar = tqdm(val_loader, desc=f"Validation Epoch {epoch+1}")
        for data, target in val_pbar:
            data, target = data.to(CFG.device), target.to(CFG.device)
            output = model(data)
            loss = criterion(output, target)
            val_losses.append(loss.item())
            
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            val_pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
    
    avg_val_loss = np.mean(val_losses)
    all_epoch_val_losses.append(avg_val_loss)
    
    val_accuracy = 100 * correct / total
    
    print(f"Train Loss: {avg_train_loss:.4f}")
    print(f"Val Loss: {avg_val_loss:.4f}")
    print(f"Val Accuracy: {val_accuracy:.2f}%")
    
    prev_lr = optimizer.param_groups[0]['lr']
    scheduler.step(avg_val_loss)
    new_lr = optimizer.param_groups[0]['lr']
    if prev_lr != new_lr:
        print(f"LR decreased to {new_lr}")
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0

        model_filename = f"epoch{epoch+1}.pth"
        model_path = os.path.join(save_dir, model_filename)
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
            'loss': float(avg_train_loss),
            'val_loss': float(avg_val_loss),
            'val_accuracy': float(val_accuracy),
        }

        torch.save(checkpoint, model_path)
        manage_checkpoints(save_dir, keep_last_n=5)
    else:
        patience_counter += 1
        print(f"Patience: {patience_counter}/{CFG.patience}")
        
        if patience_counter >= CFG.patience:
            print(f"Early stopping triggered after {epoch+1} epochs!")
            break

Starting training...

Epoch 1/50


Training Epoch 1: 100%|██████████| 450/450 [00:40<00:00, 11.00it/s, Loss=2.0358]
Validation Epoch 1: 100%|██████████| 50/50 [00:01<00:00, 28.18it/s, Loss=2.0143]


Train Loss: 2.0322
Val Loss: 2.0173
Val Accuracy: 0.00%

Epoch 2/50


Training Epoch 2: 100%|██████████| 450/450 [00:40<00:00, 11.04it/s, Loss=1.8265]
Validation Epoch 2: 100%|██████████| 50/50 [00:01<00:00, 30.17it/s, Loss=1.0714]


Train Loss: 1.8405
Val Loss: 1.8333
Val Accuracy: 10.40%

Epoch 3/50


Training Epoch 3: 100%|██████████| 450/450 [00:40<00:00, 11.11it/s, Loss=1.6857]
Validation Epoch 3: 100%|██████████| 50/50 [00:01<00:00, 30.60it/s, Loss=1.8662]


Train Loss: 1.5905
Val Loss: 1.7619
Val Accuracy: 35.80%

Epoch 4/50


Training Epoch 4: 100%|██████████| 450/450 [00:40<00:00, 11.11it/s, Loss=1.2196]
Validation Epoch 4: 100%|██████████| 50/50 [00:01<00:00, 30.74it/s, Loss=0.8510]


Train Loss: 1.4334
Val Loss: 1.1537
Val Accuracy: 59.40%

Epoch 5/50


Training Epoch 5: 100%|██████████| 450/450 [00:40<00:00, 11.11it/s, Loss=1.1550]
Validation Epoch 5: 100%|██████████| 50/50 [00:01<00:00, 30.11it/s, Loss=6.1617]


Train Loss: 1.2848
Val Loss: 3.6530
Val Accuracy: 37.60%
Patience: 1/5

Epoch 6/50


Training Epoch 6: 100%|██████████| 450/450 [00:40<00:00, 11.11it/s, Loss=1.3563]
Validation Epoch 6: 100%|██████████| 50/50 [00:01<00:00, 30.22it/s, Loss=1.7635]


Train Loss: 1.1703
Val Loss: 1.6563
Val Accuracy: 44.80%
Patience: 2/5

Epoch 7/50


Training Epoch 7: 100%|██████████| 450/450 [00:40<00:00, 11.11it/s, Loss=0.7987]
Validation Epoch 7: 100%|██████████| 50/50 [00:01<00:00, 30.26it/s, Loss=6.2305]


Train Loss: 1.0895
Val Loss: 3.5392
Val Accuracy: 38.20%
Patience: 3/5

Epoch 8/50


Training Epoch 8: 100%|██████████| 450/450 [00:40<00:00, 11.11it/s, Loss=1.4518]
Validation Epoch 8: 100%|██████████| 50/50 [00:01<00:00, 30.23it/s, Loss=3.1261]


Train Loss: 1.0634
Val Loss: 2.1780
Val Accuracy: 44.20%
LR decreased to 1e-05
Patience: 4/5

Epoch 9/50


Training Epoch 9: 100%|██████████| 450/450 [00:40<00:00, 11.10it/s, Loss=0.5318]
Validation Epoch 9: 100%|██████████| 50/50 [00:01<00:00, 30.53it/s, Loss=1.0551]

Train Loss: 0.9974
Val Loss: 1.3029
Val Accuracy: 47.20%
Patience: 5/5
Early stopping triggered after 9 epochs!





# TESTING 

In [13]:
best_model_path = os.path.join(save_dir, "epoch1.pth")
if os.path.exists(best_model_path):
    checkpoint = torch.load(best_model_path )
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded model from epoch {checkpoint['epoch']}")

print("TESTING")

model.eval()
test_losses = []
correct = 0
total = 0
class_correct = list(0. for i in range(8))
class_total   = list(0. for i in range(8))

with torch.no_grad():
    test_pbar = tqdm(test_loader, desc="Testing")
    for data, target in test_pbar:
        data, target = data.to(CFG.device), target.to(CFG.device)
        output = model(data)
        loss = criterion(output, target)
        test_losses.append(loss.item())
        
        _, predicted = torch.max(output, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        c = (predicted == target)
        
        if c.dim() == 0:  
            c = c.unsqueeze(0)
            
        for i in range(target.size(0)):
            label = target[i].item() 
            class_correct[label] += c[i].item()
            class_total[label] += 1
        
        test_pbar.set_postfix({'Loss': f'{loss.item():.4f}'})

avg_test_loss = np.mean(test_losses)
test_accuracy = 100 * correct / total

print(f"\nTest Results:")
print(f"Test Loss: {avg_test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.2f}%")
print(f"\nPer-class Accuracy:")
for i in range(8):
    if class_total[i] > 0:
        accuracy = 100 * class_correct[i] / class_total[i]
        print(f"Class {i}: {accuracy:.2f}% ({int(class_correct[i])}/{int(class_total[i])})")
    else:
        print(f"Class {i}: No samples")

Loaded model from epoch 1
TESTING


Testing: 100%|██████████| 11/11 [00:00<00:00, 18.62it/s, Loss=1.9146]


Test Results:
Test Loss: 1.9359
Test Accuracy: 5.50%

Per-class Accuracy:
Class 0: No samples
Class 1: 0.00% (0/1)
Class 2: 0.00% (0/13)
Class 3: 0.00% (0/60)
Class 4: No samples
Class 5: 100.00% (6/6)
Class 6: 0.00% (0/29)
Class 7: No samples





# FINAL CHECKPOINT

In [14]:
final_model_path = os.path.join(save_dir, "finalcheckpoint.pth")
torch.save({
    'model_state_dict': model.state_dict(),
    'num_classes': 8,
    'img_size': 256,
    'test_accuracy': float(test_accuracy),
    'test_loss': float(avg_test_loss),
}, final_model_path)

In [15]:
try:
    np.save(os.path.join(save_dir,  f"train_step_losses_{model_type}_{variant}_scratch.npy"), np.array(all_step_train_losses))
    np.save(os.path.join(save_dir, f"train_epoch_losses_{model_type}_{variant}_scratch.npy"), np.array(all_epoch_train_losses))
    np.save(os.path.join(save_dir,   f"val_epoch_losses_{model_type}_{variant}_scratch.npy"), np.array(all_epoch_val_losses))

except Exception as e:
    print(e)