<a href="https://colab.research.google.com/github/11061995/COE-523-Assignment1/blob/main/efficientvit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Step 1: Mounting data & repo

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!rm -rf Cream  # Remove existing directory if needed
!git clone https://github.com/microsoft/Cream.git
!ls
#cd Cream/EfficientViT/classification

Cloning into 'Cream'...
remote: Enumerating objects: 2175, done.[K
remote: Counting objects: 100% (265/265), done.[K
remote: Compressing objects: 100% (76/76), done.[K
remote: Total 2175 (delta 201), reused 189 (delta 189), pack-reused 1910 (from 2)[K
Receiving objects: 100% (2175/2175), 8.81 MiB | 25.34 MiB/s, done.
Resolving deltas: 100% (831/831), done.
Cream  data  drive  sample_data


In [None]:
import torch
from torchvision import datasets, transforms

# Define transforms (CIFAR-10 images are 32x32 pixels)
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
])

# Download and load CIFAR-10 datasets
train_dataset = datasets.CIFAR10(
    root='./data',   # Dataset will be saved to /content/data
    train=True,      # Load training set
    download=True,   # Auto-download if not found
    transform=train_transform
)

val_dataset = datasets.CIFAR10(
    root='./data',   # Same root path
    train=False,     # Load validation/test set
    download=True,
    transform=val_transform
)

print("Training samples:", len(train_dataset))    # 50,000 images
print("Validation samples:", len(val_dataset))   # 10,000 images

Training samples: 50000
Validation samples: 10000


### Step 2:Reorganize the validation set to match the train folder structure:

In [None]:
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define transforms
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
])

# Load datasets
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=val_transform)

# Create data loaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

### Step 3: Modify the Configuration File

Edit the config file (configs/efficientvit.yaml) to match your dataset:

In [None]:
# ✅ Correct (Python dictionary)
config = {
    "model": {
        "type": "efficientvit",
        "input_size": 32  # For CIFAR-10 (32x32 images)
    },
    "data": {
        "name": "cifar10",
        "num_classes": 10
    }
}

In [None]:
%cd Cream/EfficientViT/classification/model/
!ls  # Verify efficientvit.py exists

/content/Cream/EfficientViT/classification/model
build.py  efficientvit.py  __init__.py


In [None]:
cd classification/

[Errno 2] No such file or directory: 'classification/'
/content/Cream/EfficientViT/classification/model


### Step 4: install requirement

Modify the dataset loading in train.py (or create a new file datasets/cifar10.py)

In [None]:
!pip3 install torch torchvision timm einops fvcore easydict matplotlib numpy yacs scikit-image pillow



### Step 5: Adjust Model Architecture

In models/efficientvit.py, ensure the final classification head matches your dataset classes

In [None]:
!pip install efficientvit
from efficientvit import EfficientViT

# Initialize model with PROPER parameters
model = EfficientViT(
    img_size=32,       # ✅ Correct parameter name (not image_size)
    num_classes=10,      # CIFAR-10 has 10 classes
    patch_size=4,        # 32x32 → 32/4=8 patches per dimension
    embed_dim=[64, 128, 192],  # Reduced dimensions for CIFAR-10
    window_size=[6, 6, 6],     # Smaller window size for 32x32 input
    kernels=[5, 5, 5, 5]
)



