## Predicting input class from descriptions of higly activating images

In [1]:
import os
#virtually move to parent directory
os.chdir("..")

import math
import torch
import pandas as pd

import matplotlib
from matplotlib import pyplot as plt
from sentence_transformers import SentenceTransformer

import clip
import utils
import data_utils
import similarity

## Settings

In [2]:
target_name = 'resnet50'
target_layer = 'layer4'

clip_name = 'ViT-B/16'
d_probe = 'imagenet_broden'
concept_set = 'data/20k.txt'
batch_size = 200
device = 'cuda'
pool_mode = 'avg'

save_dir = 'saved_activations'
similarity_fn = similarity.soft_wpmi

## Run CLIP-Dissect

In [3]:
utils.save_activations(clip_name = clip_name, target_name = target_name, target_layers = [target_layer], 
                       d_probe = d_probe, concept_set = concept_set, batch_size = batch_size, 
                       device = device, pool_mode=pool_mode, save_dir = save_dir)

In [4]:
save_names = utils.get_save_names(clip_name = clip_name, target_name = target_name,
                                  target_layer = target_layer, d_probe = d_probe,
                                  concept_set = concept_set, pool_mode=pool_mode,
                                  save_dir = save_dir)

target_save_name, clip_save_name, text_save_name = save_names

similarities, target_feats = utils.get_similarity_from_activations(target_save_name, clip_save_name, 
                                                             text_save_name, similarity_fn, device=device)

with open(concept_set, 'r') as f: 
    words = (f.read()).split('\n')
    
vals, ids = torch.max(similarities, dim=1)
descriptions = {"CLIP-Dissect":[words[int(idx)] for idx in ids]}

100%|██████████| 2048/2048 [00:14<00:00, 145.32it/s]


torch.Size([2048, 20000])


## Calculate standard accuracy

In [5]:
#only use imagenet val for this part
pil_data = data_utils.get_data('imagenet_val')
target_model, target_preprocess = data_utils.get_target_model(target_name, device)

save_names = utils.get_save_names(clip_name = clip_name, target_name = target_name,
                                  target_layer = target_layer, d_probe = 'imagenet_val',
                                  concept_set = concept_set, pool_mode=pool_mode,
                                  save_dir = save_dir)
target_save_name, clip_save_name, text_save_name = save_names

dataset = data_utils.get_data('imagenet_val', target_preprocess)
utils.save_target_activations(target_model, dataset, target_save_name, target_layers = [target_layer], batch_size = batch_size,
                              device = device, pool_mode = pool_mode)

In [6]:
activations = torch.load(target_save_name, map_location='cpu')
W_f = target_model.fc.weight
b_f = target_model.fc.bias

targets = torch.LongTensor(pil_data.targets).to(device)
with open('data/imagenet_labels.txt', 'r') as f:
    classes = f.read().split('\n')

In [7]:
correct = 0
with torch.no_grad():
    for i in range(math.ceil(len(targets)/batch_size)):
        targ = targets[i*batch_size:(i+1)*batch_size]
        act = activations[i*batch_size:(i+1)*batch_size].to(device)
        out = act@W_f.T + b_f
        pred = torch.max(out, dim=1)[1]
        correct += torch.sum(pred==targ)
print("Standard Accuracy:{:.2f}%".format(correct/len(targets)*100))

Standard Accuracy:76.13%


## Measure how often most contributing neuron description matches target class

In [8]:
mpnet_model = SentenceTransformer('all-mpnet-base-v2')
clip_model, _ = clip.load(clip_name, device=device)

with torch.no_grad():
    tokens = clip.tokenize(classes).to(device)
    class_clip = clip_model.encode_text(tokens)
    class_clip /= class_clip.norm(dim=-1, keepdim=True)

class_mpnet = mpnet_model.encode(classes)
class_mpnet = torch.tensor(class_mpnet).to(device)
print(class_clip.shape, class_mpnet.shape)

torch.Size([1000, 512]) torch.Size([1000, 768])


In [9]:
name_conversion = {'resnet50':'resnet50_imagenet', 'resnet18_places':'resnet18_places365'}

netdissect_res = pd.read_csv('data/NetDissect_results/{}_{}.csv'.format(name_conversion[target_name],
                                                                       target_layer))
descriptions["Network Dissection"] = netdissect_res['label'].values

milan_base = pd.read_csv('data/MILAN_results/m_base_{}.csv'.format(name_conversion[target_name]))
milan_base = milan_base[milan_base['layer']==target_layer]
milan_base = milan_base.sort_values(by=['unit'])
descriptions["MILAN base"] = list(milan_base['description'])

In [10]:
for key in descriptions:
    print(key)
    with torch.no_grad():
        tokens = clip.tokenize(descriptions[key]).to(device)
        desc_clip = clip_model.encode_text(tokens)
        desc_clip /= desc_clip.norm(dim=-1, keepdim=True)

    desc_mpnet = mpnet_model.encode(descriptions[key])
    desc_mpnet = torch.tensor(desc_mpnet).to(device)

    correct_gt = 0
    correct_pred = 0

    with torch.no_grad():
        for i in range(math.ceil(len(targets)/batch_size)):
            targ = targets[i*batch_size:(i+1)*batch_size]
            act = activations[i*batch_size:(i+1)*batch_size].to(device)

            out = act@W_f.T + b_f
            pred = torch.max(out, dim=1)[1]

            contrib = W_f[pred]*act
            max_contrib = torch.max(contrib, dim=1)[1]

            clip_cos = desc_clip[max_contrib]@class_clip.T
            mpnet_cos = desc_mpnet[max_contrib]@class_mpnet.T
            
            cos = 3*clip_cos.detach() + mpnet_cos
            most_sim = torch.max(cos, dim=1)[1]
            
            correct_gt += torch.sum(most_sim==targ)
            correct_pred += torch.sum(most_sim==pred)

    print("Same as gt:{:.2f}%".format(100*correct_gt/len(targets)))
    print("Same as pred:{:.2f}% \n".format(100*correct_pred/len(targets)))

CLIP-Dissect
Same as gt:9.87%
Same as pred:11.83% 

Network Dissection
Same as gt:3.04%
Same as pred:3.68% 

MILAN base
Same as gt:2.30%
Same as pred:2.63% 

