In [2]:
import sys
import os
sys.path.append('..')

from Utils.Accuracy_measures import topk_accuracy
from Utils.TinyImageNet_loader import get_tinyimagenet_dataloaders
import torchvision.transforms as transforms


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
from timm import create_model

# Define Metrics
def calculate_topk_accuracy(outputs, labels, k=1):
    _, top_k_preds = outputs.topk(k, dim=1)
    correct = top_k_preds.eq(labels.view(-1, 1).expand_as(top_k_preds))
    return correct.sum().item()

# Training Function
def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    train_correct = 0
    train_total = 0
    for images, labels in tqdm(train_loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Calculate training accuracy
        _, preds = outputs.max(1)
        train_correct += preds.eq(labels).sum().item()
        train_total += labels.size(0)

    train_accuracy = train_correct / train_total
    return train_accuracy

# Validation Function
def evaluate(model, loader, device):
    model.eval()
    correct1 = correct3 = correct5 = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            correct1 += calculate_topk_accuracy(outputs, labels, k=1)
            correct3 += calculate_topk_accuracy(outputs, labels, k=3)
            correct5 += calculate_topk_accuracy(outputs, labels, k=5)
            total += labels.size(0)

    top1_acc = correct1 / total
    top3_acc = correct3 / total
    top5_acc = correct5 / total
    return top1_acc, top3_acc, top5_acc

# Main Script
if __name__ == "__main__":
    # Hyperparameters
    image_size = 224
    batch_size = 64
    num_epochs = 10
    learning_rate = 0.0001
    weight_decay = 0.01
    
    # Define transforms for training, validation, and testing
    tiny_transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.Resize((image_size, image_size)),
        transforms.RandomCrop(image_size, padding=5),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                            (0.229, 0.224, 0.225))
    ])

    tiny_transform_val = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                            (0.229, 0.224, 0.225))
    ])

    tiny_transform_test = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                            (0.229, 0.224, 0.225))
    ])

    # Load Data (Assume get_tinyimagenet_dataloaders is predefined)
    train_loader, val_loader, test_loader = get_tinyimagenet_dataloaders(
        data_dir='../datasets',
        transform_train=tiny_transform_train,
        transform_val=tiny_transform_val,
        transform_test=tiny_transform_test,
        batch_size=batch_size,
        image_size=image_size
    )

    # Define Model, Loss, Optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=200)
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Store Metrics
    train_accuracies = []
    val_top1_accuracies = []
    val_top3_accuracies = []
    val_top5_accuracies = []

    # Training Loop
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        train_accuracy = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_top1_acc, val_top3_acc, val_top5_acc = evaluate(model, val_loader, device)

        train_accuracies.append(train_accuracy)
        val_top1_accuracies.append(val_top1_acc)
        val_top3_accuracies.append(val_top3_acc)
        val_top5_accuracies.append(val_top5_acc)

        print(f"Train Accuracy: {train_accuracy:.4f}")
        print(f"Val Top-1 Accuracy: {val_top1_acc:.4f}, Top-3 Accuracy: {val_top3_acc:.4f}, Top-5 Accuracy: {val_top5_acc:.4f}")

    # Final Test Evaluation
    print("\nFinal Test Evaluation:")
    test_top1_acc, test_top3_acc, test_top5_acc = evaluate(model, test_loader, device)
    print(f"Test Top-1 Accuracy: {test_top1_acc:.4f}")
    print(f"Test Top-3 Accuracy: {test_top3_acc:.4f}")
    print(f"Test Top-5 Accuracy: {test_top5_acc:.4f}")

    # Plot Accuracy
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, num_epochs + 1), train_accuracies, label="Train Accuracy")
    plt.plot(range(1, num_epochs + 1), val_top1_accuracies, label="Val Top-1 Accuracy")
    plt.plot(range(1, num_epochs + 1), val_top3_accuracies, label="Val Top-3 Accuracy")
    plt.plot(range(1, num_epochs + 1), val_top5_accuracies, label="Val Top-5 Accuracy")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.title("Training and Validation Accuracy")
    plt.legend()
    plt.grid()
    plt.show()


  from .autonotebook import tqdm as notebook_tqdm
