In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torchvision import models as vision_models
from sklearn.metrics import precision_score as calc_precision, recall_score as calc_recall

# Configuration parameters
bsz = 128
num_epochs = 20
lr_val = 0.001
net_arch = "resnet50"

def setup_loaders(bsz=128):
    tr_trans = T.Compose([
        T.RandomHorizontalFlip(),
        T.RandomCrop(32, padding=4),
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    te_trans = T.Compose([
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    train_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=tr_trans)
    dl_train = DataLoader(train_data, batch_size=bsz, shuffle=True, num_workers=2)

    test_data = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=te_trans)
    dl_test = DataLoader(test_data, batch_size=bsz, shuffle=False, num_workers=2)

    return dl_train, dl_test

def build_model(arch):
    if arch == "resnet18":
        net = vision_models.resnet18(weights=vision_models.ResNet18_Weights.DEFAULT)
        net.fc = nn.Linear(net.fc.in_features, 10)
    elif arch == "resnet50":
        net = vision_models.resnet50(weights=vision_models.ResNet50_Weights.DEFAULT)
        net.fc = nn.Linear(net.fc.in_features, 10)
    elif arch == "vgg16":
        net = vision_models.vgg16(weights=vision_models.VGG16_Weights.DEFAULT)
        net.classifier[6] = nn.Linear(net.classifier[6].in_features, 10)
    else:
        raise ValueError("Model architecture not supported.")

    for p in net.parameters():
        p.requires_grad = True

    return net

def run_training(net, loader_train, loader_test, epochs, lr):
    dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net.to(dev)

    loss_func = nn.CrossEntropyLoss()
    opt = optim.Adam(net.parameters(), lr=lr)
    lr_adjust = optim.lr_scheduler.StepLR(opt, step_size=5, gamma=0.1)

    history_train, history_val, accuracy_val = [], [], []
    best_acc = 0.0

    for ep in range(epochs):
        net.train()
        tot_loss = 0.0

        for imgs, lbls in loader_train:
            imgs, lbls = imgs.to(dev), lbls.to(dev)
            opt.zero_grad()
            outs = net(imgs)
            cost = loss_func(outs, lbls)
            cost.backward()
            opt.step()
            tot_loss += cost.item()

        history_train.append(tot_loss / len(loader_train))

        net.eval()
        correct_pred, total_samples = 0, 0
        tot_val_loss = 0.0
        pred_list, label_list = [], []

        with torch.no_grad():
            for imgs, lbls in loader_test:
                imgs, lbls = imgs.to(dev), lbls.to(dev)
                outs = net(imgs)
                cost = loss_func(outs, lbls)
                tot_val_loss += cost.item()
                _, preds = torch.max(outs, 1)
                total_samples += lbls.size(0)
                correct_pred += (preds == lbls).sum().item()
                pred_list.extend(preds.cpu().numpy())
                label_list.extend(lbls.cpu().numpy())

        history_val.append(tot_val_loss / len(loader_test))
        epoch_acc = correct_pred / total_samples
        accuracy_val.append(epoch_acc)

        prc = calc_precision(label_list, pred_list, average='macro', zero_division=1)
        rcl = calc_recall(label_list, pred_list, average='macro', zero_division=1)

        if epoch_acc > best_acc:
            best_acc = epoch_acc
            torch.save(net.state_dict(), "optimal_model.pth")

        print(f"Epoch {ep+1}/{epochs} - Train Loss: {history_train[-1]:.4f}, Val Loss: {history_val[-1]:.4f}, "
              f"Acc: {epoch_acc:.4f}, Precision: {prc:.4f}, Recall: {rcl:.4f}")

        lr_adjust.step()

    return net, history_train, history_val, accuracy_val

def display_stats(train_hist, val_hist, val_acc):
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(train_hist, label='Training Loss')
    plt.plot(val_hist, label='Validation Loss')
    plt.legend()
    plt.title('Loss per Epoch')

    plt.subplot(1, 2, 2)
    plt.plot(val_acc, label='Validation Accuracy')
    plt.legend()
    plt.title('Accuracy per Epoch')
    plt.show()

def store_model(net, file_path="final_cifar10_model.pth"):
    torch.save(net.state_dict(), file_path)
    print(f"Model saved at: {file_path}")

# Main execution flow
if __name__ == "__main__":
    loader_tr, loader_te = setup_loaders(bsz)
    network = build_model(net_arch)
    trained_net, loss_history, val_loss_history, acc_history = run_training(network, loader_tr, loader_te, num_epochs, lr_val)
    display_stats(loss_history, val_loss_history, acc_history)
    store_model(trained_net)
