In [25]:
import pickle
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm
import torch.nn as nn
from src.model import vit_tiny_patch16_224

In [26]:
trained_model = vit_tiny_patch16_224(num_classes=10)
checkpoint = torch.load('best_model.pth', map_location="cpu", weights_only=False)
trained_model.load_state_dict(checkpoint['model_state_dict'])
trained_model.eval()

VisionTransformer(
  (patch_embed): PatchEmbedding(
    (projection): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_dropout): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x TransformerEncoder(
      (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_dropout): Dropout(p=0.1, inplace=False)
      )
      (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (mlp): FeedForward(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (gelu): GELU(approximate='none')
        (dropout1): Dropout(p=0.1, inplace=False)
        (fc2): Linear(in_features=768, out_features=192, bias=True)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (norm): LayerNor

In [27]:
class CIFAR10DATASET(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        Args:
            data_dir: Path to CIFAR-10 data directory
            transform: Optional transform to apply to images
        """
        self.data_dir = data_dir
        self.transform = transform

        # Load data and labels
        self.data = []
        self.labels = []

        # Load test batch
        test_file = os.path.join(data_dir, 'test_batch')
        with open(test_file, 'rb') as f:
            test_dict = pickle.load(f, encoding='bytes')
        self.data = test_dict[b'data']  # Shape: (10000, 3072)
        self.labels = test_dict[b'labels']

        # Reshape data from (N, 3072) to (N, 3, 32, 32)
        self.data = self.data.reshape(-1, 3, 32, 32)
        # Convert from [0, 255] to [0, 1]
        self.data = self.data.astype(np.float32) / 255.0

        # Load label names
        meta_file = os.path.join(data_dir, 'batches.meta')
        with open(meta_file, 'rb') as f:
            meta_dict = pickle.load(f, encoding='bytes')
        self.label_names = [name.decode('utf-8') for name in meta_dict[b'label_names']]

        print(f"Loaded test data: {len(self.data)} images")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = torch.from_numpy(self.data[idx])  # Shape: (3, 32, 32)
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [61]:
class CIFAR100DATASET(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        Args:
            data_dir: Path to CIFAR-10 data directory
            transform: Optional transform to apply to images
        """
        self.data_dir = data_dir
        self.transform = transform

        # Load data and labels
        self.data = []
        self.labels = []

        # Load test batch
        test_file = os.path.join(data_dir, 'test')
        with open(test_file, 'rb') as f:
            test_dict = pickle.load(f, encoding='bytes')
        self.data = test_dict[b'data']  # Shape: (10000, 3072)
        self.labels = test_dict[b'coarse_labels']

        # Reshape data from (N, 3072) to (N, 3, 32, 32)
        self.data = self.data.reshape(-1, 3, 32, 32)
        # Convert from [0, 255] to [0, 1]
        self.data = self.data.astype(np.float32) / 255.0

        # Load label names
        meta_file = os.path.join(data_dir, 'meta')
        with open(meta_file, 'rb') as f:
            meta_dict = pickle.load(f, encoding='bytes')
        print(meta_dict.keys())
        self.label_names = [name.decode('utf-8') for name in meta_dict[b'coarse_label_names']]

        print(f"Loaded test data: {len(self.data)} images")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = torch.from_numpy(self.data[idx])  # Shape: (3, 32, 32)
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [68]:
def validate(model, val_loader, criterion, device):
    """Validate the model."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        pbar = tqdm(val_loader)
        for batch_idx, (images, labels) in enumerate(pbar):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # Update progress bar
            pbar.set_postfix({
                'loss': running_loss / (batch_idx + 1),
                'acc': 100. * correct / total
            })

    epoch_loss = running_loss / len(val_loader)
    epoch_acc = 100. * correct / total

    return epoch_loss, epoch_acc

In [69]:
dataset = CIFAR100DATASET('cifar-100', transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.Normalize(
                mean=[0.4914, 0.4822, 0.4465],
                std=[0.2470, 0.2435, 0.2616]
            ),
        ]))

dict_keys([b'fine_label_names', b'coarse_label_names'])
Loaded test data: 10000 images


In [70]:
criterion = nn.CrossEntropyLoss()

In [71]:
val_loader =  DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        pin_memory=torch.cuda.is_available()
    )

In [72]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [73]:
val_loss, val_acc = validate(trained_model, val_loader, criterion, device)

  0%|          | 0/10000 [00:00<?, ?it/s]

tensor([[[[ 1.1700,  1.1700,  1.1700,  ...,  1.5669,  1.5669,  1.5669],
          [ 1.1700,  1.1700,  1.1700,  ...,  1.5669,  1.5669,  1.5669],
          [ 1.1700,  1.1700,  1.1700,  ...,  1.5669,  1.5669,  1.5669],
          ...,
          [-0.7193, -0.7193, -0.7193,  ...,  0.8048,  0.8048,  0.8048],
          [-0.7193, -0.7193, -0.7193,  ...,  0.8048,  0.8048,  0.8048],
          [-0.7193, -0.7193, -0.7193,  ...,  0.8048,  0.8048,  0.8048]],

         [[ 1.4823,  1.4823,  1.4823,  ...,  1.7883,  1.7883,  1.7883],
          [ 1.4823,  1.4823,  1.4823,  ...,  1.7883,  1.7883,  1.7883],
          [ 1.4823,  1.4823,  1.4823,  ...,  1.7883,  1.7883,  1.7883],
          ...,
          [-0.6114, -0.6114, -0.6114,  ...,  0.6609,  0.6609,  0.6609],
          [-0.6114, -0.6114, -0.6114,  ...,  0.6609,  0.6609,  0.6609],
          [-0.6114, -0.6114, -0.6114,  ...,  0.6609,  0.6609,  0.6609]],

         [[ 2.0259,  2.0259,  2.0259,  ...,  2.0709,  2.0709,  2.0709],
          [ 2.0259,  2.0259,  




IndexError: Target 10 is out of bounds.

In [34]:
print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")

  Val Loss: 0.9240 | Val Acc: 83.02%
