In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from tqdm import tqdm
import glob
import sys
sys.path.append("/kaggle/input/pip-vlmfs")
import timm

# ARGS

In [2]:
class CFG:
    epochs = 50
    train_batch_size = 10
    val_batch_size = 10
    test_batch_size = 10
    
    train_start = 0
    train_end = 4500
    
    val_start = 4500
    val_end =  5000
    
    test_start = 5000
    test_end = 5109
    
    lr = 2e-5
    patience = 5
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# DATASET

In [3]:
class PathologyDataset(Dataset):
    def __init__(self, patches, labels, start_idx, end_idx):
        self.patches = patches
        self.labels = labels
        self.start_idx = start_idx
        self.end_idx = end_idx
        self.length = end_idx - start_idx
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        actual_idx = self.start_idx + idx
        
        patch = self.patches[actual_idx]
        label = self.labels[actual_idx]
        
        patch = torch.tensor(patch) 
        label = torch.tensor(label, dtype=torch.long) 
        
        return patch, label



In [4]:
inputs = np.load("/kaggle/input/miccaireg/images.npy", mmap_mode="r")
labels = np.load("/kaggle/input/miccaireg/labels.npy")

print(f"Inputs shape: {inputs.shape}")
print(f"Labels shape: {labels.shape}")
print(f"Labels range: {labels.min()} to {labels.max()}")
print(f"Unique labels: {np.unique(labels)}")

print(f"Using device: {CFG.device}")
print(f"Training samples: {CFG.train_end - CFG.train_start}")
print(f"Validation samples: {CFG.val_end - CFG.val_start}")
print(f"Test samples: {CFG.test_end - CFG.test_start}")

Inputs shape: (5109, 3, 256, 256)
Labels shape: (5109,)
Labels range: 0 to 7
Unique labels: [0 1 2 3 4 5 6 7]
Using device: cuda
Training samples: 4500
Validation samples: 500
Test samples: 109


In [5]:
train_dataset = PathologyDataset(inputs, labels, CFG.train_start, CFG.train_end)
val_dataset = PathologyDataset(inputs, labels, CFG.val_start, CFG.val_end)
test_dataset = PathologyDataset(inputs, labels, CFG.test_start, CFG.test_end)

train_loader = DataLoader(train_dataset, batch_size=CFG.train_batch_size, shuffle=True, pin_memory=True,num_workers=4)
val_loader   = DataLoader(val_dataset, batch_size=CFG.val_batch_size, shuffle=True, pin_memory=True,num_workers=4)
test_loader  = DataLoader(test_dataset, batch_size=CFG.test_batch_size, shuffle=False, pin_memory=True,num_workers=2)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

sample_batch, sample_labels = next(iter(train_loader))
print(f"Sample batch shape: {sample_batch.shape}")
print(f"Sample labels shape: {sample_labels.shape}")
print(f"Sample labels: {sample_labels}")


Train batches: 450
Val batches: 50
Test batches: 11
Sample batch shape: torch.Size([10, 3, 256, 256])
Sample labels shape: torch.Size([10])
Sample labels: tensor([5, 7, 7, 5, 0, 0, 4, 1, 5, 5])


# MODEL

#### ViT MODEL

In [6]:
class TimmViTModel(nn.Module):
    def __init__(self, num_classes: int = 8, model_name: str = "vit_small_patch16_224"):
        super().__init__()
        
        self.backbone = timm.create_model(
            model_name, 
            pretrained=True, 
            num_classes=0,  
            img_size=256     
        )
        
        self.feature_dim = self.backbone.num_features
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(self.feature_dim, 512),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(512),
        
            nn.Dropout(0.1),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, num_classes)
        )

        
    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.forward_features(x)
        logits = self.classifier(features)
        return logits

#### CNN MODEL

In [7]:
class TimmCNNModel(nn.Module):
    # timm/mobilenetv4_conv_medium.e500_r256_in1k
    def __init__(self, num_classes: int = 8, model_name: str = "efficientnet_b0"):
        super().__init__()
        
        self.backbone = timm.create_model(
            model_name,
            pretrained=True,
            num_classes=0,
            )
        
        self.feature_dim = self.backbone.num_features
        

        self.classifier = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(self.feature_dim, 512),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(512),
        
            nn.Dropout(0.1),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, num_classes)
        )

        
    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.forward_features(x)
        logits = self.classifier(features)
        return logits


In [8]:
model = TimmViTModel(num_classes=8)
# model = TimmCNNModel(num_classes=8)

model = model.to(CFG.device)

total_params      = sum(p.numel() for p in model.parameters())
backbone_params   = sum(p.numel() for p in model.backbone.parameters() if p.requires_grad)
classifier_params = sum(p.numel() for p in model.classifier.parameters() if p.requires_grad)

