# # CIFAR-10 CNN Classifier - Comprehensive Evaluation

# ## 1. Import Required Libraries

In [None]:
import torch
import torchvision
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve
from sklearn.preprocessing import label_binarize
import pandas as pd

# ## 2. Helper Functions for Visualization

In [None]:
def plot_confusion_matrix(y_true, y_pred, classes):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

def plot_roc_auc(y_true, y_score, num_classes, classes):
    y_true_bin = label_binarize(y_true, classes=np.arange(num_classes))
    fpr, tpr, roc_auc = {}, {}, {}

    plt.figure(figsize=(10, 8))
    for i in range(num_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_score[:, i])
        roc_auc[i] = roc_auc_score(y_true_bin[:, i], y_score[:, i])
        plt.plot(fpr[i], tpr[i], label=f'{classes[i]} (AUC={roc_auc[i]:.2f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve for Each Class')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.show()

def plot_metric_bar(metrics_df, metric_name):
    plt.figure(figsize=(10, 6))
    sns.barplot(x=metrics_df.index, y=metric_name, data=metrics_df)
    plt.title(f'{metric_name} per Class')
    plt.xticks(rotation=45)
    plt.ylim(0, 1)
    for i, v in enumerate(metrics_df[metric_name]):
        plt.text(i, v + 0.02, f"{v:.2f}", ha='center')
    plt.show()

def plot_training_history(train_losses, train_accuracies):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    ax1.plot(train_losses, label='Training Loss')
    ax1.set_title('Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    
    ax2.plot(train_accuracies, label='Training Accuracy')
    ax2.set_title('Training Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    
    plt.show()

# ## 3. CNN Model Definition

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(-1, 128 * 4 * 4)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# ## 4. Data Loading and Transformation

In [None]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
classes = train_dataset.classes

# ## 5. Model Training

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN(num_classes=10).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

train_losses = []
train_accuracies = []

for epoch in range(10):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}'):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    epoch_loss = total_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_acc)
    print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")
    # Plot training history
plot_training_history(train_losses, train_accuracies)

# ## 6. Model Evaluation

In [None]:
model.eval()
all_labels, all_preds, all_probs = [], [], []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)
        _, predicted = outputs.max(1)
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

# Classification report
report = classification_report(all_labels, all_preds, target_names=classes, digits=4, output_dict=True)
metrics_df = pd.DataFrame(report).transpose().iloc[:-3, :]

# ## 7. Confusion Matrix Visualization

In [None]:
plot_confusion_matrix(all_labels, all_preds, classes)

# ## 8. Accuracy Visualization

In [None]:
plt.figure(figsize=(10, 6))
plt.bar(['Overall Accuracy'], [report['accuracy']])
plt.ylim(0, 1)
plt.title('Overall Model Accuracy')
plt.ylabel('Accuracy')
for i, v in enumerate([report['accuracy']]):
    plt.text(i, v + 0.02, f"{v:.4f}", ha='center')
plt.show()

# ## 9. Precision Visualization

In [None]:
plot_metric_bar(metrics_df, 'precision')

# ## 10. Recall (Sensitivity) Visualization

In [None]:
plot_metric_bar(metrics_df, 'recall')

# ## 11. F1-Score Visualization

In [None]:
plot_metric_bar(metrics_df, 'f1-score')

# ## 12. Support Visualization

In [None]:
plt.figure(figsize=(10, 6))
sns.barplot(x=metrics_df.index, y='support', data=metrics_df)
plt.title('Number of Samples per Class (Support)')
plt.xticks(rotation=45)
for i, v in enumerate(metrics_df['support']):
    plt.text(i, v + 20, str(v), ha='center')
plt.show()

# ## 13. ROC Curve & AUC Scores

In [None]:
plot_roc_auc(all_labels, np.array(all_probs), num_classes=10, classes=classes)

# ## 14. Detailed Classification Report

In [None]:
print(classification_report(all_labels, all_preds, target_names=classes, digits=4))