In [15]:
import numpy as np 
import pandas as pd 
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import os
from PIL import Image
import json
import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torch.amp import autocast, GradScaler
import time
import copy

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# for dirname, _, filenames in os.walk('/kaggle/input/imagenet100'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# Preprocessing Steps

In [2]:
def collect_all_data(data_dir, json_mapping_path):
    """
    Collect all image paths and labels in a deterministic way
    """
    data_dir = (data_dir)
    
    # Load class mapping
    with open(json_mapping_path, 'r') as f:
        folder_to_class = json.load(f)
    
    # Create class to index mapping
    classes = sorted(list(set(folder_to_class.values())))
    class_to_idx = {cls: idx for idx, cls in enumerate(classes)}
    
    # Collect all image paths and labels
    all_image_paths = []
    all_labels = []
    directories = os.listdir(data_dir)[1:]

    # Sort folder names for deterministic ordering
    folder_names = []
    folder_paths = []

    for d in directories:
        folder_names.extend([f for f in os.listdir(f'{data_dir}/{d}')
                              if f in folder_to_class])
        folder_paths.extend([os.path.join(f'{data_dir}/{d}', f) for f in os.listdir(f'{data_dir}/{d}')
                              if f in folder_to_class])
    # print(folder_names[0])
    # print(folder_paths[0])
    for i, folder_name in enumerate(folder_names):
        if os.path.isdir(folder_paths[i]):
            class_name = folder_to_class[folder_name]
            class_idx = class_to_idx[class_name]
            
            # Get all image files and sort them for deterministic ordering
            image_files = sorted([
                os.path.join(folder_paths[i], img_file) for img_file in os.listdir(folder_paths[i])
                if img_file.endswith('.JPEG')
            ])
            # print(image_files)
            # for img_file in image_files:
            all_image_paths.extend(image_files)
            all_labels.extend([class_idx]*len(image_files))
    
    return all_image_paths, all_labels, classes, class_to_idx, folder_to_class
# collect_all_data(root, f'{root}/Labels.json')

In [3]:
# split data paths and labels
def split(all_image_paths, all_labels, test_size=0.2, val_size=0.1, random_state=42):
    """
    Split data once to ensure mutual exclusivity
    """
    # print(f"Splitting: {all_image_paths[0]}")
    # First split: separate test set
    X_temp, X_test, y_temp, y_test = train_test_split(
        all_image_paths, all_labels,
        test_size=test_size,
        stratify=all_labels,
        random_state=random_state
    )
    
    # Second split: separate train and validation from remaining data
    val_size_adjusted = val_size / (1 - test_size)
    X_train, X_val, y_train, y_val = train_test_split(
        X_temp, y_temp,
        test_size=val_size_adjusted,
        stratify=y_temp,
        random_state=random_state
    )
    # print(X_train[0])
    # Verify mutual exclusivity
    train_set = set(X_train)
    val_set = set(X_val)
    test_set = set(X_test)
    # print(train_set[0])
    
    assert len(train_set.intersection(val_set)) == 0, "Train and validation sets overlap!"
    assert len(train_set.intersection(test_set)) == 0, "Train and test sets overlap!"
    assert len(val_set.intersection(test_set)) == 0, "Validation and test sets overlap!"
    
    print("✓ Data splits verified as mutually exclusive")
    print(f"Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")
    
    return {
        'train': {'paths': X_train, 'labels': y_train},
        'val': {'paths': X_val, 'labels': y_val},
        'test': {'paths': X_test, 'labels': y_test}
    }
# img_path_splits = split(img_paths, labels)
# img_path_splits['train']['paths'][0]

In [4]:
class imagenet100(Dataset):
    def __init__(self, image_paths, labels, classes, classes_to_idx, transform = None, split=None):
        self.transform = transform
        self.image_paths = image_paths
        self.labels = labels
        self.classes = classes
        self.classes_to_idx = classes_to_idx
        self.num_classes = len(classes)
        
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label


