In [6]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

from models.networks import get_model
from data_utils.data_stats import *


dataset = 'cifar10'                 # One of cifar10, cifar100, stl10, imagenet or imagenet21
architecture = 'B_12-Wi_1024'
data_resolution = 32                # Resolution of data as it is stored
crop_resolution = 64                # Resolution of fine-tuned model (64 for all models we provide)
num_classes = CLASS_DICT[dataset]
data_path = './beton/'
eval_batch_size = 1024
checkpoint = 'in21k_cifar10'        # This means you want the network pre-trained on ImageNet21k and finetuned on CIFAR10



# Set device to GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the transformation for the input data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(64),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
])

# Load the CIFAR-10 test dataset
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

# Load a pretrained ResNet model (you can choose a different model)
pretrained_model = get_model(architecture=architecture, resolution=crop_resolution, num_classes=CLASS_DICT[dataset],
                  checkpoint='in21k_cifar10')
pretrained_model = pretrained_model.to(device)
pretrained_model.eval()

# Initialize a dictionary to store per-class counts
class_counts = {class_idx: {'correct': 0, 'total': 0} for class_idx in range(num_classes)}


with torch.no_grad():
    for inputs, labels in test_loader:

        inputs = torch.reshape(inputs, (inputs.shape[0], -1))
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = pretrained_model(inputs)
        _, predicted = torch.max(outputs, 1)

        # Update counts for each class
        for class_idx in range(num_classes):
            class_mask = labels == class_idx
            class_total = class_mask.sum().item()
            class_correct = (predicted[class_mask] == class_idx).sum().item()

            class_counts[class_idx]['correct'] += class_correct
            class_counts[class_idx]['total'] += class_total

# Calculate per-class accuracy
per_class_accuracy = {class_idx: class_counts[class_idx]['correct'] / class_counts[class_idx]['total']
                      for class_idx in range(num_classes)}


# Print per-class accuracy
for class_idx in range(num_classes):
    print(f'Accuracy for class {class_idx}: {100 * per_class_accuracy[class_idx]:.2f}%')


#0: airplanes, 1: cars, 2: birds, 3: cats, 4: deer, 5: dogs, 6: frogs, 7: horses, 8: ships, 9: trucks

Files already downloaded and verified
Weights already downloaded
Load_state output <All keys matched successfully>
Accuracy for class 0: 96.20%
Accuracy for class 1: 94.60%
Accuracy for class 2: 94.80%
Accuracy for class 3: 86.80%
Accuracy for class 4: 96.10%
Accuracy for class 5: 88.20%
Accuracy for class 6: 95.90%
Accuracy for class 7: 95.20%
Accuracy for class 8: 96.60%
Accuracy for class 9: 96.50%


In [None]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

from models.networks import get_model
from data_utils.data_stats import *

# ... (the rest of your imports and settings remain unchanged)

# Initialize a dictionary to store per-class counts
class_counts = {class_idx: {'correct': 0, 'total': 0} for class_idx in range(num_classes)}

with torch.no_grad():
    for inputs, labels in test_loader:

        inputs = torch.reshape(inputs, (inputs.shape[0], -1))
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = pretrained_model(inputs)
        _, predicted = torch.max(outputs, 1)

        # Update counts for each class
        for class_idx in range(num_classes):
            class_mask = labels == class_idx
            class_total = class_mask.sum().item()
            class_correct = (predicted[class_mask] == class_idx).sum().item()

            class_counts[class_idx]['correct'] += class_correct
            class_counts[class_idx]['total'] += class_total

# Calculate per-class accuracy
per_class_accuracy = {class_idx: class_counts[class_idx]['correct'] / class_counts[class_idx]['total']
                      for class_idx in range(num_classes)}

# Calculate overall accuracy
overall_accuracy = correct_predictions / total_samples

# Print per-class accuracy
for class_idx in range(num_classes):
    print(f'Accuracy for class {class_idx}: {100 * per_class_accuracy[class_idx]:.2f}%')

# Print overall accuracy
print(f'Overall accuracy on CIFAR-10: {100 * overall_accuracy:.2f}%')


In [8]:
for i in range(20):
    print(test_dataset[i][1])

3
8
8
0
6
6
1
6
3
1
0
9
5
7
9
8
5
7
8
6


In [9]:
import matplotlib.pyplot as plt

def norm_01(array):
    return(array-np.min(array))/(np.max(array)-np.min(array))

def show_im(batch):
    img = batch.reshape(3,32,32).permute(1,2,0)
    img_np = img.detach().numpy()
    plt.imshow(norm_01(img_np))

show_im(test_dataset[12][0])



RuntimeError: shape '[3, 32, 32]' is invalid for input of size 12288

In [10]:
print(test_dataset[12][0].shape)

torch.Size([3, 64, 64])


