In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision.models import vit_b_16
from torch.cuda.amp import GradScaler, autocast

# Paths to the Tiny ImageNet dataset
TRAIN_DIR = 'tiny-imagenet-200/train'
VAL_DIR = 'tiny-imagenet-200/val'
WORDS_FILE = 'tiny-imagenet-200/words.txt'
WNIDS_FILE = 'tiny-imagenet-200/wnids.txt'

# Hyperparameters
BATCH_SIZE = 128  # Increased batch size
NUM_EPOCHS = 50   # Increased number of epochs
LEARNING_RATE = 1e-4  # Reduced learning rate
NUM_CLASSES = 200
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class TinyImageNetDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
        self.le = LabelEncoder()
        unique_labels = sorted(set(labels))  # Sort the unique labels
        self.le.fit(unique_labels)
        self.encoded_labels = self.le.transform(labels)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path = self.data[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.encoded_labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

def load_training_data():
    data = []
    labels = []
    class_dirs = sorted(os.listdir(TRAIN_DIR))  # Ensure consistent class order
    for wnid in class_dirs:
        wnid_dir = os.path.join(TRAIN_DIR, wnid, 'images')
        if os.path.isdir(wnid_dir):
            image_files = sorted(os.listdir(wnid_dir))  # Ensure consistent image order
            for img_file in image_files:
                img_path = os.path.join(wnid_dir, img_file)
                data.append(img_path)
                labels.append(wnid)
    return data, labels

def load_validation_data():
    val_data = []
    val_labels = []
    annotations_file = os.path.join(VAL_DIR, 'val_annotations.txt')
    val_annotations = {}
    with open(annotations_file, 'r') as f:
        for line in f:
            tokens = line.strip().split('\t')
            filename, wnid = tokens[0], tokens[1]
            val_annotations[filename] = wnid
    images_dir = os.path.join(VAL_DIR, 'images')
    image_files = sorted(os.listdir(images_dir))  # Ensure consistent image order
    for img_file in image_files:
        img_path = os.path.join(images_dir, img_file)
        if img_file in val_annotations:
            val_data.append(img_path)
            val_labels.append(val_annotations[img_file])
    return val_data, val_labels

# Updated transform with data augmentation and adjusted image size
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Load training data
train_data, train_labels = load_training_data()
train_dataset = TinyImageNetDataset(train_data, train_labels, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

# Load validation data
val_data, val_labels = load_validation_data()

# Split validation data into new validation and test sets
val_data, test_data, val_labels, test_labels = train_test_split(
    val_data, val_labels, test_size=0.5, random_state=42, stratify=val_labels)

# Create validation dataset and dataloader
val_dataset = TinyImageNetDataset(val_data, val_labels, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# Create test dataset and dataloader
test_dataset = TinyImageNetDataset(test_data, test_labels, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# Initialize the pretrained ViT model
model = vit_b_16(pretrained=True)
model.heads = nn.Linear(model.heads.head.in_features, NUM_CLASSES)
model = model.to(DEVICE)

# Define loss function with label smoothing
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# Define optimizer with weight decay
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)

# Define learning rate scheduler
scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

# Initialize GradScaler for mixed precision training
scaler = GradScaler()

def accuracy(output, target, topk=(1, 3, 5)):
    """Computes the accuracy over the k top predictions"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)  # Get indices of top k predictions
    pred = pred.t()  # Transpose to shape (k, batch_size)
    correct = pred.eq(target.view(1, -1).expand_as(pred))  # Check if predictions are correct

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)  # Count correct predictions
        res.append(correct_k.mul_(100.0 / batch_size))  # Compute accuracy percentage
    return res

def train(model, device, train_loader, criterion, optimizer, scheduler, scaler, epoch):
    model.train()
    running_loss = 0.0
    top1_acc = 0.0
    top3_acc = 0.0
    top5_acc = 0.0
    total = 0

    loop = tqdm(train_loader)
    for batch_idx, (inputs, targets) in enumerate(loop):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        running_loss += loss.item()
        acc1, acc3, acc5 = accuracy(outputs, targets, topk=(1, 3, 5))
        top1_acc += acc1.item()
        top3_acc += acc3.item()
        top5_acc += acc5.item()
        total += 1

        # Update tqdm loop
        loop.set_description(f"Epoch [{epoch}/{NUM_EPOCHS}]")
        loop.set_postfix(loss=(running_loss/total), top1_acc=(top1_acc/total), top3_acc=(top3_acc/total), top5_acc=(top5_acc/total))

    return running_loss/total, top1_acc/total, top3_acc/total, top5_acc/total

def validate(model, device, val_loader, criterion):
    model.eval()
    running_loss = 0.0
    top1_acc = 0.0
    top3_acc = 0.0
    top5_acc = 0.0
    total = 0

    with torch.no_grad():
        loop = tqdm(val_loader)
        for batch_idx, (inputs, targets) in enumerate(loop):
            inputs, targets = inputs.to(device), targets.to(device)

            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets)

            running_loss += loss.item()
            acc1, acc3, acc5 = accuracy(outputs, targets, topk=(1, 3, 5))
            top1_acc += acc1.item()
            top3_acc += acc3.item()
            top5_acc += acc5.item()
            total += 1

            # Update tqdm loop
            loop.set_description(f"Validation")
            loop.set_postfix(loss=(running_loss/total), top1_acc=(top1_acc/total), top3_acc=(top3_acc/total), top5_acc=(top5_acc/total))

    return running_loss/total, top1_acc/total, top3_acc/total, top5_acc/total

def test(model, device, test_loader, criterion):
    model.eval()
    running_loss = 0.0
    top1_acc = 0.0
    top3_acc = 0.0
    top5_acc = 0.0
    total = 0

    with torch.no_grad():
        loop = tqdm(test_loader)
        for batch_idx, (inputs, targets) in enumerate(loop):
            inputs, targets = inputs.to(device), targets.to(device)

            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets)

            running_loss += loss.item()
            acc1, acc3, acc5 = accuracy(outputs, targets, topk=(1, 3, 5))
            top1_acc += acc1.item()
            top3_acc += acc3.item()
            top5_acc += acc5.item()
            total += 1

            # Update tqdm loop
            loop.set_description(f"Test")
            loop.set_postfix(loss=(running_loss/total), top1_acc=(top1_acc/total), top3_acc=(top3_acc/total), top5_acc=(top5_acc/total))

    avg_loss = running_loss / total
    avg_top1_acc = top1_acc / total
    avg_top3_acc = top3_acc / total
    avg_top5_acc = top5_acc / total

    print(f"Test Loss: {avg_loss:.4f}")
    print(f"Test Acc@1: {avg_top1_acc:.2f}%")
    print(f"Test Acc@3: {avg_top3_acc:.2f}%")
    print(f"Test Acc@5: {avg_top5_acc:.2f}%")

    return avg_loss, avg_top1_acc, avg_top3_acc, avg_top5_acc

def plot_metrics(train_metric, val_metric, metric_name):
    epochs = range(1, NUM_EPOCHS + 1)
    plt.figure()
    plt.plot(epochs, train_metric, 'b', label=f'Training {metric_name}')
    plt.plot(epochs, val_metric, 'r', label=f'Validation {metric_name}')
    plt.title(f'Training and Validation {metric_name}')
    plt.xlabel('Epochs')
    plt.ylabel(metric_name)
    plt.legend()
    plt.savefig(f'{metric_name}.png')
    plt.show()

train_losses = []
val_losses = []
train_top1_acc = []
val_top1_acc = []
train_top3_acc = []
val_top3_acc = []
train_top5_acc = []
val_top5_acc = []

best_acc = 0.0

# Training loop
for epoch in range(1, NUM_EPOCHS + 1):
    train_loss, train_acc1, train_acc3, train_acc5 = train(
        model, DEVICE, train_loader, criterion, optimizer, scheduler, scaler, epoch)
    val_loss, val_acc1, val_acc3, val_acc5 = validate(
        model, DEVICE, val_loader, criterion)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_top1_acc.append(train_acc1)
    val_top1_acc.append(val_acc1)
    train_top3_acc.append(train_acc3)
    val_top3_acc.append(val_acc3)
    train_top5_acc.append(train_acc5)
    val_top5_acc.append(val_acc5)

    # Save the best model
    if val_acc1 > best_acc:
        best_acc = val_acc1
        torch.save(model.state_dict(), 'best_model.pth')
        print(f"Best model saved with accuracy: {best_acc:.2f}%")

    print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    print(f"Train Acc@1: {train_acc1:.2f}%, Val Acc@1: {val_acc1:.2f}%")
    print(f"Train Acc@3: {train_acc3:.2f}%, Val Acc@3: {val_acc3:.2f}%")
    print(f"Train Acc@5: {train_acc5:.2f}%, Val Acc@5: {val_acc5:.2f}%")

# Plot Loss
plot_metrics(train_losses, val_losses, 'Loss')

# Plot Top-1 Accuracy
plot_metrics(train_top1_acc, val_top1_acc, 'Top-1 Accuracy')

# Plot Top-3 Accuracy
plot_metrics(train_top3_acc, val_top3_acc, 'Top-3 Accuracy')

# Plot Top-5 Accuracy
plot_metrics(train_top5_acc, val_top5_acc, 'Top-5 Accuracy')

print("Model Parameters:")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"{name}: {param.data.shape}")

# Load the best model
model.load_state_dict(torch.load('best_model.pth'))

# Evaluate on test set
test_loss, test_acc1, test_acc3, test_acc5 = test(model, DEVICE, test_loader, criterion)


FileNotFoundError: [WinError 3] The system cannot find the path specified: 'tiny-imagenet-200/train'