In [None]:
import torch
from torchvision import datasets, transforms
from torchvision.models import vit_h_14, ViT_H_14_Weights
from torch.utils.data import DataLoader


from helpers.helpers import set_seed


set_seed(42)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
weights = ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1  # Using LINEAR for efficiency
model = vit_h_14(weights=weights)

num_ftrs = model.heads.head.in_features
model.heads.head = torch.nn.Linear(num_ftrs, 10)

In [None]:
MODEL_SAVE_PATH = './saved_models/25_top/best_model_lf_0.01.pth'
model.load_state_dict(torch.load(MODEL_SAVE_PATH))
# model.to(device)
model.eval()

In [None]:
eval_transform = weights.transforms()
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=eval_transform)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [None]:
# Evaluate the model
test_correct = 0
test_total = 0
with torch.no_grad():
    for images, labels in test_loader:
        # images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()

final_accuracy = 100 * test_correct / test_total
print(f"\n🎉 Final Accuracy of the best model on the test set: {final_accuracy:.2f}%")