In [1]:
import pandas as pd
import torch
from torchvision import datasets, transforms
from torchvision.models import vit_h_14, ViT_H_14_Weights
from torch.utils.data import DataLoader

In [2]:
from helpers.helpers import set_seed
from helpers.sae import SparseAutoencoder, train_sae_on_layer, evaluate_sae_with_probe

In [3]:
# --- 0. For reproducibility & Configuration ---
set_seed(42)
MODEL_SAVE_PATH = './classifiers/baseline/vit_h_99.56.pth'
IMG_RES = 384
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
weights = ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1
# weights = ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1
model = vit_h_14(weights=None) 

In [5]:
# Update model's image size and positional embeddings
model.image_size = IMG_RES  # Update the expected image size
patch_size = model.patch_size  # Should be 14 for vit_h_14
num_patches = (IMG_RES // patch_size) ** 2  # 729 for 384x384 with 14x14 patches

In [6]:
# Interpolate positional embeddings
orig_pos_embed = model.encoder.pos_embedding  # Shape: [1, 257, 1280] (257 = 1 cls token + 256 patches)
print(f"Original pos_embed shape: {orig_pos_embed.shape}")

Original pos_embed shape: torch.Size([1, 257, 1280])


In [7]:
# Extract the embedding dimension
embed_dim = orig_pos_embed.shape[-1]  # 1280
num_orig_patches = orig_pos_embed.shape[1] - 1  # 1369 patches (exclude class token)
orig_grid_size = int(num_orig_patches ** 0.5)  # 37 for 1369 patches (37x37 grid)
new_grid_size = int(num_patches ** 0.5)  # 27 for 729 patches (27x27 grid)

In [8]:
# Extract the positional embeddings (excluding class token)
pos_embed = orig_pos_embed[:, 1:, :]  # Shape: [1, 1369, 1280]
pos_embed = pos_embed.reshape(1, orig_grid_size, orig_grid_size, embed_dim)  # Reshape to [1, 37, 37, 1280]

# Interpolate to new grid size
pos_embed = torch.nn.functional.interpolate(
    pos_embed.permute(0, 3, 1, 2),  # [1, 1280, 37, 37]
    size=(new_grid_size, new_grid_size),  # Interpolate to [1, 1280, 27, 27]
    mode='bilinear',
    align_corners=False
)
pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(1, num_patches, embed_dim)  # [1, 729, 1280]

In [9]:
# Combine with class token embedding
cls_token_embed = orig_pos_embed[:, :1, :]  # [1, 1, 1280]
new_pos_embed = torch.cat([cls_token_embed, pos_embed], dim=1)  # [1, 730, 1280]

# Update model's positional embeddings
model.encoder.pos_embedding = torch.nn.Parameter(new_pos_embed)

In [10]:
num_ftrs = model.heads.head.in_features
model.heads.head = torch.nn.Linear(num_ftrs, 10)
model.load_state_dict(torch.load(MODEL_SAVE_PATH))

<All keys matched successfully>

In [11]:
model = model.to(device)
model.eval()
print(f"Successfully loaded model from {MODEL_SAVE_PATH} to device: {device}")

Successfully loaded model from ./classifiers/baseline/vit_h_99.56.pth to device: cuda


In [12]:
eval_transform = transforms.Compose([
        transforms.Resize((IMG_RES, IMG_RES)),  # Ensure 384x384 for validation
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=eval_transform)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

Files already downloaded and verified


In [13]:
# Evaluate the model
test_correct = 0
test_total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()

final_accuracy = 100 * test_correct / test_total
print(f"\n🎉 Final Accuracy of the best model on the test set: {final_accuracy:.2f}%")


🎉 Final Accuracy of the best model on the test set: 99.56%


In [14]:
torch.save(new_pos_embed, 'pos_embed_edge_384_99.56.pth')