In [11]:
import os
import torch
import torch.nn as nn
import xml.etree.ElementTree as ET
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.models as models
import torchvision.transforms as transforms
from tqdm import tqdm

# 1. Define paths for ImageNet validation images and annotations
IMAGE_DIR = "/home/kajm20/mnist/ILSVRC/Data/CLS-LOC/val"  
ANNOTATION_DIR = "/home/kajm20/mnist/ILSVRC/Annotations/CLS-LOC/val"  

imagenet_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# 3. Load the synset mapping
synset_mapping_path = "/home/kajm20/mnist/ILSVRC/LOC_synset_mapping.txt"
wordnet_to_imagenet = {}

with open(synset_mapping_path) as f:
    for idx, line in enumerate(f.readlines()):
        wordnet_id, _ = line.split(' ', 1)
        wordnet_to_imagenet[wordnet_id] = idx  

# 4. Define the custom dataset class
class ImageNetValDataset(Dataset):
    def __init__(self, image_dir, annotation_dir, transform=None):
        self.image_dir = image_dir
        self.annotation_dir = annotation_dir
        self.transform = transform
        self.annotation_files = sorted(os.listdir(annotation_dir))

    def __len__(self):
        return len(self.annotation_files)

    def __getitem__(self, idx):
        annotation_path = os.path.join(self.annotation_dir, self.annotation_files[idx])
        tree = ET.parse(annotation_path)
        root = tree.getroot()
        wordnet_id = root.find("object").find("name").text  

        class_idx = wordnet_to_imagenet.get(wordnet_id, -1)  
        image_filename = root.find("filename").text + ".JPEG"
        image_path = os.path.join(self.image_dir, image_filename)

        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        return image, class_idx

num_workers = min(4, os.cpu_count())
# 5. Initialize the dataset and dataloader
imagenet_val_dataset = ImageNetValDataset(IMAGE_DIR, ANNOTATION_DIR, transform=imagenet_transform)
imagenet_val_loader = DataLoader(imagenet_val_dataset, batch_size=32, shuffle=False, num_workers=num_workers)

# 6. Define the model (EfficientNet-B0 with pre-trained weights)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.efficientnet_b0(weights='DEFAULT')  
model.to(device)
model.eval() 



EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActivat

In [18]:
# Hook function
def hook_fn(module, inputs, output):
    layer_name = f"{module.__class__.__name__} ({id(module)})"
    activation_values[layer_name] = activation_values.get(layer_name, 0) + output.detach().cpu().sum(dim=0)

for name, layer in model.named_modules():
    layer.register_forward_hook(hook_fn)  # Register hook for all layers


In [16]:
for name, layer in model.named_modules():
    if hasattr(layer, "_forward_hooks"):
        for hook_id in list(layer._forward_hooks.keys()):  # Convert to list to avoid runtime changes
            layer._forward_hooks.pop(hook_id)


In [19]:
for name, layer in model.named_modules():
    if layer._forward_hooks:
        print(f"Layer: {name}, Hooks: {layer._forward_hooks}")


Layer: , Hooks: OrderedDict({1011: <function hook_fn at 0x7e8e3120c860>})
Layer: features, Hooks: OrderedDict({1012: <function hook_fn at 0x7e8e3120c860>})
Layer: features.0, Hooks: OrderedDict({1013: <function hook_fn at 0x7e8e3120c860>})
Layer: features.0.0, Hooks: OrderedDict({1014: <function hook_fn at 0x7e8e3120c860>})
Layer: features.0.1, Hooks: OrderedDict({1015: <function hook_fn at 0x7e8e3120c860>})
Layer: features.0.2, Hooks: OrderedDict({1016: <function hook_fn at 0x7e8e3120c860>})
Layer: features.1, Hooks: OrderedDict({1017: <function hook_fn at 0x7e8e3120c860>})
Layer: features.1.0, Hooks: OrderedDict({1018: <function hook_fn at 0x7e8e3120c860>})
Layer: features.1.0.block, Hooks: OrderedDict({1019: <function hook_fn at 0x7e8e3120c860>})
Layer: features.1.0.block.0, Hooks: OrderedDict({1020: <function hook_fn at 0x7e8e3120c860>})
Layer: features.1.0.block.0.0, Hooks: OrderedDict({1021: <function hook_fn at 0x7e8e3120c860>})
Layer: features.1.0.block.0.1, Hooks: OrderedDict(

In [20]:
activation_values = {}  # Stores activation values

# 10. Define the evaluation function
def evaluate_model(model, dataloader):
    correct = 0
    total = 0

    with torch.no_grad():  
        for images, labels in tqdm(dataloader):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)  

            _, predicted = torch.max(outputs, 1)  
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    accuracy = (correct / total) * 100
    return accuracy

# 11. Evaluate the model on ImageNet validation set
accuracy = evaluate_model(model, imagenet_val_loader)
print(f"EfficientNet-B0 Top-1 Accuracy on ImageNet: {accuracy:.2f}%")

  0%|          | 7/1563 [00:06<24:19,  1.07it/s]


KeyboardInterrupt: 

In [9]:
activation_values

{'Conv2d (139149088319744)': tensor([[[[-6.1614e+04, -6.4736e+04, -6.4768e+04,  ..., -6.4818e+04,
            -6.4742e+04, -6.4611e+04],
           [-3.8792e+02, -1.0424e+02, -8.0834e+01,  ..., -1.1180e+02,
            -7.3723e+01, -8.6023e+01],
           [-3.4326e+02, -7.1924e+01, -6.5156e+01,  ..., -7.6025e+01,
            -7.6475e+01, -9.9233e+01],
           ...,
           [-3.0130e+02, -1.0557e+02, -1.1884e+02,  ..., -1.1021e+02,
            -1.6407e+02, -5.1570e+01],
           [-3.3381e+02, -6.4971e+01, -1.0209e+02,  ...,  1.2123e+01,
             4.6029e+01, -5.3397e+01],
           [-2.6913e+02, -5.1732e+01, -1.2867e+01,  ..., -3.0070e+01,
            -3.6203e+00, -1.5449e+01]],
 
          [[-8.2052e+04, -1.4913e+03, -1.3505e+03,  ..., -1.3348e+03,
            -1.4058e+03, -1.3166e+03],
           [-8.7504e+04,  2.2601e+02,  2.9173e+02,  ...,  3.0114e+02,
             3.0142e+02,  3.4803e+02],
           [-8.7628e+04,  3.0711e+02,  2.8056e+02,  ...,  3.3981e+02,
           