In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import itertools

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [2]:
# Define the model architecture (needs to be exactly the same as the saved model)
class ModifiedVGG16(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedVGG16, self).__init__()
        # Load pretrained VGG16
        vgg16 = torchvision.models.vgg16(pretrained=True)
        
        # Keep all layers except the classifier
        self.features = vgg16.features
        
        # Freeze the feature layers
        for param in self.features.parameters():
            param.requires_grad = False
            
        # Create new classifier layers
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 1024),  # Modified layer
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(1024, num_classes)  # Modified layer
        )
        
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


In [3]:

# Load and preprocess test dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # VGG16 expects 224x224 images
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_dataset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=False,
    download=True, 
    transform=transform
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=32,
    shuffle=False
)

Files already downloaded and verified


In [4]:
# Create a new model instance
loaded_model = ModifiedVGG16(num_classes=10).to(device)
save_path = 'modified_vgg16_cifar10.pth'

# Load the saved state
checkpoint = torch.load(save_path)
loaded_model.load_state_dict(checkpoint['model_state_dict'])
loaded_model.eval()



ModifiedVGG16(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, 

In [5]:
# Initialize metrics
correct = 0
total = 0
class_correct = [0] * 10
class_total = [0] * 10
all_preds = []
all_labels = []

# Test the model
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = loaded_model(images)
        _, predicted = torch.max(outputs.data, 1)
        
        # Overall accuracy
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # Per-class accuracy
        c = (predicted == labels).squeeze()
        for i in range(labels.size(0)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1
        
        # Store predictions and labels for confusion matrix
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Calculate overall accuracy
accuracy = 100 * correct / total
print(f'Overall Test Accuracy: {accuracy:.2f}%')

Overall Test Accuracy: 87.05%