total_trainable_params      = sum(p.numel() for p in model.parameters() if p.requires_grad)
backbone_trainable_params   = sum(p.numel() for p in model.backbone.parameters() if p.requires_grad)
classifier_trainable_params = sum(p.numel() for p in model.classifier.parameters() if p.requires_grad)

print(f"Total parameters: {total_params / 1e6:.2f}M")
print(f"Backbone parameters: {backbone_params / 1e6:.2f}M")
print(f"Classifier parameters: {classifier_params}")

print(f"Total trainable parameters: {total_trainable_params / 1e6:.2f}M")
print(f"Backbone trainable parameters: {backbone_trainable_params / 1e6:.2f}M")
print(f"Classifier trainable parameters: {classifier_trainable_params}")

model.safetensors:   0%|          | 0.00/88.2M [00:00<?, ?B/s]

Total parameters: 22.02M
Backbone parameters: 21.69M
Classifier parameters: 331528
Total trainable parameters: 22.02M
Backbone trainable parameters: 21.69M
Classifier trainable parameters: 331528


In [9]:
model.backbone.forward_features(torch.rand(2,3,256,256).to(CFG.device)).shape

torch.Size([2, 257, 384])

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=CFG.lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=0.5, 
    patience=3, 
    min_lr=1e-7,
    verbose=False
)

print(f"Criterion: {criterion}")
print(f"Optimizer: {optimizer}")
print(f"Scheduler: {scheduler}")

Criterion: CrossEntropyLoss()
Optimizer: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 2e-05
    maximize: False
    weight_decay: 0
)
Scheduler: <torch.optim.lr_scheduler.ReduceLROnPlateau object at 0x7c1bc2c59490>




In [11]:
def manage_checkpoints(save_dir, keep_last_n=3):
    checkpoint_pattern = os.path.join(save_dir, 'epoch*.pth')
    checkpoint_files = glob.glob(checkpoint_pattern)
    
    checkpoints = []
    for checkpoint_file in checkpoint_files:
        filename = os.path.basename(checkpoint_file)
        try:
            epoch_num = int(filename.replace('epoch', '').replace('.pth', ''))
            checkpoints.append((epoch_num, checkpoint_file))
        except ValueError:
            continue
    
    checkpoints.sort(reverse=True)
    
    if len(checkpoints) > keep_last_n:
        for _, checkpoint_file in checkpoints[keep_last_n:]:
            try:
                os.remove(checkpoint_file)
                print(f"Removed old checkpoint: {os.path.basename(checkpoint_file)}")
            except Exception as e:
                print(f"Error removing checkpoint {checkpoint_file}: {e}")


In [12]:
print("Starting training...")

best_val_loss = float('inf')
patience_counter = 0
save_dir = "/kaggle/working"
os.makedirs(save_dir, exist_ok=True)

all_step_train_losses = []
all_epoch_train_losses = []
all_epoch_val_losses = []

