# Libraries

In [None]:
import os
import csv
import time
import yaml
import shutil
import random
import numpy as np
from tqdm import tqdm
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt 
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader

# Utils

### Change this to True if you want to download into output
https://www.kaggle.com/datasets/dimensi0n/imagenet-256

In [None]:
DOWNLOAD = False

In [None]:
def load_yaml(path):
    with open(path, "r") as f:
        cfg = yaml.safe_load(f)
    return cfg

def set_seed(seed: int = 42):
    random.seed(seed)                     # Python random
    np.random.seed(seed)                  # NumPy
    torch.manual_seed(seed)               # CPU
    torch.cuda.manual_seed(seed)          # GPU
    torch.cuda.manual_seed_all(seed)      # All GPUs
    torch.backends.cudnn.deterministic = True  # Deterministic convs
    torch.backends.cudnn.benchmark = False     # Disable auto-tuner for reproducibility
    print(f"Random seed set to {seed}")

def save_training_plots(
    loss_history,
    train_acc_history,
    test_acc_history,
    epoch_times,
    output_dir="outputs/plots"
):
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    epochs = np.arange(1, len(loss_history) + 1)

    # Loss plot
    plt.figure()
    plt.plot(epochs, loss_history, label="Train Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Loss")
    plt.legend()
    plt.grid(True)
    plt.savefig(output_dir / "loss.png")
    plt.close()

    # Accuracy plot
    plt.figure()
    plt.plot(epochs, train_acc_history, label="Train Accuracy")
    plt.plot(epochs, test_acc_history, label="Test Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Train vs Test Accuracy")
    plt.legend()
    plt.grid(True)
    plt.savefig(output_dir / "accuracy.png")
    plt.close()

    # Time per epoch plot
    plt.figure()
    plt.plot(epochs, epoch_times, label="Time per Epoch (s)")
    plt.xlabel("Epoch")
    plt.ylabel("Seconds")
    plt.title("Epoch Time")
    plt.legend()
    plt.grid(True)
    plt.savefig(output_dir / "epoch_time.png")
    plt.close()

    print(f"\nPlots saved to: {output_dir.resolve()}")

In [None]:
VGG_CFG = load_yaml("/kaggle/input/vgg16-config/vgg.yaml")
DATA_CFG = load_yaml("/kaggle/input/vgg16-config/imagenet.yaml")

# Dataset

In [None]:
def download_data(data_dir):
    data_dir = Path(data_dir)
    data_dir.mkdir(parents=True, exist_ok=True)

    download_path = kagglehub.dataset_download("dimensi0n/imagenet-256")

    print("Path to dataset files:", download_path)

    print("Moving data into ", data_dir)
    shutil.move(os.path.join(download_path, "versions", "1"), data_dir)
    return data_dir

def process_data(download):
    if download:
        data_path = download_data(DATA_CFG['root'])
    else:
        data_path = "/kaggle/input/imagenet-256"
        new_data_path = "/kaggle/working/"

    data_path = Path(data_path)
    new_data_path = Path(new_data_path)
    train_path = Path(new_data_path) / "train"
    test_path = Path(new_data_path) / "test"

    train_path.mkdir(parents=True, exist_ok=True)
    test_path.mkdir(parents=True, exist_ok=True)

    class_count = 0
    image_count = 0
    for class_dir in tqdm(list(data_path.iterdir())):
        if class_dir == train_path or class_dir == test_path: continue
        if class_dir.is_dir():
            img_paths = []
            class_name = class_dir.name
            class_count += 1
            
            for img_path in class_dir.iterdir():
                if img_path.suffix.lower() in [".jpg", ".jpeg", ".png"]:
                    img_paths.append(img_path)
                    image_count += 1

            random.shuffle(img_paths)

            split_idx = int(len(img_paths) * DATA_CFG["split_ratio"])
            train_imgs = img_paths[:split_idx]
            test_imgs   = img_paths[split_idx:]

            (train_path / class_name).mkdir(parents=True, exist_ok=True)
            (test_path / class_name).mkdir(parents=True, exist_ok=True)

            for img in train_imgs:
                shutil.copy(img, train_path / class_name / img.name)

            for img in test_imgs:
                shutil.copy(img, test_path / class_name / img.name)

    print(f"Class Count: {class_count} \t Image Count: {image_count} \t Average Image per Class: {class_count/image_count:.3f}")

In [None]:
if DOWNLOAD:
    process_data(download=False)

