# Enhanced Zero-shot Learning with Neuron-concepts

In [1]:
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from PIL import Image

from data_utils import get_model

In [2]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

model_name = "resnet50"
model, preprocess = get_model(model_name, device)

## Register Hooks

In [3]:
# activation with pooling mode
def get_activation(outputs, mode):
    '''
    mode: how to pool activations: one of avg, max
    for fc or ViT neurons does no pooling
    '''
    if mode=='avg':
        def hook(model, input, output):
            if len(output.shape)==4: #CNN layers
                outputs.append(output.mean(dim=[2,3]).detach())
            elif len(output.shape)==3: #ViT
                outputs.append(output[:, 0].clone())
            elif len(output.shape)==2: #FC layers
                outputs.append(output.detach())
    elif mode=='max':
        def hook(model, input, output):
            if len(output.shape)==4: #CNN layers
                outputs.append(output.amax(dim=[2,3]).detach())
            elif len(output.shape)==3: #ViT
                outputs.append(output[:, 0].clone())
            elif len(output.shape)==2: #FC layers
                outputs.append(output.detach())
    return hook

In [4]:
def per_img_activations(target_model, target_name, target_layers, dataset, batch_size, device='cuda', pool_mode='avg'):

    all_features = {layer: [] for layer in target_layers}
    hooks = {}

    # register forward hook
    for layer in target_layers:
        module = dict(target_model.named_modules()).get(layer)
        if module:
            hooks[layer] = module.register_forward_hook(get_activation(all_features[layer], pool_mode))
            print(f"Hook registered for layer: {layer}")
        else:
            print(f"Warning: Layer '{layer}' does not exist in the model.")

    # Forward pass
    with torch.no_grad():
        for images in tqdm(DataLoader(dataset, batch_size, num_workers=8, pin_memory=True)):
            if "cvcl" in target_name:
                _ = target_model.encode_image(images.to(device))
            else:
                _ = target_model(images.to(device))

    # Remove 
    for layer in target_layers:
        hooks[layer].remove()

    torch.cuda.empty_cache()
    print("Activations saved and memory cleaned up.")

    # dict: {layername: activation tensor[batch_size, n_neurons]}
    activations = {layer: torch.stack(all_features[layer]) for layer in target_layers}

    return activations

Load Image

In [5]:
img = preprocess(Image.open('data/toy_example_dataset_konka/muffin/img_9246-muffin.jpg'))
data = torch.unsqueeze(img, 0)

Get activation


In [6]:
activations = per_img_activations(target_model=model, target_name=model_name, target_layers=target_layers, dataset=data, batch_size=200, device=device)

print(activations['layer4'].shape,
      activations['fc'].shape)

Hook registered for layer: layer4
Hook registered for layer: fc


100%|██████████| 1/1 [00:13<00:00, 13.13s/it]


Activations saved and memory cleaned up.
torch.Size([1, 1, 2048]) torch.Size([1, 1, 1000])


Find Top-k activationa and index

In [7]:
top_k_values, top_k_indices = torch.topk(activations['fc'], k=10, dim=2) 

# flat tensor into list
top_k_indices = top_k_indices[0].tolist()[0]
top_k_values = top_k_values[0].tolist()[0]

top_k_df = pd.DataFrame({
    'unit': top_k_indices,
    'value': top_k_values
})

Find Zero-shot Description

In [14]:
# Load descriptions generated from CLIP-dissect
descriptions = pd.read_csv('descriptions/res_konk_baby.csv')

# Filter the descriptions to match the top_k indices with the descriptions
filtered_descriptions = descriptions[descriptions.apply(lambda row: row['unit'] in top_k_indices and row['layer'] == 'layer4', axis=1)]

# Merge the filtered descriptions with the top_k DataFrame
matched_descriptions = pd.merge(filtered_descriptions, top_k_df, on='unit')

# Sort the matched descriptions by 'value' in descending order
matched_descriptions = matched_descriptions.sort_values(by='value', ascending=False)

# Print the final DataFrame
print(matched_descriptions)

    layer  unit description  similarity      value
6  layer4   930         pen    1.301636  10.217084
8  layer4   962       phone    1.341309   9.900900
5  layer4   911         car    0.707062   9.468554
7  layer4   960    stroller    0.829865   9.288763
0  layer4   415   microwave    0.744781   9.257796
1  layer4   551    computer    0.882507   8.907595
9  layer4   969        desk    3.426575   8.473521
2  layer4   588       swing    2.122437   8.260846
4  layer4   824    backpack    0.912323   8.117612
3  layer4   738    backpack    0.702332   7.558163
