In [1]:
# Import required packages from Pytorch
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F


# fix random seed for reproducibility
import numpy as np
np.random.seed(7)

# moves your model to train on your gpu if available else it uses your cpu
device = ("cuda" if torch.cuda.is_available() else "cpu")
print(device)


cuda


In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import models, transforms, datasets
from torch.nn import functional as F
from torch import nn
import cv2
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

# Check the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Prepare transformations for CIFAR-10 with data augmentation
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    #transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load CIFAR-10 dataset
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transforms)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transforms)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)

class_names = trainset.classes

# Initialize models on the correct device
squeezenet = models.squeezenet1_0(weights=models.SqueezeNet1_0_Weights.DEFAULT).to(device)
vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).to(device)
resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT).to(device)
alexnet = models.alexnet(weights=models.AlexNet_Weights.DEFAULT).to(device)

# Modify models for CIFAR-10
squeezenet.classifier[1] = nn.Conv2d(512, 10, kernel_size=1).to(device)
vgg.classifier[6] = nn.Linear(4096, 10).to(device)
resnet.fc = nn.Linear(resnet.fc.in_features, 10).to(device)
alexnet.classifier[6] = nn.Linear(4096, 10).to(device)

models_dict = {
    "SqueezeNet": squeezenet,
    "VGG16": vgg,
    "ResNet18": resnet,
    "AlexNet": alexnet
}

# Training function
def train_model(model, trainloader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(trainloader):.4f}")

# Train models
criterion = nn.CrossEntropyLoss()
epochs = 3

for name, model in models_dict.items():
    print(f"\nFine-tuning {name}...")
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    train_model(model, trainloader, criterion, optimizer, epochs)

# Function to evaluate model performance
def evaluate_model(model, testloader):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')
    conf_matrix = confusion_matrix(all_labels, all_preds)

    print(f"Accuracy: {acc:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")

    plt.figure(figsize=(6, 4))
    plt.imshow(conf_matrix, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(f"Confusion Matrix - {model.__class__.__name__}")
    plt.colorbar()
    plt.show()

# Evaluate each model
for name, model in models_dict.items():
    print(f"\nEvaluating {name}...")
    evaluate_model(model, testloader)

# Function to generate Grad-CAM
def generate_gradcam(model, img_tensor, target_layer):
    def hook_fn(module, input, output):
        model_output.append(output)

    model_output = []
    handle = target_layer.register_forward_hook(hook_fn)
    output = model(img_tensor.to(device))
    handle.remove()

    pred_class = output.argmax(dim=1).item()
    pred_score = F.softmax(output, dim=1)[0][pred_class].item()

    grad_output = torch.autograd.grad(output[:, pred_class], model_output[0])[0]
    weights = grad_output.mean(dim=(2, 3), keepdim=True)

    gradcam = torch.sum(weights * model_output[0], dim=1).squeeze()
    gradcam = F.relu(gradcam)

    if gradcam.max() > 0:
        gradcam = gradcam / gradcam.max()

    return gradcam.cpu(), pred_class, pred_score

# Visualizing Grad-CAM
def visualize_gradcam_on_image(img_tensor, model, target_layer):
    gradcam, pred_class, pred_score = generate_gradcam(model, img_tensor, target_layer)

    gradcam_resized = cv2.resize(gradcam.detach().numpy(), (224, 224))
    heatmap = cv2.applyColorMap(np.uint8(255 * gradcam_resized), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255

    img = img_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
    img = np.clip(img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406], 0, 1)

    cam_img = heatmap + np.float32(img)
    cam_img = cam_img / np.max(cam_img)

    return cam_img, pred_class, pred_score

# Display Grad-CAM with predictions
def display_gradcam_for_models(images, models_dict, class_names):
    fig, axes = plt.subplots(len(images), len(models_dict) * 3, figsize=(len(models_dict) * 6, len(images) * 3))

    for img_idx, img_tensor in enumerate(images):
        img = img_tensor.permute(1, 2, 0).cpu().numpy()
        img = np.clip(img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406], 0, 1)

        for model_idx, (model_name, model) in enumerate(models_dict.items()):
            if "SqueezeNet" in model_name:
                target_layer = model.features[12]
            elif "VGG16" in model_name:
                target_layer = model.features[-1]
            elif "ResNet18" in model_name:
                target_layer = model.layer4[-1]
            elif "AlexNet" in model_name:
                target_layer = model.features[-1]
            else:
                raise ValueError(f"Unsupported model {model_name}.")

            cam_img, pred_class, pred_score = visualize_gradcam_on_image(img_tensor.unsqueeze(0).to(device), model, target_layer)

            col_idx = model_idx * 3
            axes[img_idx, col_idx].imshow(img)
            axes[img_idx, col_idx].axis('off')
            axes[img_idx, col_idx].set_title(f"{model_name}\nTrue: {class_names[labels[img_idx]]}", fontsize=10)

            axes[img_idx, col_idx + 1].imshow(cam_img)
            axes[img_idx, col_idx + 1].axis('off')
            axes[img_idx, col_idx + 1].set_title(f"Predicted: {class_names[pred_class]} (P={pred_score:.2f})", fontsize=10)

            axes[img_idx, col_idx + 2].imshow(img, alpha=0.5)
            axes[img_idx, col_idx + 2].imshow(cam_img, alpha=0.7)
            axes[img_idx, col_idx + 2].axis('off')
            axes[img_idx, col_idx + 2].set_title("Overlay", fontsize=10)

    plt.subplots_adjust(wspace=0.3, hspace=0.4)
    plt.show()

# Test on a batch
images, labels = next(iter(testloader))
display_gradcam_for_models(images[:5], models_dict, class_names)


Fine-tuning SqueezeNet...
Epoch [1/3], Loss: 1.6566
Epoch [2/3], Loss: 1.2102
Epoch [3/3], Loss: 1.0144

Fine-tuning VGG16...


KeyboardInterrupt: 