for epoch in range(CFG.epochs):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch+1}/{CFG.epochs}")
    print(f"{'='*60}")
    
    model.train()
    train_losses = []

    train_pbar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}")
    for batch_idx, (data, target) in enumerate(train_pbar):
        data, target = data.to(CFG.device), target.to(CFG.device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        step_loss = loss.item()
        train_losses.append(step_loss)
        all_step_train_losses.append(step_loss)
        train_pbar.set_postfix({'Loss': f'{step_loss:.4f}'})
    
    avg_train_loss = np.mean(train_losses)
    all_epoch_train_losses.append(avg_train_loss)

    model.eval()
    val_losses = []
    correct = 0
    total = 0

    with torch.no_grad():
        val_pbar = tqdm(val_loader, desc=f"Validation Epoch {epoch+1}")
        for data, target in val_pbar:
            data, target = data.to(CFG.device), target.to(CFG.device)
            output = model(data)
            loss = criterion(output, target)
            val_losses.append(loss.item())

            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

            val_pbar.set_postfix({'Loss': f'{loss.item():.4f}'})

    avg_val_loss = np.mean(val_losses)
    all_epoch_val_losses.append(avg_val_loss)

    val_accuracy = 100 * correct / total

    print(f"Train Loss: {avg_train_loss:.4f}")
    print(f"Val Loss: {avg_val_loss:.4f}")
    print(f"Val Accuracy: {val_accuracy:.2f}%")

    prev_lr = optimizer.param_groups[0]['lr']
    scheduler.step(avg_val_loss)
    new_lr = optimizer.param_groups[0]['lr']
    if prev_lr != new_lr:
        print(f"LR decreased to {new_lr}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0

        model_filename = f"epoch{epoch+1}.pth"
        model_path = os.path.join(save_dir, model_filename)
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
            'loss': float(avg_train_loss),
            'val_loss': float(avg_val_loss),
            'val_accuracy': float(val_accuracy),
        }

        torch.save(checkpoint, model_path)
        manage_checkpoints(save_dir, keep_last_n=5)
    else:
        patience_counter += 1
        print(f"Patience: {patience_counter}/{CFG.patience}")

        if patience_counter >= CFG.patience:
            print(f"Early stopping triggered after {epoch+1} epochs!")
            break

np.save(os.path.join(save_dir, "train_step_losses.npy"), np.array(all_step_train_losses))
np.save(os.path.join(save_dir, "train_epoch_losses.npy"), np.array(all_epoch_train_losses))
np.save(os.path.join(save_dir, "val_epoch_losses.npy"), np.array(all_epoch_val_losses))

Starting training...

Epoch 1/50


Training Epoch 1: 100%|██████████| 450/450 [00:44<00:00, 10.08it/s, Loss=0.6658]
Validation Epoch 1: 100%|██████████| 50/50 [00:01<00:00, 30.39it/s, Loss=0.8709]


Train Loss: 1.3016
Val Loss: 0.6075
Val Accuracy: 86.60%

Epoch 2/50


Training Epoch 2: 100%|██████████| 450/450 [00:44<00:00, 10.16it/s, Loss=0.2265]
Validation Epoch 2: 100%|██████████| 50/50 [00:01<00:00, 33.37it/s, Loss=0.5759]


Train Loss: 0.6921
Val Loss: 0.3962
Val Accuracy: 88.80%

Epoch 3/50


Training Epoch 3: 100%|██████████| 450/450 [00:44<00:00, 10.20it/s, Loss=0.9964]
Validation Epoch 3: 100%|██████████| 50/50 [00:01<00:00, 32.65it/s, Loss=0.0333]


Train Loss: 0.5282
Val Loss: 0.3872
Val Accuracy: 86.60%

Epoch 4/50


Training Epoch 4: 100%|██████████| 450/450 [00:44<00:00, 10.18it/s, Loss=0.3630]
Validation Epoch 4: 100%|██████████| 50/50 [00:01<00:00, 34.08it/s, Loss=0.3757]


Train Loss: 0.4052
Val Loss: 0.3759
Val Accuracy: 88.80%

Epoch 5/50


Training Epoch 5: 100%|██████████| 450/450 [00:44<00:00, 10.21it/s, Loss=0.1784]
Validation Epoch 5: 100%|██████████| 50/50 [00:01<00:00, 34.05it/s, Loss=0.3717]


Train Loss: 0.3065
Val Loss: 0.3186
Val Accuracy: 90.20%

Epoch 6/50


Training Epoch 6: 100%|██████████| 450/450 [00:44<00:00, 10.20it/s, Loss=0.2271]
Validation Epoch 6: 100%|██████████| 50/50 [00:01<00:00, 34.05it/s, Loss=0.0101]


Train Loss: 0.2421
Val Loss: 0.4164
Val Accuracy: 88.00%
Patience: 1/5

Epoch 7/50


Training Epoch 7: 100%|██████████| 450/450 [00:44<00:00, 10.20it/s, Loss=0.1696]
Validation Epoch 7: 100%|██████████| 50/50 [00:01<00:00, 33.89it/s, Loss=0.5057]


Train Loss: 0.1747
Val Loss: 0.3236
Val Accuracy: 90.60%
Patience: 2/5

Epoch 8/50


Training Epoch 8: 100%|██████████| 450/450 [00:44<00:00, 10.19it/s, Loss=0.1751]
Validation Epoch 8: 100%|██████████| 50/50 [00:01<00:00, 34.10it/s, Loss=1.3918]


Train Loss: 0.1556
Val Loss: 0.3865
Val Accuracy: 89.60%
Patience: 3/5

Epoch 9/50


Training Epoch 9: 100%|██████████| 450/450 [00:44<00:00, 10.20it/s, Loss=0.0184]
Validation Epoch 9: 100%|██████████| 50/50 [00:01<00:00, 33.92it/s, Loss=0.4479]


Train Loss: 0.1586
Val Loss: 0.4753
Val Accuracy: 87.80%
LR decreased to 1e-05
Patience: 4/5

Epoch 10/50


Training Epoch 10: 100%|██████████| 450/450 [00:44<00:00, 10.20it/s, Loss=0.4749]
Validation Epoch 10: 100%|██████████| 50/50 [00:01<00:00, 34.14it/s, Loss=0.7237]

Train Loss: 0.0942
Val Loss: 0.4302
Val Accuracy: 88.20%
Patience: 5/5
Early stopping triggered after 10 epochs!





# TESTING 

In [13]:
import os
import re

folder = '/kaggle/working'
files = [f for f in os.listdir(folder) if re.match(r'epoch\d+\.pth', f)]

files.sort(key=lambda x: int(re.search(r'\d+', x).group()))

latest_file = files[-1] if files else None

In [14]:
best_model_path = os.path.join(save_dir, latest_file)
if os.path.exists(best_model_path):
    checkpoint = torch.load(best_model_path )
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded model from epoch {checkpoint['epoch']}")

print("TESTING")

model.eval()
test_losses = []
correct = 0
total = 0
class_correct = list(0. for i in range(8))
class_total = list(0. for i in range(8))

with torch.no_grad():
    test_pbar = tqdm(test_loader, desc="Testing")
    for data, target in test_pbar:
        data, target = data.to(CFG.device), target.to(CFG.device)
        output = model(data)
        loss = criterion(output, target)
        test_losses.append(loss.item())
        
        _, predicted = torch.max(output, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        c = (predicted == target)
        
        if c.dim() == 0:  
            c = c.unsqueeze(0)
            
        for i in range(target.size(0)):
            label = target[i].item() 
            class_correct[label] += c[i].item()
            class_total[label] += 1
        
        test_pbar.set_postfix({'Loss': f'{loss.item():.4f}'})

avg_test_loss = np.mean(test_losses)
test_accuracy = 100 * correct / total

print(f"\nTest Results:")
print(f"Test Loss: {avg_test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.2f}%")
print(f"\nPer-class Accuracy:")
for i in range(8):
    if class_total[i] > 0:
        accuracy = 100 * class_correct[i] / class_total[i]
        print(f"Class {i}: {accuracy:.2f}% ({int(class_correct[i])}/{int(class_total[i])})")
    else:
        print(f"Class {i}: No samples")

Loaded model from epoch 5
TESTING


Testing: 100%|██████████| 11/11 [00:00<00:00, 19.63it/s, Loss=0.0060]


Test Results:
Test Loss: 0.0768
Test Accuracy: 96.33%

Per-class Accuracy:
Class 0: No samples
Class 1: 0.00% (0/1)
Class 2: 100.00% (13/13)
Class 3: 98.33% (59/60)
Class 4: No samples
Class 5: 83.33% (5/6)
Class 6: 96.55% (28/29)
Class 7: No samples





# FINAL CHECKPOINT

In [15]:
final_model_path = os.path.join(save_dir, "finalcheckpoint.pth")
torch.save({
    'model_state_dict': model.state_dict(),
    'num_classes': 8,
    'img_size': 256,
    'test_accuracy': float(test_accuracy),
    'test_loss': float(avg_test_loss),
}, final_model_path)

In [16]:
latest_file = "finalcheckpoint.pth"
best_model_path = os.path.join(save_dir, latest_file)
if os.path.exists(best_model_path):
    checkpoint = torch.load(best_model_path )
    model.load_state_dict(checkpoint['model_state_dict'])

print("TESTING")

model.eval()
test_losses = []
correct = 0
total = 0
class_correct = list(0. for i in range(8))
class_total = list(0. for i in range(8))

with torch.no_grad():
    test_pbar = tqdm(test_loader, desc="Testing")
    for data, target in test_pbar:
        data, target = data.to(CFG.device), target.to(CFG.device)
        output = model(data)
        loss = criterion(output, target)
        test_losses.append(loss.item())
        
        _, predicted = torch.max(output, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        c = (predicted == target)
        
        if c.dim() == 0:  
            c = c.unsqueeze(0)
            
        for i in range(target.size(0)):
            label = target[i].item() 
            class_correct[label] += c[i].item()
            class_total[label] += 1
        
        test_pbar.set_postfix({'Loss': f'{loss.item():.4f}'})

avg_test_loss = np.mean(test_losses)
test_accuracy = 100 * correct / total

print(f"\nTest Results:")
print(f"Test Loss: {avg_test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.2f}%")
print(f"\nPer-class Accuracy:")
for i in range(8):
    if class_total[i] > 0:
        accuracy = 100 * class_correct[i] / class_total[i]
        print(f"Class {i}: {accuracy:.2f}% ({int(class_correct[i])}/{int(class_total[i])})")
    else:
        print(f"Class {i}: No samples")

TESTING


Testing: 100%|██████████| 11/11 [00:00<00:00, 27.61it/s, Loss=0.0060]


Test Results:
Test Loss: 0.0768
Test Accuracy: 96.33%

Per-class Accuracy:
Class 0: No samples
Class 1: 0.00% (0/1)
Class 2: 100.00% (13/13)
Class 3: 98.33% (59/60)
Class 4: No samples
Class 5: 83.33% (5/6)
Class 6: 96.55% (28/29)
Class 7: No samples