In [5]:
from torchvision import transforms

def get_transforms(input_size):
    return transforms.Compose([
        transforms.Resize((input_size, input_size)),     # Match ResNet18 input
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

In [6]:
# train_size = int(0.8*len(dataset))
# val_size = len(dataset) - train_size
# train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

def create_data_loaders(data_dir, json_mapping, batch_size=64, num_workers=4, input_size=224):
    img_paths, labels, classes, class_to_idx, folder_to_class = collect_all_data(data_dir, json_mapping)
    # print(img_paths[0]) 
    splits = split(img_paths, labels)
    transform = get_transforms(input_size)
    print(splits['train']['paths'][0])
    # Create datasets
    train_dataset = imagenet100(
        splits['train']['paths'], 
        splits['train']['labels'], 
        classes,
        class_to_idx,
        transform=transform, 
    )
    
    val_dataset = imagenet100(
        splits['val']['paths'], 
        splits['val']['labels'], 
        classes,
        class_to_idx,
        transform=transform, 
    )
    
    test_dataset = imagenet100(
        splits['test']['paths'], 
        splits['test']['labels'], 
        classes,
        class_to_idx,
        transform=transform, 
    )
    # print(train_dataset.image_paths)
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=num_workers,
        pin_memory=True  # Faster GPU transfer
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=num_workers,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=num_workers,
        pin_memory=True
    )
    print(f"Dataset sizes:")
    print(f"Train: {len(train_dataset)} images")
    print(f"Validation: {len(val_dataset)} images")
    print(f"Test: {len(test_dataset)} images")
    print(f"Number of classes: {len(classes)}")
    
    return train_loader, val_loader, test_loader, len(classes)


In [7]:
root = f"/kaggle/input/imagenet100"
train_loader, val_loader, test_loader, num_classes = create_data_loaders(root, os.path.join(root, 'Labels.json'))

✓ Data splits verified as mutually exclusive
Train: 94500, Val: 13500, Test: 27000
/kaggle/input/imagenet100/train.X4/n01860187/n01860187_1819.JPEG
Dataset sizes:
Train: 94500 images
Validation: 13500 images
Test: 27000 images
Number of classes: 100


In [8]:
print("\nTesting DataLoaders...")
for batch_idx, (images, labels) in enumerate(train_loader):
    print(f"Batch {batch_idx}: Images shape: {images.shape}, Labels shape: {labels.shape}")
    print(f"Label range: {labels.min().item()} to {labels.max().item()}")
    if batch_idx == 2:  # Just show first few batches
        break


# Show some class information
print(f"\nFirst 10 classes: {train_loader.dataset.classes[:10]}")
print(f"Class to index mapping (first 5): {dict(list(train_loader.dataset.classes_to_idx.items())[:5])}")
print(f"\n first 10 image paths: {train_loader.dataset.image_paths[:10]}")


Testing DataLoaders...
Batch 0: Images shape: torch.Size([64, 3, 224, 224]), Labels shape: torch.Size([64])
Label range: 0 to 99
Batch 1: Images shape: torch.Size([64, 3, 224, 224]), Labels shape: torch.Size([64])
Label range: 0 to 97
Batch 2: Images shape: torch.Size([64, 3, 224, 224]), Labels shape: torch.Size([64])
Label range: 4 to 91

First 10 classes: ['American alligator, Alligator mississipiensis', 'American coot, marsh hen, mud hen, water hen, Fulica americana', 'Dungeness crab, Cancer magister', 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis', 'agama', 'albatross, mollymawk', 'axolotl, mud puppy, Ambystoma mexicanum', 'bald eagle, American eagle, Haliaeetus leucocephalus', 'banded gecko', 'barn spider, Araneus cavaticus']
Class to index mapping (first 5): {'American alligator, Alligator mississipiensis': 0, 'American coot, marsh hen, mud hen, water hen, Fulica americana': 1, 'Dungeness crab, Cancer magister': 2, 'Komodo dragon, Komodo lizard,

# RESNET18

In [9]:
def conv3x3(in_channels, out_channels, stride, dilation=1):
    return nn.Conv2d(
        in_channels=in_channels, 
        out_channels=out_channels, 
        kernel_size=3, 
        stride=stride, 
        dilation=dilation, 
        padding=dilation, 
        bias=False,
    )
    
class block(nn.Module):
    '''
    Basic Block: 3x3 Conv -> Batch Norm 1 -> ReLU -> 3x3 Conv -> Batch Norm 2 -> += initial -> ReLU
    '''
    
    def __init__(self, in_channels, out_channels, stride=1):
        super(block, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride=stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.ReLU = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels, stride=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # add another layer if channel
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    def forward(self, x):
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.ReLU(out)

        out = self.conv2(out)
        out = self.bn2(out)

        # skip connection / identity matching
        out += self.shortcut(x)
        out = self.ReLU(out)
        
        return out

class resnet18(nn.Module):
    def __init__(self, num_classes=100,stride=1):
        super(resnet18, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(
            in_channels=3, 
            out_channels=self.in_channels, 
            kernel_size=7, 
            stride=2, 
            padding=1, 
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.ReLU = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self.make_layer(64,  2, stride=1)
        self.layer2 = self.make_layer(128, 2, stride=2)
        self.layer3 = self.make_layer(256, 2, stride=2)
        self.layer4 = self.make_layer(512, 2, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512, num_classes)
    
        
    def make_layer(self, out_channels, num_blocks, stride=1):
        strides = [stride] + [1] * (num_blocks-1)
        layers = []
        for s in strides:
            layers.append(block(self.in_channels, out_channels, stride=s))
            self.in_channels = out_channels
        return nn.Sequential(*layers)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.ReLU(x)
        x = self.maxpool(x)

        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        
        return out
        


# CTM

In [10]:
def get_model_size(model):
    # Count total parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    # Calculate memory usage (assumes float32 = 4 bytes)
    total_size = sum(p.numel() * p.element_size() for p in model.parameters())
    
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Model size: {total_size / 1024**2:.2f} MB")
    print(f"Model size: {total_size / 1024**3:.3f} GB")
    
    return total_params, total_size

# Usage
total_params, model_size = get_model_size(resnet18())

Total parameters: 11,227,812
Trainable parameters: 11,227,812
Model size: 42.83 MB
Model size: 0.042 GB


In [11]:
import gc 
# Check current usage"
def print_memory():
    print(f"Allocated: {torch.cuda.memory_allocated()/1024**2:.0f} MB")
    print(f"Reserved: {torch.cuda.memory_reserved()/1024**2:.0f} MB")
    # Get detailed memory info
    print(torch.cuda.memory_summary())

def clean_memory():
    # Check current usage
    del model, optimizer, criterion  # Replace with your variable names
    gc.collect()
    torch.cuda.empty_cache()
print_memory()
clean_memory()


Allocated: 0 MB
Reserved: 0 MB
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|--------------------------------

UnboundLocalError: cannot access local variable 'model' where it is not associated with a value

# TRain Model

In [12]:
def validate_model(model, val_loader, criterion, device):
    """
    Comprehensive validation function
    """
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
            
            # Forward pass
            output = model(data)
            loss = criterion(output, target)
            
            # Statistics
            val_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            # Store for detailed metrics
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            
            # Clean up memory
            del data, target, output, loss
    
    # Calculate metrics
    avg_loss = val_loss / len(val_loader)
    accuracy = 100.0 * correct / total
    
    return avg_loss, accuracy, all_preds, all_targets

def train_epoch(model, train_loader, optimizer, criterion, device, 
                grad_clip_norm=1.0, print_interval=100):
    """
    Train for one epoch with progress tracking
    """
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to('cuda', non_blocking=True), labels.to('cuda', non_blocking=True)
        
        # Forward pass
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # Print progress
        if batch_idx % print_interval == 0:
            print(f'Batch {batch_idx}/{len(train_loader)}, '
                  f'Loss: {loss.item():.4f}, '
                  f'Acc: {100.*correct/total:.2f}%')
        
        # Memory cleanup
        if batch_idx % 50 == 0:
            torch.cuda.empty_cache()
        
        # del data, labels, output, loss
    
    avg_loss = running_loss / len(train_loader)
    accuracy = 100.0 * correct / total
    
    return avg_loss, accuracy

class EarlyStopping:
    """
    Early stopping to prevent overfitting
    """
    def __init__(self, patience=7, min_delta=0.001, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = None
        self.counter = 0
        self.best_weights = None
        
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
        else:
            self.counter += 1
            
        if self.counter >= self.patience:
            if self.restore_best_weights:
                model.load_state_dict(self.best_weights)
            return True
        return False
    
    def save_checkpoint(self, model):
        self.best_weights = copy.deepcopy(model.state_dict())


In [17]:
def train(model, train_loader, val_loader, num_epochs=10, 
                          learning_rate=0.001, device='cuda', patience=10):
    """
    Complete training loop with validation and early stopping
    """
    # Setup training components
    model.to(device)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    early_stopping = EarlyStopping(patience=patience, restore_best_weights=True)
    
    # Training history
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [],
        'learning_rates': []
    }
    
    print(f"Starting training for {num_epochs} epochs...")
    print(f"Device: {device}")
    print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
    print("-" * 60)
    
    for epoch in range(num_epochs):
        start_time = time.time()
        
        # Training phase
        train_loss, train_acc = train_epoch(
            model, train_loader, optimizer, criterion, device
        )
        
        # Validation phase
        val_loss, val_acc, val_preds, val_targets = validate_model(
            model, val_loader, criterion, device
        )
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        
        # Store history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['learning_rates'].append(current_lr)
        
        # Print epoch results
        epoch_time = time.time() - start_time
        print(f"\nEpoch {epoch+1}/{num_epochs} ({epoch_time:.1f}s)")
        print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
        print(f"Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%")
        print(f"LR: {current_lr:.6f}")
        
        # Early stopping check
        if early_stopping(val_loss, model):
            print(f"\nEarly stopping triggered at epoch {epoch+1}")
            print(f"Best validation loss: {early_stopping.best_loss:.4f}")
            break
        
        # Memory cleanup
        torch.cuda.empty_cache()
        print("-" * 60)
    
    return history, model


In [16]:
model = resnet18()
history, model = train(model, train_loader, val_loader)

Starting training for 10 epochs...
Device: cuda
Train batches: 1477, Val batches: 211
------------------------------------------------------------
Batch 0/1477, Loss: 4.6514, Acc: 0.00%
Batch 100/1477, Loss: 4.5330, Acc: 3.00%
Batch 200/1477, Loss: 4.5050, Acc: 3.97%
Batch 300/1477, Loss: 4.2283, Acc: 4.68%
Batch 400/1477, Loss: 3.9372, Acc: 5.15%
Batch 500/1477, Loss: 4.1974, Acc: 5.74%
Batch 600/1477, Loss: 4.0947, Acc: 6.25%
Batch 700/1477, Loss: 3.7566, Acc: 7.08%
Batch 800/1477, Loss: 3.7826, Acc: 7.87%
Batch 900/1477, Loss: 3.8561, Acc: 8.70%
Batch 1000/1477, Loss: 3.6027, Acc: 9.60%
Batch 1100/1477, Loss: 3.5132, Acc: 10.48%
Batch 1200/1477, Loss: 3.3873, Acc: 11.38%
Batch 1300/1477, Loss: 3.4763, Acc: 12.26%
Batch 1400/1477, Loss: 3.3101, Acc: 13.05%

Epoch 1/10 (317.7s)
Train - Loss: 3.8596, Acc: 13.63%
Val   - Loss: 3.4835, Acc: 21.24%
LR: 0.001000
------------------------------------------------------------
Batch 0/1477, Loss: 3.1798, Acc: 32.81%
Batch 100/1477, Loss: 3.1046

In [42]:
model.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[-1.3664e-01, -1.1040e-01,  4.4606e-02,  ..., -2.2538e-03,
                         -1.1823e-01, -6.9003e-02],
                        [-6.4455e-02, -8.4982e-02,  3.1349e-02,  ...,  4.8349e-02,
                         -6.6766e-02, -2.9746e-02],
                        [ 3.5980e-02, -8.1370e-02, -1.0576e-01,  ...,  7.6561e-02,
                          1.5189e-03,  2.1370e-02],
                        ...,
                        [-1.1154e-01,  7.2375e-04, -4.4782e-02,  ...,  1.1009e-01,
                          6.2710e-02, -8.9515e-02],
                        [-1.2802e-01, -7.9691e-02, -5.4215e-02,  ...,  1.5370e-01,
                          1.2523e-01, -3.6549e-02],
                        [-3.5150e-03, -3.2835e-02, -3.2212e-02,  ..., -1.2887e-02,
                         -1.0561e-02, -1.2962e-02]],
              
                       [[-2.0618e-01, -1.1127e-01,  4.7792e-02,  ..., -1.1628e-01,
                         -1.4385

In [43]:
# Save checkpoint
torch.save(model.state_dict(), '/kaggle/working/model1.pth')

In [None]:
def plot_training_history(history):
    """
    Plot training and validation metrics
    """
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss plot
    ax1.plot(epochs, history['train_loss'], 'b-', label='Training Loss')
    ax1.plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Accuracy plot
    ax2.plot(epochs, history['train_acc'], 'b-', label='Training Accuracy')
    ax2.plot(epochs, history['val_acc'], 'r-', label='Validation Accuracy')
    ax2.set_title('Training and Validation Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True)
    
    # Learning rate plot
    ax3.plot(epochs, history['learning_rates'], 'g-')
    ax3.set_title('Learning Rate Schedule')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Learning Rate')
    ax3.set_yscale('log')
    ax3.grid(True)
    
    # Overfitting check
    train_val_diff = [t - v for t, v in zip(history['train_acc'], history['val_acc'])]
    ax4.plot(epochs, train_val_diff, 'purple')
    ax4.axhline(y=0, color='black', linestyle='--', alpha=0.5)
    ax4.set_title('Training vs Validation Gap')
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Train Acc - Val Acc (%)')
    ax4.grid(True)
    
    plt.tight_layout()
    plt.show()

def detailed_validation_report(model, val_loader, device, class_names=None):
    """
    Generate detailed validation report with confusion matrix
    """
    val_loss, val_acc, val_preds, val_targets = validate_model(
        model, val_loader, nn.CrossEntropyLoss(), device
    )
    
    print(f"\n{'='*50}")
    print("DETAILED VALIDATION REPORT")
    print(f"{'='*50}")
    print(f"Validation Accuracy: {val_acc:.2f}%")
    print(f"Validation Loss: {val_loss:.4f}")
    
    # Classification report
    if class_names is None:
        class_names = [f"Class_{i}" for i in range(len(set(val_targets)))]
    
    print(f"\n{'-'*50}")
    print("CLASSIFICATION REPORT")
    print(f"{'-'*50}")
    print(classification_report(val_targets, val_preds, target_names=class_names))
    
    # Confusion matrix
    cm = confusion_matrix(val_targets, val_preds)
    print(f"\n{'-'*50}")
    print("CONFUSION MATRIX")
    print(f"{'-'*50}")
    print(cm)
    
    return val_acc, val_loss, cm