In [3]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset
from vgg_models.vgg import vgg13_bn

# Set device to GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- Load the CIFAR-10 test dataset -------------

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
])

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

#--------- Load the permuted dataset ----------

class CustomDataset(Dataset):
    def __init__(self, dataset_dict, transform = None):
        self.dataset_dict = dataset_dict
        self.transform = transform

    def __len__(self):
        return len(self.dataset_dict)

    def __getitem__(self, index):
        data = self.dataset_dict[index]
        image = data['image']
        if self.transform:
            image = self.transform(data['image'])
        label = data['label']

        return image, label

# Load the dataset dictionary
loaded_dataset = torch.load('cifar10_permuted_block_size_8.pth')

# Create an instance of the CustomDataset
permuted_dataset = CustomDataset(loaded_dataset, transform = transform)
permuted_loader = DataLoader(permuted_dataset, batch_size=1024, shuffle=False)

# ------------- Load the pretrained model ------------
pretrained_model = vgg13_bn(pretrained=True)
pretrained_model = pretrained_model.to(device)
pretrained_model.eval()


# -----------calculate accuracies -----------

def accuracy_calc(loader, classes, dataset):
    num_classes = len(classes)
    # Initialize a dictionary to store per-class counts
    class_counts = {class_idx: {'correct': 0, 'total': 0} for class_idx in range(num_classes)}

    with torch.no_grad():
        for inputs, labels in loader:

            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = pretrained_model(inputs)
            _, predicted = torch.max(outputs, 1)

            # Update counts for each class
            for class_idx in range(num_classes):
                class_mask = labels == class_idx
                class_total = class_mask.sum().item()
                class_correct = (predicted[class_mask] == class_idx).sum().item()

                class_counts[class_idx]['correct'] += class_correct
                class_counts[class_idx]['total'] += class_total

    # Calculate per-class accuracy
    per_class_accuracy = {class_idx: class_counts[class_idx]['correct'] / class_counts[class_idx]['total']
                        for class_idx in range(num_classes)}

    # Print per-class accuracy
    total_accuracy = 0
    for class_idx in range(num_classes):
        print(f'{dataset} accuracy for {classes[class_idx]}: {100 * per_class_accuracy[class_idx]:.2f}%') 
        total_accuracy += per_class_accuracy[class_idx]

    print(f'{dataset} total accuracy: {100 * total_accuracy/10:.2f}%')

classes  = ['Airplane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
accuracy_calc(test_loader, classes, 'CIFAR10')
accuracy_calc(permuted_loader, classes, 'Permuted CIFAR10')


Files already downloaded and verified
CIFAR10 accuracy: 96.00%
CIFAR10 accuracy: 96.80%
CIFAR10 accuracy: 91.70%
CIFAR10 accuracy: 87.90%
CIFAR10 accuracy: 95.50%
CIFAR10 accuracy: 90.00%
CIFAR10 accuracy: 95.50%
CIFAR10 accuracy: 96.10%
CIFAR10 accuracy: 96.70%
CIFAR10 accuracy: 95.90%
CIFAR10 accuracy: 94.21%
Permuted CIFAR10 accuracy: 60.10%
Permuted CIFAR10 accuracy: 7.70%
Permuted CIFAR10 accuracy: 14.20%
Permuted CIFAR10 accuracy: 42.70%
Permuted CIFAR10 accuracy: 36.30%
Permuted CIFAR10 accuracy: 7.40%
Permuted CIFAR10 accuracy: 24.40%
Permuted CIFAR10 accuracy: 26.60%
Permuted CIFAR10 accuracy: 24.70%
Permuted CIFAR10 accuracy: 81.70%
Permuted CIFAR10 accuracy: 32.58%
