In [11]:
import os
import torch
import torch.nn as nn
import xml.etree.ElementTree as ET
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.models as models
import torchvision.transforms as transforms
from tqdm import tqdm

# 1. Define paths for ImageNet validation images and annotations
IMAGE_DIR = "/home/kajm20/mnist/ILSVRC/Data/CLS-LOC/val"  
ANNOTATION_DIR = "/home/kajm20/mnist/ILSVRC/Annotations/CLS-LOC/val"  

imagenet_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# 3. Load the synset mapping
synset_mapping_path = "/home/kajm20/mnist/ILSVRC/LOC_synset_mapping.txt"
wordnet_to_imagenet = {}

with open(synset_mapping_path) as f:
    for idx, line in enumerate(f.readlines()):
        wordnet_id, _ = line.split(' ', 1)
        wordnet_to_imagenet[wordnet_id] = idx  

# 4. Define the custom dataset class
class ImageNetValDataset(Dataset):
    def __init__(self, image_dir, annotation_dir, transform=None):
        self.image_dir = image_dir
        self.annotation_dir = annotation_dir
        self.transform = transform
        self.annotation_files = sorted(os.listdir(annotation_dir))

    def __len__(self):
        return len(self.annotation_files)

    def __getitem__(self, idx):
        annotation_path = os.path.join(self.annotation_dir, self.annotation_files[idx])
        tree = ET.parse(annotation_path)
        root = tree.getroot()
        wordnet_id = root.find("object").find("name").text  

        class_idx = wordnet_to_imagenet.get(wordnet_id, -1)  
        image_filename = root.find("filename").text + ".JPEG"
        image_path = os.path.join(self.image_dir, image_filename)

        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        return image, class_idx

num_workers = min(4, os.cpu_count())
# 5. Initialize the dataset and dataloader
imagenet_val_dataset = ImageNetValDataset(IMAGE_DIR, ANNOTATION_DIR, transform=imagenet_transform)
imagenet_val_loader = DataLoader(imagenet_val_dataset, batch_size=32, shuffle=False, num_workers=num_workers)

# 6. Define the model (EfficientNet-B0 with pre-trained weights)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.efficientnet_b0(weights='DEFAULT')  
model.to(device)
model.eval() 



EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActivat

In [18]:
# Hook function
def hook_fn(module, inputs, output):
    layer_name = f"{module.__class__.__name__} ({id(module)})"
    activation_values[layer_name] = activation_values.get(layer_name, 0) + output.detach().cpu().sum(dim=0)

for name, layer in model.named_modules():
    layer.register_forward_hook(hook_fn)  # Register hook for all layers


In [16]:
for name, layer in model.named_modules():
    if hasattr(layer, "_forward_hooks"):
        for hook_id in list(layer._forward_hooks.keys()):  # Convert to list to avoid runtime changes
            layer._forward_hooks.pop(hook_id)


In [19]:
for name, layer in model.named_modules():
    if layer._forward_hooks:
        print(f"Layer: {name}, Hooks: {layer._forward_hooks}")


Layer: , Hooks: OrderedDict({1011: <function hook_fn at 0x7e8e3120c860>})
Layer: features, Hooks: OrderedDict({1012: <function hook_fn at 0x7e8e3120c860>})
Layer: features.0, Hooks: OrderedDict({1013: <function hook_fn at 0x7e8e3120c860>})
Layer: features.0.0, Hooks: OrderedDict({1014: <function hook_fn at 0x7e8e3120c860>})
Layer: features.0.1, Hooks: OrderedDict({1015: <function hook_fn at 0x7e8e3120c860>})
Layer: features.0.2, Hooks: OrderedDict({1016: <function hook_fn at 0x7e8e3120c860>})
Layer: features.1, Hooks: OrderedDict({1017: <function hook_fn at 0x7e8e3120c860>})
Layer: features.1.0, Hooks: OrderedDict({1018: <function hook_fn at 0x7e8e3120c860>})
Layer: features.1.0.block, Hooks: OrderedDict({1019: <function hook_fn at 0x7e8e3120c860>})
Layer: features.1.0.block.0, Hooks: OrderedDict({1020: <function hook_fn at 0x7e8e3120c860>})
Layer: features.1.0.block.0.0, Hooks: OrderedDict({1021: <function hook_fn at 0x7e8e3120c860>})
Layer: features.1.0.block.0.1, Hooks: OrderedDict(

In [21]:
activation_values = {}  # Stores activation values

# 10. Define the evaluation function
def evaluate_model(model, dataloader):
    correct = 0
    total = 0

    with torch.no_grad():  
        for images, labels in tqdm(dataloader):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)  

            _, predicted = torch.max(outputs, 1)  
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    accuracy = (correct / total) * 100
    return accuracy

