In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize(224),  # AlexNet requires 224x224
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_len = int(0.8 * len(train_dataset))
val_len = len(train_dataset) - train_len
train_data, val_data = random_split(train_dataset, [train_len, val_len])

batch_size = 64
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

alexnet = models.alexnet(pretrained=True)

alexnet.classifier[6] = nn.Linear(alexnet.classifier[6].in_features, 10)
alexnet.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(alexnet.parameters(), lr=1e-4)

def train(model, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        print(f"Epoch [{epoch+1}] Loss: {total_loss:.4f}")

def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
    return 100 * correct / total

train(alexnet, epochs=5)
test_acc = evaluate(alexnet, test_loader)
print(f"\nAlexNet CIFAR-10 Test Accuracy: {test_acc:.2f}%")


100%|██████████| 170M/170M [00:02<00:00, 76.0MB/s]
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [00:01<00:00, 153MB/s]
Epoch 1: 100%|██████████| 625/625 [02:02<00:00,  5.12it/s]


Epoch [1] Loss: 380.3750


Epoch 2: 100%|██████████| 625/625 [01:56<00:00,  5.35it/s]


Epoch [2] Loss: 204.2916


Epoch 3: 100%|██████████| 625/625 [01:56<00:00,  5.37it/s]


Epoch [3] Loss: 138.1592


Epoch 4: 100%|██████████| 625/625 [01:56<00:00,  5.38it/s]


Epoch [4] Loss: 97.5611


Epoch 5: 100%|██████████| 625/625 [01:54<00:00,  5.46it/s]


Epoch [5] Loss: 73.0543

AlexNet CIFAR-10 Test Accuracy: 89.91%


In [4]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve
from sklearn.preprocessing import label_binarize

def evaluate_with_preds(model, loader):
    model.eval()
    y_true, y_pred, y_prob = [], [], []
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            preds = np.argmax(probs, axis=1)

            y_true.extend(labels.numpy())
            y_pred.extend(preds)
            y_prob.extend(probs)
    return np.array(y_true), np.array(y_pred), np.array(y_prob)

def plot_confusion_matrix(y_true, y_pred, class_names):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.tight_layout()
    plt.savefig("confusion_matrix.png")
    plt.close()

def plot_roc_auc(y_true, y_prob, class_names):
    y_true_bin = label_binarize(y_true, classes=np.arange(len(class_names)))
    plt.figure(figsize=(10, 8))
    for i in range(len(class_names)):
        fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_prob[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f"{class_names[i]} (AUC={roc_auc:.2f})")
    plt.plot([0, 1], [0, 1], 'k--')
    plt.title("ROC-AUC Curve (One-vs-All)")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("roc_auc_curve.png")
    plt.close()

def plot_precision_recall(y_true, y_prob, class_names):
    y_true_bin = label_binarize(y_true, classes=np.arange(len(class_names)))
    plt.figure(figsize=(10, 8))
    for i in range(len(class_names)):
        precision, recall, _ = precision_recall_curve(y_true_bin[:, i], y_prob[:, i])
        plt.plot(recall, precision, label=class_names[i])
    plt.title("Precision-Recall Curve")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("precision_recall_curve.png")
    plt.close()

y_true, y_pred, y_prob = evaluate_with_preds(alexnet, test_loader)
class_names = test_dataset.classes

plot_confusion_matrix(y_true, y_pred, class_names)
plot_roc_auc(y_true, y_prob, class_names)
plot_precision_recall(y_true, y_prob, class_names)

print("Plots saved: confusion_matrix.png, roc_auc_curve.png, precision_recall_curve.png")


Plots saved: confusion_matrix.png, roc_auc_curve.png, precision_recall_curve.png


In [3]:
from google.colab import files

files.download('confusion_matrix.png')
files.download('roc_auc_curve.png')
files.download('precision_recall_curve.png')


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>