[31mERROR: Could not find a version that satisfies the requirement efficientvit (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for efficientvit[0m[31m
[0m



In [None]:
model = EfficientViT(img_size=32, num_classes=10)
x = torch.randn(2, 3, 32, 32)  # Test batch
print(model(x).shape)  # Should output [2, 10]

torch.Size([2, 10])


### Step 6: Start Training

In [None]:
# Change to the correct directory
%cd /content/Cream/EfficientViT/classification/

/content/Cream/EfficientViT/classification


In [None]:
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

####Create a training loop

In [None]:
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

def train(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in tqdm(loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_loss = running_loss / len(loader)
    train_acc = 100. * correct / total
    return train_loss, train_acc

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in tqdm(loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    val_loss = running_loss / len(loader)
    val_acc = 100. * correct / total
    return val_loss, val_acc

####Run the training

In [None]:
num_epochs = 50
best_acc = 0

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    # Train
    train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)

    # Validate
    val_loss, val_acc = validate(model, val_loader, criterion, device)

    # Update scheduler
    scheduler.step()

    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")

    # Save best model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), 'best_model.pth')
        print("Saved best model!")

    print("-" * 50)

Epoch 1/50


100%|██████████| 391/391 [00:50<00:00,  7.75it/s]
100%|██████████| 79/79 [00:04<00:00, 17.85it/s]


Train Loss: 1.6996 | Train Acc: 37.28%
Val Loss: 1.4624 | Val Acc: 46.33%
Saved best model!
--------------------------------------------------
Epoch 2/50


100%|██████████| 391/391 [00:45<00:00,  8.51it/s]
100%|██████████| 79/79 [00:05<00:00, 14.67it/s]


Train Loss: 1.4531 | Train Acc: 46.92%
Val Loss: 1.3487 | Val Acc: 50.93%
Saved best model!
--------------------------------------------------
Epoch 3/50


100%|██████████| 391/391 [00:47<00:00,  8.17it/s]
100%|██████████| 79/79 [00:04<00:00, 19.62it/s]


Train Loss: 1.3458 | Train Acc: 50.98%
Val Loss: 1.2339 | Val Acc: 55.13%
Saved best model!
--------------------------------------------------
Epoch 4/50


100%|██████████| 391/391 [00:44<00:00,  8.78it/s]
100%|██████████| 79/79 [00:03<00:00, 20.60it/s]


Train Loss: 1.2894 | Train Acc: 53.06%
Val Loss: 1.2092 | Val Acc: 56.95%
Saved best model!
--------------------------------------------------
Epoch 5/50


100%|██████████| 391/391 [00:43<00:00,  8.94it/s]
100%|██████████| 79/79 [00:04<00:00, 16.42it/s]


Train Loss: 1.2423 | Train Acc: 55.39%
Val Loss: 1.1688 | Val Acc: 57.84%
Saved best model!
--------------------------------------------------
Epoch 6/50


100%|██████████| 391/391 [00:43<00:00,  8.97it/s]
100%|██████████| 79/79 [00:03<00:00, 21.00it/s]


Train Loss: 1.2039 | Train Acc: 56.84%
Val Loss: 1.1743 | Val Acc: 58.37%
Saved best model!
--------------------------------------------------
Epoch 7/50


100%|██████████| 391/391 [00:43<00:00,  8.95it/s]
100%|██████████| 79/79 [00:05<00:00, 15.56it/s]


Train Loss: 1.1723 | Train Acc: 57.85%
Val Loss: 1.1007 | Val Acc: 61.14%
Saved best model!
--------------------------------------------------
Epoch 8/50


100%|██████████| 391/391 [00:43<00:00,  9.04it/s]
100%|██████████| 79/79 [00:03<00:00, 20.78it/s]


Train Loss: 1.1458 | Train Acc: 58.95%
Val Loss: 1.0728 | Val Acc: 62.00%
Saved best model!
--------------------------------------------------
Epoch 9/50


100%|██████████| 391/391 [00:43<00:00,  9.02it/s]
100%|██████████| 79/79 [00:05<00:00, 15.71it/s]


Train Loss: 1.1165 | Train Acc: 59.95%
Val Loss: 1.0512 | Val Acc: 62.71%
Saved best model!
--------------------------------------------------
Epoch 10/50


100%|██████████| 391/391 [00:43<00:00,  8.99it/s]
100%|██████████| 79/79 [00:03<00:00, 20.93it/s]


Train Loss: 1.0916 | Train Acc: 60.85%
Val Loss: 1.0567 | Val Acc: 62.52%
--------------------------------------------------
Epoch 11/50


100%|██████████| 391/391 [00:42<00:00,  9.15it/s]
100%|██████████| 79/79 [00:04<00:00, 15.93it/s]


Train Loss: 1.0780 | Train Acc: 61.63%
Val Loss: 1.0234 | Val Acc: 64.01%
Saved best model!
--------------------------------------------------
Epoch 12/50


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
100%|██████████| 79/79 [00:03<00:00, 20.66it/s]


Train Loss: 1.0572 | Train Acc: 62.12%
Val Loss: 1.0250 | Val Acc: 63.93%
--------------------------------------------------
Epoch 13/50


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
100%|██████████| 79/79 [00:04<00:00, 16.10it/s]


Train Loss: 1.0454 | Train Acc: 62.69%
Val Loss: 0.9854 | Val Acc: 65.34%
Saved best model!
--------------------------------------------------
Epoch 14/50


100%|██████████| 391/391 [00:43<00:00,  9.03it/s]
100%|██████████| 79/79 [00:03<00:00, 21.23it/s]


Train Loss: 1.0229 | Train Acc: 63.65%
Val Loss: 1.0034 | Val Acc: 64.72%
--------------------------------------------------
Epoch 15/50


100%|██████████| 391/391 [00:43<00:00,  9.03it/s]
100%|██████████| 79/79 [00:05<00:00, 15.54it/s]


Train Loss: 1.0072 | Train Acc: 64.01%
Val Loss: 0.9551 | Val Acc: 66.52%
Saved best model!
--------------------------------------------------
Epoch 16/50


100%|██████████| 391/391 [00:42<00:00,  9.12it/s]
100%|██████████| 79/79 [00:03<00:00, 21.46it/s]


Train Loss: 0.9930 | Train Acc: 64.53%
Val Loss: 0.9408 | Val Acc: 66.94%
Saved best model!
--------------------------------------------------
Epoch 17/50


100%|██████████| 391/391 [00:44<00:00,  8.82it/s]
100%|██████████| 79/79 [00:04<00:00, 15.94it/s]


Train Loss: 0.9789 | Train Acc: 65.04%
Val Loss: 0.9713 | Val Acc: 66.13%
--------------------------------------------------
Epoch 18/50


100%|██████████| 391/391 [00:44<00:00,  8.76it/s]
100%|██████████| 79/79 [00:03<00:00, 20.58it/s]


Train Loss: 0.9610 | Train Acc: 65.90%
Val Loss: 0.9069 | Val Acc: 67.61%
Saved best model!
--------------------------------------------------
Epoch 19/50


100%|██████████| 391/391 [00:45<00:00,  8.62it/s]
100%|██████████| 79/79 [00:03<00:00, 20.75it/s]


Train Loss: 0.9510 | Train Acc: 66.04%
Val Loss: 0.9274 | Val Acc: 67.68%
Saved best model!
--------------------------------------------------
Epoch 20/50


100%|██████████| 391/391 [00:44<00:00,  8.73it/s]
100%|██████████| 79/79 [00:04<00:00, 16.57it/s]


Train Loss: 0.9398 | Train Acc: 66.63%
Val Loss: 0.9163 | Val Acc: 68.32%
Saved best model!
--------------------------------------------------
Epoch 21/50


100%|██████████| 391/391 [00:44<00:00,  8.84it/s]
100%|██████████| 79/79 [00:03<00:00, 20.90it/s]


Train Loss: 0.9244 | Train Acc: 67.25%
Val Loss: 0.9358 | Val Acc: 67.23%
--------------------------------------------------
Epoch 22/50


100%|██████████| 391/391 [00:45<00:00,  8.60it/s]
100%|██████████| 79/79 [00:04<00:00, 17.24it/s]


Train Loss: 0.9184 | Train Acc: 67.42%
Val Loss: 0.8861 | Val Acc: 69.35%
Saved best model!
--------------------------------------------------
Epoch 23/50


100%|██████████| 391/391 [00:43<00:00,  8.95it/s]
100%|██████████| 79/79 [00:03<00:00, 20.98it/s]


Train Loss: 0.9107 | Train Acc: 67.66%
Val Loss: 0.8936 | Val Acc: 68.64%
--------------------------------------------------
Epoch 24/50


100%|██████████| 391/391 [00:44<00:00,  8.84it/s]
100%|██████████| 79/79 [00:04<00:00, 18.71it/s]


Train Loss: 0.8961 | Train Acc: 68.19%
Val Loss: 0.8938 | Val Acc: 68.53%
--------------------------------------------------
Epoch 25/50


100%|██████████| 391/391 [00:44<00:00,  8.80it/s]
100%|██████████| 79/79 [00:04<00:00, 19.72it/s]


Train Loss: 0.8896 | Train Acc: 68.38%
Val Loss: 0.8672 | Val Acc: 69.92%
Saved best model!
--------------------------------------------------
Epoch 26/50


100%|██████████| 391/391 [00:44<00:00,  8.70it/s]
100%|██████████| 79/79 [00:03<00:00, 21.08it/s]


Train Loss: 0.8743 | Train Acc: 68.66%
Val Loss: 0.8743 | Val Acc: 69.44%
--------------------------------------------------
Epoch 27/50


100%|██████████| 391/391 [00:43<00:00,  8.96it/s]
100%|██████████| 79/79 [00:04<00:00, 16.52it/s]


Train Loss: 0.8702 | Train Acc: 69.01%
Val Loss: 0.8552 | Val Acc: 70.12%
Saved best model!
--------------------------------------------------
Epoch 28/50


100%|██████████| 391/391 [00:43<00:00,  9.00it/s]
100%|██████████| 79/79 [00:03<00:00, 21.21it/s]


Train Loss: 0.8516 | Train Acc: 69.75%
Val Loss: 0.8484 | Val Acc: 70.41%
Saved best model!
--------------------------------------------------
Epoch 29/50


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
100%|██████████| 79/79 [00:04<00:00, 16.03it/s]


Train Loss: 0.8446 | Train Acc: 69.83%
Val Loss: 0.8622 | Val Acc: 70.24%
--------------------------------------------------
Epoch 30/50


100%|██████████| 391/391 [00:43<00:00,  8.99it/s]
100%|██████████| 79/79 [00:03<00:00, 20.90it/s]


Train Loss: 0.8318 | Train Acc: 70.38%
Val Loss: 0.8489 | Val Acc: 70.51%
Saved best model!
--------------------------------------------------
Epoch 31/50


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
100%|██████████| 79/79 [00:04<00:00, 15.97it/s]


Train Loss: 0.8255 | Train Acc: 70.62%
Val Loss: 0.8314 | Val Acc: 71.59%
Saved best model!
--------------------------------------------------
Epoch 32/50


100%|██████████| 391/391 [00:42<00:00,  9.11it/s]
100%|██████████| 79/79 [00:03<00:00, 21.29it/s]


Train Loss: 0.8223 | Train Acc: 70.81%
Val Loss: 0.8434 | Val Acc: 70.36%
--------------------------------------------------
Epoch 33/50


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
100%|██████████| 79/79 [00:05<00:00, 15.68it/s]


Train Loss: 0.8206 | Train Acc: 70.83%
Val Loss: 0.8110 | Val Acc: 71.86%
Saved best model!
--------------------------------------------------
Epoch 34/50


100%|██████████| 391/391 [00:43<00:00,  8.95it/s]
100%|██████████| 79/79 [00:03<00:00, 20.90it/s]


Train Loss: 0.8055 | Train Acc: 71.43%
Val Loss: 0.8239 | Val Acc: 71.23%
--------------------------------------------------
Epoch 35/50


100%|██████████| 391/391 [00:43<00:00,  9.02it/s]
100%|██████████| 79/79 [00:05<00:00, 15.46it/s]


Train Loss: 0.7951 | Train Acc: 71.50%
Val Loss: 0.8207 | Val Acc: 71.45%
--------------------------------------------------
Epoch 36/50


100%|██████████| 391/391 [00:43<00:00,  8.98it/s]
100%|██████████| 79/79 [00:03<00:00, 21.35it/s]


Train Loss: 0.7925 | Train Acc: 71.78%
Val Loss: 0.8083 | Val Acc: 71.56%
--------------------------------------------------
Epoch 37/50


100%|██████████| 391/391 [00:43<00:00,  9.07it/s]
100%|██████████| 79/79 [00:04<00:00, 16.48it/s]


Train Loss: 0.7807 | Train Acc: 72.33%
Val Loss: 0.8162 | Val Acc: 71.30%
--------------------------------------------------
Epoch 38/50


100%|██████████| 391/391 [00:44<00:00,  8.85it/s]
100%|██████████| 79/79 [00:04<00:00, 18.39it/s]


Train Loss: 0.7802 | Train Acc: 72.19%
Val Loss: 0.7979 | Val Acc: 72.36%
Saved best model!
--------------------------------------------------
Epoch 39/50


100%|██████████| 391/391 [00:45<00:00,  8.58it/s]
100%|██████████| 79/79 [00:03<00:00, 21.06it/s]


Train Loss: 0.7668 | Train Acc: 72.65%
Val Loss: 0.8018 | Val Acc: 72.07%
--------------------------------------------------
Epoch 40/50


100%|██████████| 391/391 [00:49<00:00,  7.91it/s]
100%|██████████| 79/79 [00:04<00:00, 19.27it/s]


Train Loss: 0.7630 | Train Acc: 72.96%
Val Loss: 0.8000 | Val Acc: 72.05%
--------------------------------------------------
Epoch 41/50


100%|██████████| 391/391 [00:44<00:00,  8.80it/s]
100%|██████████| 79/79 [00:04<00:00, 16.70it/s]


Train Loss: 0.7502 | Train Acc: 73.20%
Val Loss: 0.7907 | Val Acc: 72.75%
Saved best model!
--------------------------------------------------
Epoch 42/50


100%|██████████| 391/391 [00:45<00:00,  8.57it/s]
100%|██████████| 79/79 [00:03<00:00, 20.93it/s]


Train Loss: 0.7431 | Train Acc: 73.27%
Val Loss: 0.7915 | Val Acc: 72.25%
--------------------------------------------------
Epoch 43/50


100%|██████████| 391/391 [00:43<00:00,  9.02it/s]
100%|██████████| 79/79 [00:05<00:00, 14.16it/s]


Train Loss: 0.7403 | Train Acc: 73.73%
Val Loss: 0.7862 | Val Acc: 72.53%
--------------------------------------------------
Epoch 44/50


100%|██████████| 391/391 [00:42<00:00,  9.10it/s]
100%|██████████| 79/79 [00:03<00:00, 21.20it/s]


Train Loss: 0.7380 | Train Acc: 73.45%
Val Loss: 0.7897 | Val Acc: 72.79%
Saved best model!
--------------------------------------------------
Epoch 45/50


100%|██████████| 391/391 [00:43<00:00,  9.09it/s]
100%|██████████| 79/79 [00:05<00:00, 15.68it/s]


Train Loss: 0.7275 | Train Acc: 73.98%
Val Loss: 0.7851 | Val Acc: 73.23%
Saved best model!
--------------------------------------------------
Epoch 46/50


100%|██████████| 391/391 [00:43<00:00,  9.08it/s]
100%|██████████| 79/79 [00:03<00:00, 21.21it/s]


Train Loss: 0.7211 | Train Acc: 74.58%
Val Loss: 0.7739 | Val Acc: 73.27%
Saved best model!
--------------------------------------------------
Epoch 47/50


100%|██████████| 391/391 [00:43<00:00,  8.95it/s]
100%|██████████| 79/79 [00:04<00:00, 15.98it/s]


Train Loss: 0.7203 | Train Acc: 74.26%
Val Loss: 0.7784 | Val Acc: 72.88%
--------------------------------------------------
Epoch 48/50


100%|██████████| 391/391 [00:42<00:00,  9.18it/s]
100%|██████████| 79/79 [00:03<00:00, 21.12it/s]


Train Loss: 0.7076 | Train Acc: 74.93%
Val Loss: 0.7804 | Val Acc: 72.89%
--------------------------------------------------
Epoch 49/50


100%|██████████| 391/391 [00:44<00:00,  8.74it/s]
100%|██████████| 79/79 [00:04<00:00, 17.19it/s]


Train Loss: 0.7032 | Train Acc: 74.73%
Val Loss: 0.7818 | Val Acc: 72.59%
--------------------------------------------------
Epoch 50/50


100%|██████████| 391/391 [00:44<00:00,  8.76it/s]
100%|██████████| 79/79 [00:03<00:00, 21.03it/s]

Train Loss: 0.6969 | Train Acc: 75.13%
Val Loss: 0.7743 | Val Acc: 73.12%
--------------------------------------------------





### Step 7: Evaluate the Model

In [None]:
# Load best model
model.load_state_dict(torch.load('best_model.pth'))

# Final evaluation
val_loss, val_acc = validate(model, val_loader, criterion, device)
print(f"Final Validation Accuracy: {val_acc:.2f}%")

100%|██████████| 313/313 [00:09<00:00, 32.34it/s]

Final Validation Accuracy: 73.27%





###Step 8: Visualization

In [None]:
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt

class EfficientViTCAM:
    def __init__(self, model):
        self.model = model
        self.features = None
        self.gradients = None
        self.model.eval()

        # Find the last convolutional layer
        def find_target_layer(module):
            target_layer = None
            for name, layer in reversed(list(module.named_modules())):
                if isinstance(layer, torch.nn.Conv2d):
                    target_layer = layer
                    print(f"✅ Found target layer: {name}")
                    break
            return target_layer

        self.target_layer = find_target_layer(self.model)

        if self.target_layer is None:
            raise ValueError("❌ No convolutional layer found for Grad-CAM")

        # Register hooks
        self.target_layer.register_forward_hook(self.save_features)
        self.target_layer.register_full_backward_hook(self.save_gradients)
        print("✅ Hooks registered successfully")

    def save_features(self, module, input, output):
        self.features = output.detach()
        print(f"✅ Features saved - shape: {self.features.shape}")

    def save_gradients(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()
        print(f"✅ Gradients saved - shape: {self.gradients.shape}")

    def generate_cam(self, input_image, target_class=None):
        print("\n=== Generating CAM ===")
        print(f"Input image shape: {input_image.shape}")

        # Forward pass
        output = self.model(input_image.unsqueeze(0))
        print(f"Model output shape: {output.shape}")

        if target_class is None:
            target_class = output.argmax(dim=1).item()
        print(f"Target class: {target_class}")

        # Backward pass
        self.model.zero_grad()
        one_hot = torch.zeros_like(output)
        one_hot[0][target_class] = 1
        output.backward(gradient=one_hot)
        print("✅ Backward pass completed")

        # Process gradients and activations
        pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3], keepdim=True)
        activations = self.features
        print(f"Pooled gradients shape: {pooled_gradients.shape}")
        print(f"Activations shape: {activations.shape}")

        # Weight activations by gradients
        weighted_activations = activations * pooled_gradients
        print(f"Weighted activations shape: {weighted_activations.shape}")

        # Generate heatmap
        heatmap = torch.mean(weighted_activations, dim=1).squeeze()
        print(f"Heatmap before processing shape: {heatmap.shape}")

        # Convert to numpy and process
        heatmap = heatmap.cpu().numpy()
        print(f"Heatmap numpy shape: {heatmap.shape}")

        heatmap = np.maximum(heatmap, 0)
        print(f"Heatmap after max(0): {heatmap.shape}")

        # Resize to input image dimensions
        heatmap = cv2.resize(heatmap, (input_image.shape[2], input_image.shape[1]))
        print(f"Heatmap after resize: {heatmap.shape}")

        # Normalize
        heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap) + 1e-8)
        print("✅ Heatmap generated successfully")

        return heatmap

def visualize_heatmap(image, heatmap, true_class, pred_class, alpha=0.5):
    print("\n=== Visualizing ===")
    # Convert tensor image to numpy
    img = image.cpu().numpy().transpose(1, 2, 0)
    mean = np.array([0.4914, 0.4822, 0.4465])
    std = np.array([0.2470, 0.2435, 0.2616])
    img = std * img + mean  # Un-normalize
    img = np.clip(img, 0, 1)
    print(f"Image numpy shape: {img.shape}")

    # Create heatmap
    print(f"Heatmap input shape: {heatmap.shape}")
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    print(f"Heatmap after color mapping: {heatmap.shape}")

    # Superimpose heatmap on image
    superimposed_img = cv2.addWeighted(np.uint8(255 * img), alpha, heatmap, 1-alpha, 0)
    print(f"Final output shape: {superimposed_img.shape}")

    # Display
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.title(f"True: {true_class}")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(superimposed_img)
    plt.title(f"Predicted: {pred_class}")
    plt.axis('off')

    plt.tight_layout()
    print("✅ Showing visualization...")
    plt.show()

def generate_and_visualize_heatmaps(model, data_loader, num_examples=5):
    print("===== Starting Grad-CAM Visualization =====")
    try:
        cam_generator = EfficientViTCAM(model)
    except ValueError as e:
        print(f"❌ Error initializing Grad-CAM: {e}")
        return

    classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
              'dog', 'frog', 'horse', 'ship', 'truck']

    # Get some samples
    data_iter = iter(data_loader)
    sample_images, sample_labels = next(data_iter)
    sample_images, sample_labels = sample_images.to(device), sample_labels.to(device)
    print(f"Loaded {len(sample_images)} examples")

    for i in range(min(num_examples, len(sample_images))):
        print(f"\n==== Processing example {i+1}/{min(num_examples, len(sample_images))} ====")
        image = sample_images[i]
        label = sample_labels[i].item()

        # Get prediction
        with torch.no_grad():
            output = model(image.unsqueeze(0))
            pred_class = output.argmax(dim=1).item()
        print(f"True class: {classes[label]}, Predicted: {classes[pred_class]}")

        # Generate heatmap
        try:
            heatmap = cam_generator.generate_cam(image, pred_class)
            # Visualize
            visualize_heatmap(image, heatmap, classes[label], classes[pred_class])
        except Exception as e:
            print(f"❌ Error generating heatmap for example {i}: {e}")

# Run the visualization
print("===== Starting Visualization Process =====")
generate_and_visualize_heatmaps(model, val_loader, num_examples=5)
print("===== Visualization Complete =====")

===== Starting Visualization Process =====
===== Starting Grad-CAM Visualization =====
✅ Found target layer: blocks3.5.ffn1.m.pw2.c
✅ Hooks registered successfully
Loaded 32 examples

==== Processing example 1/5 ====
✅ Features saved - shape: torch.Size([1, 192, 1, 1])
✅ Features saved - shape: torch.Size([1, 192, 1, 1])
True class: cat, Predicted: cat

=== Generating CAM ===
Input image shape: torch.Size([3, 32, 32])
✅ Features saved - shape: torch.Size([1, 192, 1, 1])
✅ Features saved - shape: torch.Size([1, 192, 1, 1])
Model output shape: torch.Size([1, 10])
Target class: 3
✅ Gradients saved - shape: torch.Size([1, 192, 1, 1])
✅ Gradients saved - shape: torch.Size([1, 192, 1, 1])
✅ Backward pass completed
Pooled gradients shape: torch.Size([1, 192, 1, 1])
Activations shape: torch.Size([1, 192, 1, 1])
Weighted activations shape: torch.Size([1, 192, 1, 1])
Heatmap before processing shape: torch.Size([])
Heatmap numpy shape: ()
Heatmap after max(0): ()
❌ Error generating heatmap for ex