# 11. Evaluate the model on ImageNet validation set
accuracy = evaluate_model(model, imagenet_val_loader)
print(f"EfficientNet-B0 Top-1 Accuracy on ImageNet: {accuracy:.2f}%")

100%|██████████| 1563/1563 [20:31<00:00,  1.27it/s]

EfficientNet-B0 Top-1 Accuracy on ImageNet: 77.67%





In [44]:
num_val_images = len(imagenet_val_dataset)

averaged_activations = {
    k: v / num_val_images for k, v in activation_values.items() if isinstance(v, torch.Tensor)
}
averaged_activations


{'Conv2d (139149048122576)': tensor([[[-4.8224e-01, -5.2361e-01, -5.2393e-01,  ..., -5.4800e-01,
           -5.4083e-01, -5.3303e-01],
          [-3.8411e-03, -1.7018e-03,  4.9430e-03,  ..., -3.9553e-03,
           -1.7998e-03,  4.0263e-03],
          [ 1.3811e-03,  8.9226e-03,  4.1932e-03,  ...,  3.3174e-03,
            4.6704e-03,  3.9172e-03],
          ...,
          [ 5.0961e-03, -5.5607e-03, -3.4501e-03,  ..., -5.4917e-03,
           -1.6579e-02, -6.0517e-03],
          [ 3.3903e-03,  4.9833e-03, -6.4695e-04,  ...,  8.9518e-03,
            1.0445e-02,  5.8544e-03],
          [ 3.8012e-03,  5.1938e-03,  8.5869e-03,  ...,  2.3938e-03,
            3.8824e-03,  7.8598e-03]],
 
         [[-6.1249e-01, -2.5841e-02, -1.0341e-02,  ..., -5.8507e-04,
           -1.4132e-02, -1.0778e-02],
          [-6.5535e-01, -1.4484e-02,  2.8869e-03,  ...,  4.4892e-03,
           -1.7839e-03,  6.9549e-03],
          [-6.6537e-01, -3.2544e-03, -1.0213e-03,  ...,  7.4690e-03,
            6.2123e-03, -1.30

In [51]:
batch_sum = torch.zeros(3, 224, 224).to(device)  # Ensure it's on the same device
num_images = 0  # Track total images processed

def evaluate_model(model, dataloader):
    global batch_sum, num_images  # Allow modification of global variables

    with torch.no_grad():  
        for images, labels in tqdm(dataloader):
            images = images.to(device)  # Move to same device as batch_sum
            batch_sum += torch.sum(images, dim=0)  # Accumulate sum across batches
            num_images += images.shape[0]  # Update total image count

evaluate_model(model, imagenet_val_loader)

# Compute the average pixel value across all images
print(num_images)
inmean = batch_sum / num_images  # Element-wise division
print(inmean.shape)
print(inmean)  # Should be [3, 224, 224]


100%|██████████| 1563/1563 [00:57<00:00, 27.05it/s]

50000
torch.Size([3, 224, 224])
tensor([[[-0.0044, -0.0020,  0.0010,  ...,  0.0015,  0.0019,  0.0014],
         [-0.0056, -0.0030,  0.0008,  ...,  0.0034,  0.0031,  0.0030],
         [-0.0037, -0.0010,  0.0021,  ...,  0.0041,  0.0027,  0.0026],
         ...,
         [-0.1084, -0.1092, -0.1071,  ..., -0.1085, -0.1085, -0.1068],
         [-0.1091, -0.1097, -0.1085,  ..., -0.1095, -0.1100, -0.1085],
         [-0.1108, -0.1121, -0.1115,  ..., -0.1098, -0.1102, -0.1088]],

        [[ 0.1119,  0.1139,  0.1159,  ...,  0.1180,  0.1182,  0.1177],
         [ 0.1101,  0.1122,  0.1150,  ...,  0.1188,  0.1185,  0.1184],
         [ 0.1104,  0.1124,  0.1150,  ...,  0.1182,  0.1169,  0.1175],
         ...,
         [-0.0661, -0.0669, -0.0657,  ..., -0.0679, -0.0671, -0.0639],
         [-0.0658, -0.0667, -0.0667,  ..., -0.0683, -0.0681, -0.0656],
         [-0.0670, -0.0687, -0.0691,  ..., -0.0671, -0.0671, -0.0658]],

        [[ 0.1628,  0.1644,  0.1655,  ...,  0.1661,  0.1673,  0.1673],
         [ 0.




In [56]:
if inmean.dim() == 3:
    input_image = inmean.unsqueeze(0)
input_image = input_image.to(device)

first_conv = None
for layer in model.modules():
    if isinstance(layer, torch.nn.Conv2d):
        first_conv = layer
        break

with torch.no_grad():
    output = first_conv(input_image)

# Store it in a new dictionary
weighted_activations = {"image": output.cpu()}  # move to CPU if you like

# Optional: Check shape
print(f"First conv output shape: {output.shape}")


First conv output shape: torch.Size([1, 32, 112, 112])


In [70]:
module_lookup = {f"{module.__class__.__name__} ({id(module)})": module for _, module in model.named_modules()}

keys = list(activation_values.keys())
weighted_activations = {}

# We keep track of processed layers by their names
processed_layers = set()

for i in range(len(keys) - 1):
    current_key = keys[i]
    next_key = keys[i + 1]

    current_activation = activation_values[current_key].unsqueeze(0)  # Add batch dim
    next_layer = module_lookup.get(next_key, None)

    if next_layer is None:
        continue  # Skip if the next layer is not found
    
    # Check if next layer is part of a composite block (e.g., Conv2dNormActivation)
    if isinstance(next_layer, nn.Sequential):
        # Skip all submodules if the first module in the block has already been processed
        if next_key in processed_layers:
            continue
        
        # Process the first submodule (Conv2d, BatchNorm2d, etc.)
        first_submodule = next_layer[0]  # This could be Conv2d or any other first layer
        processed_layers.add(next_key)  # Mark this block as processed
        
        print(f"Processing first submodule: {first_submodule}")
        with torch.no_grad():
            weighted_activations[current_key] = first_submodule(current_activation)
    else:
        # If it's not a Sequential, process the layer normally
        if next_key not in processed_layers:
            print(f"Processing layer: {next_layer}")
            with torch.no_grad():
                weighted_activations[current_key] = next_layer(current_activation)
            processed_layers.add(next_key)

# Print the results of weighted activations
for k, v in weighted_activations.items():
    print(f"{k}: {v.shape}")






Processing layer: BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Processing layer: SiLU(inplace=True)
Processing first submodule: Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)


RuntimeError: Given groups=1, weight of size [32, 3, 3, 3], expected input[1, 32, 112, 112] to have 3 channels, but got 32 channels instead

In [41]:
for k, v in activation_values.items():
    print(f"{k}: {v.shape}")


Conv2d (139149048122576): torch.Size([32, 112, 112])
BatchNorm2d (139149056681616): torch.Size([32, 112, 112])
SiLU (139149047962016): torch.Size([32, 112, 112])
Conv2dNormActivation (139149056660240): torch.Size([32, 112, 112])
Conv2d (139149048122896): torch.Size([32, 112, 112])
BatchNorm2d (139149056682160): torch.Size([32, 112, 112])
SiLU (139149047968016): torch.Size([32, 112, 112])
Conv2dNormActivation (139149047966336): torch.Size([32, 112, 112])
AdaptiveAvgPool2d (139149047968736): torch.Size([32, 1, 1])
Conv2d (139149048123216): torch.Size([8, 1, 1])
SiLU (139149047968976): torch.Size([8, 1, 1])
Conv2d (139149048123536): torch.Size([32, 1, 1])
Sigmoid (139149048367152): torch.Size([32, 1, 1])
SqueezeExcitation (139149048366480): torch.Size([32, 112, 112])
Conv2d (139149048123856): torch.Size([16, 112, 112])
BatchNorm2d (139149056682432): torch.Size([16, 112, 112])
Conv2dNormActivation (139149047968256): torch.Size([16, 112, 112])
Sequential (139149048369616): torch.Size([16, 1

In [49]:
for name, param in model.named_parameters():
    layer = dict(model.named_modules()).get(name.rsplit('.', 1)[0], None)
    print(f"{name} {layer} {param.shape}")


features.0.0.weight Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) torch.Size([32, 3, 3, 3])
features.0.1.weight BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) torch.Size([32])
features.0.1.bias BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) torch.Size([32])
features.1.0.block.0.0.weight Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False) torch.Size([32, 1, 3, 3])
features.1.0.block.0.1.weight BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) torch.Size([32])
features.1.0.block.0.1.bias BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) torch.Size([32])
features.1.0.block.1.fc1.weight Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1)) torch.Size([8, 32, 1, 1])
features.1.0.block.1.fc1.bias Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1)) torch.Size([8])
features.1.0.block.1.fc2.weight Conv2d(8, 32, 