In [None]:
def build_transforms(image_size=224, train=True):
    if train:
        return T.Compose([
            T.RandomResizedCrop(image_size),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(mean=DATA_CFG["mean"], std=DATA_CFG["std"])
        ])
    else:
        return T.Compose([
            T.Resize(256),
            T.CenterCrop(image_size),
            T.ToTensor(),
            T.Normalize(mean=DATA_CFG["mean"], std=DATA_CFG["std"])
        ])
        
class ImageNetDataset(Dataset):
    def __init__(self, root, split="train", transform=None):
        self.root = Path(root) / split
        self.transform = transform

        # Scan class folders
        self.classes = sorted([d.name for i, d in enumerate(self.root.iterdir()) if d.is_dir() and i % (1000 // DATA_CFG.get("num_classes", 1000)) == 0])
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}

        # Build list of (image_path, label)
        self.samples = []
        for cls_name in self.classes:
            cls_folder = self.root / cls_name
            for img_path in cls_folder.iterdir():
                if img_path.suffix.lower() == ".jpg":
                    self.samples.append((img_path, self.class_to_idx[cls_name]))

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        img_path, label = self.samples[index]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Model

In [None]:
class VGG16(nn.Module):
    def __init__(self, num_classes=1000):
        super(VGG16, self).__init__()

        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Black 2
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 3
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2),

            # Block 4
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2),

            # Block 5
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2)
        )

        self.classifer = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512*7*7, VGG_CFG.get("fc_layer", 1024)),  # assuming input 224x224
            nn.BatchNorm1d(VGG_CFG.get("fc_layer", 1024)),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(VGG_CFG.get("fc_layer", 1024), VGG_CFG.get("fc_layer", 1024)),
            nn.BatchNorm1d(VGG_CFG.get("fc_layer", 1024)),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(VGG_CFG.get("fc_layer", 1024), num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifer(x)
        return x

# Train

In [None]:
# Config
set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

if DOWNLOAD:
    data_root = "/kaggle/working/"
else:
    data_root = "/kaggle/input/imagenet/imagenet"
    
# Datasets & loaders
train_dataset = ImageNetDataset(root=data_root, 
                                split="train", 
                                transform=build_transforms(DATA_CFG["image_size"], train=True))
test_dataset   = ImageNetDataset(root=data_root, 
                                split="test",  
                                transform=build_transforms(DATA_CFG["image_size"], train=False))

train_loader = DataLoader(train_dataset, 
                          batch_size=DATA_CFG["batch_size"], 
                          shuffle=True, 
                          num_workers=DATA_CFG["num_workers"],
                          drop_last=True)
test_loader   = DataLoader(test_dataset, 
                          batch_size=DATA_CFG["batch_size"], 
                          shuffle=False, 
                          num_workers=DATA_CFG["num_workers"],
                          drop_last=True)

# Model, loss, optimizer
model = VGG16(num_classes=DATA_CFG.get("num_classes", 1000)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
    model.parameters(),
    lr=float(VGG_CFG.get("lr", 0.001)),
    momentum=float(VGG_CFG.get("momentum", 0.9)),
    weight_decay=float(VGG_CFG.get("weight_decay", 1e-4)),
)
scheduler = optim.lr_scheduler.StepLR(
    optimizer,
    step_size=int(VGG_CFG.get("step_size", 30)),
    gamma=float(VGG_CFG.get("gamma", 0.1)),
)

# Checkpoint
num_epochs = VGG_CFG.get("epochs", 50)
output_dir = Path("/kaggle/working/outputs/checkpoints")
output_dir.mkdir(parents=True, exist_ok=True)

start_epoch = 0
best_acc = 0.0

loss_history = []
train_acc_history = []
test_acc_history = []
epoch_times = []

if VGG_CFG.get("start_from", None) is not None and not isinstance(VGG_CFG.get("start_from", None), str):
    ckpt_epoch = int(VGG_CFG["start_from"])
    model_dir = Path("/kaggle/input/vgg16/pytorch/default/1")
    
    ckpt_path = model_dir / f"vgg16_epoch_{ckpt_epoch}.pth"

    checkpoint = torch.load(ckpt_path, map_location=device)

    model.load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"])
    scheduler.load_state_dict(checkpoint["scheduler_state"])

    best_acc = checkpoint.get("best_acc", 0.0)

    loss_history = checkpoint.get("loss_history", [])
    train_acc_history = checkpoint.get("train_acc_history", [])
    test_acc_history = checkpoint.get("test_acc_history", [])
    epoch_times = checkpoint.get("epoch_times", [])

    start_epoch = checkpoint["epoch"] + 1

    print(f"Resumed from epoch {start_epoch}")
    
# Training loop
for epoch in range(start_epoch, num_epochs+1):
    start_time = time.time()
    
    # Training
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    for images, labels in tqdm(train_loader, desc=f"[Train] Epoch {epoch+1}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct_train += (preds == labels).sum().item()
        total_train += labels.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    train_acc = correct_train / total_train
    loss_history.append(epoch_loss)
    train_acc_history.append(train_acc)

    # Test
    model.eval()
    correct_test = 0
    total_test = 0

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc=f"[Test] Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct_test += (preds == labels).sum().item()
            total_test += labels.size(0)

    test_acc = correct_test / total_test
    test_acc_history.append(test_acc)

    epoch_time = time.time() - start_time
    epoch_times.append(epoch_time)

    print(f"Epoch {epoch+1} | Loss: {epoch_loss:.4f} | Train Acc: {train_acc*100:.2f}% | Test Acc: {test_acc*100:.2f}% | Time: {epoch_time:.2f}s")

    # Save plots
    save_training_plots(
        loss_history=loss_history,
        train_acc_history=train_acc_history,
        test_acc_history=test_acc_history,
        epoch_times=epoch_times,
        output_dir="outputs/plots"
    )

    # Save checkpoint
    ckpt = {
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict(),
        "best_acc": best_acc,
    
        # histories
        "loss_history": loss_history,
        "train_acc_history": train_acc_history,
        "test_acc_history": test_acc_history,
        "epoch_times": epoch_times,
    }
    
    ckpt_path = output_dir / f"vgg16_epoch_{epoch}.pth"
    torch.save(ckpt, ckpt_path)
    print(f"Checkpoint saved to : {ckpt_path}")
    
    if test_acc > best_acc:
        best_acc = test_acc
        best_ckpt_path = output_dir / "vgg16_best.pth"
        torch.save(ckpt, best_ckpt_path)
        print(f"Saved best model to {best_ckpt_path}")

    scheduler.step()

print("\nTraining Summary")
print(f"Best Test Accuracy: {best_acc*100:.2f}%")
print(f"Total time: {sum(epoch_times):.2f} seconds")
print(f"Avg time/epoch: {np.mean(epoch_times):.2f} seconds")
print(f"Min epoch time: {np.min(epoch_times):.2f} seconds")
print(f"Max epoch time: {np.max(epoch_times):.2f} seconds")

# Inference

In [None]:
def inference(params_path, topk=(1,5)):
    # Setup
    set_seed(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # Create output directory for plots
    plots_dir = Path("/kaggle/working/outputs/plots")
    plots_dir.mkdir(parents=True, exist_ok=True)

    # Create output directory for confusion
    metric_dir = Path("/kaggle/working/outputs/metrics")
    metric_dir.mkdir(parents=True, exist_ok=True)

    if DOWNLOAD:
        data_root = "/kaggle/working/"
    else:
        data_root = "/kaggle/input/imagenet/imagenet"

    # Data
    test_dataset = ImageNetDataset(
        root=data_root, 
        split="test",  
        transform=build_transforms(DATA_CFG["image_size"], train=False)
    )
    test_loader = DataLoader(
        test_dataset, batch_size=64, shuffle=False, num_workers=1
    )
    idx_to_class = test_dataset.classes

    # Model
    model = VGG16(num_classes=DATA_CFG.get("num_classes", len(idx_to_class))).to(device)
    ckpt_path = Path("/kaggle/working/outputs/checkpoints") / params_path
    checkpoint = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(checkpoint["model_state"])
    model.eval()

    # Metrics Tracking
    total = 0
    topk_correct = [0] * len(topk)
    confusion_counter = Counter()      # (true, pred)
    per_class_total = Counter()        # true
    per_class_correct = Counter()      # true & correct

    # Inference Loop
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc=f"[Inference]"):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            logits = model(images)
            probs = torch.softmax(logits, dim=1)

            # Top-k accuracy
            for i, k in enumerate(topk):
                topk_preds = torch.topk(probs, k, dim=1).indices
                topk_correct[i] += (
                    topk_preds == labels.unsqueeze(1)
                ).any(dim=1).sum().item()

            # Top-1 predictions
            preds = torch.argmax(probs, dim=1)

            for t, p in zip(labels.cpu().numpy(), preds.cpu().numpy()):
                per_class_total[t] += 1
                if t == p:
                    per_class_correct[t] += 1
                else:
                    confusion_counter[(t, p)] += 1

            total += labels.size(0)

    # Print accuracy
    print("\nAccuracy:")
    for i, k in enumerate(topk):
        acc = topk_correct[i] / total
        print(f"Top-{k}: {acc:.4f}")

    # Confusion analysis
    most_confused = confusion_counter.most_common(10)

    print("\nTop 10 most confused class pairs (true -> predicted):")
    for (t, p), count in most_confused:
        print(f"{idx_to_class[t]} -> {idx_to_class[p]} : {count}")

    # Bar plot for top 10 most confused
    if most_confused:
        labels_plot = [
            f"{idx_to_class[t]}->{idx_to_class[p]}"
            for (t, p), _ in most_confused
        ]
        counts = [c for _, c in most_confused]

        plt.figure(figsize=(10, 5))
        plt.bar(range(len(counts)), counts)
        plt.xticks(range(len(counts)), labels_plot, rotation=45)
        plt.ylabel("Count")
        plt.title("Top 10 Most Confused Class Pairs")
        plt.tight_layout()

        plot_path = plots_dir / "most_confused_pairs.png"
        plt.savefig(plot_path)
        plt.close()
        print(f"\nConfusion plot saved to: {plot_path}")

    # Selected pairs for 3x4 image grid
    selected_pairs = [
        ("sidewinder", "horned_viper"),
        ("desktop_computer", "screen"),
        ("blenheim_spaniel", "welsh_springer_spaniel"),
        ("barn_spider", "wolf_spider"),
        ("potpie", "bagel"),
        ("bedlington_terrier", "miniature_poodle")
    ]
    class_name_to_idx = {name: idx for idx, name in enumerate(idx_to_class)}

    fig, axes = plt.subplots(3, 4, figsize=(16, 12))
    axes = axes.reshape(3, 4)  # ensure shape

    for idx, (true_name, pred_name) in enumerate(selected_pairs):
        row = idx // 2
        col = (idx % 2) * 2  # 0 or 2

        t = class_name_to_idx[true_name]
        p = class_name_to_idx[pred_name]

        # Get all images for true and predicted classes
        t_imgs = [img_path for img_path, label in test_dataset.samples if label == t]
        p_imgs = [img_path for img_path, label in test_dataset.samples if label == p]

        # Randomly pick 2 images per class
        t_img1, t_img2 = random.sample(t_imgs, 2)
        p_img1, p_img2 = random.sample(p_imgs, 2)

        img_list = [
            (t_img1, true_name),
            (p_img1, pred_name),
            (t_img2, true_name),
            (p_img2, pred_name)
        ]

        for i in range(2):
            axes[row, col + i].imshow(Image.open(img_list[i][0]).convert("RGB"))
            axes[row, col + i].axis("off")
            axes[row, col + i].set_title(img_list[i][1], fontsize=10)

        # Black vertical line between pair columns
        axes[row, col + 1].spines['left'].set_color('black')
        axes[row, col + 1].spines['left'].set_linewidth(2)

    plt.tight_layout()
    sample_img_path = plots_dir / "most_confused_pairs_samples.png"
    plt.savefig(sample_img_path)
    plt.close()
    print(f"\nSample images of confused pairs saved to: {sample_img_path}")

    # Per-class accuracy CSV
    class_accuracy = []
    for cls in per_class_total:
        acc = per_class_correct[cls] / per_class_total[cls]
        class_accuracy.append(
            (cls, idx_to_class[cls], acc, per_class_correct[cls], per_class_total[cls])
        )

    # Sort high -> low accuracy
    class_accuracy.sort(key=lambda x: x[2], reverse=True)

    csv_path = metric_dir / "per_class_accuracy.csv"
    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["class_id", "class_name", "accuracy", "correct", "total"])
        for cls, name, acc, correct, total_cls in class_accuracy:
            writer.writerow([cls, name, f"{acc:.4f}", correct, total_cls])

    print(f"\nPer-class accuracy CSV saved to: {csv_path}")

param_path = Path("/kaggle/input/vgg16/pytorch/default/1") / "vgg16_epoch_100.pth"
inference(param_path)

In [None]:
def summarize_checkpoint_times(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    
    # Check if epoch_times exists
    if "epoch_times" not in ckpt:
        print("Checkpoint does not contain 'epoch_times'.")
        return None

    epoch_times = ckpt["epoch_times"]
    total_time = sum(epoch_times)
    avg_time = total_time / len(epoch_times)

    def format_hms(seconds):
        h = int(seconds // 3600)
        m = int((seconds % 3600) // 60)
        s = int(seconds % 60)
        return f"{h}h {m}m {s}s"

    print(f"Average epoch time: {format_hms(avg_time)}")
    print(f"Total training time: {format_hms(total_time)}")
    
    return avg_time, total_time

ckpt_file = Path("/kaggle/input/vgg16/pytorch/default/1/vgg16_epoch_100.pth")
summarize_checkpoint_times(ckpt_file)