## Similarity function comparison

In [1]:
import os
#virtually move to parent directory
os.chdir("..")

import torch
from sentence_transformers import SentenceTransformer
from sklearn import metrics

import clip
import utils
import similarity

## Settings

In [2]:
similarity_fns = ["cos_similarity", "rank_reorder", "wpmi", "soft_wpmi"]
d_probes = ['cifar100_train', 'broden', 'imagenet_val', 'imagenet_broden']

clip_name = 'ViT-B/16'
target_name = 'resnet50'
target_layer = 'fc'
batch_size = 200
device = 'cuda'
pool_mode = 'avg'
save_dir = 'saved_activations'

In [3]:
model = SentenceTransformer('all-mpnet-base-v2')
clip_model, _ = clip.load(clip_name, device=device)

with open("data/imagenet_labels.txt", "r") as f:
    cls_id_to_name = f.read().split("\n")

# Cos similarities

In [4]:
concept_set = 'data/20k.txt'

with open(concept_set, 'r') as f:
    words = f.read().split('\n')

for similarity_fn in similarity_fns:
    for d_probe in d_probes:
        utils.save_activations(clip_name = clip_name, target_name = target_name, target_layers = [target_layer], 
                               d_probe = d_probe, concept_set = concept_set, batch_size = batch_size, 
                               device = device, pool_mode=pool_mode, save_dir = save_dir)

        save_names = utils.get_save_names(clip_name = clip_name, target_name = target_name,
                                          target_layer = target_layer, d_probe = d_probe,
                                          concept_set = concept_set, pool_mode=pool_mode,
                                          save_dir = save_dir)

        target_save_name, clip_save_name, text_save_name = save_names

        similarities, target_feats = utils.get_similarity_from_activations(target_save_name, clip_save_name, 
                                                                           text_save_name, 
                                                                           eval("similarity.{}".format(similarity_fn)),
                                                                           device=device)

        clip_preds = torch.argmax(similarities, dim=1)
        clip_preds = [words[int(pred)] for pred in clip_preds]

        clip_cos, mpnet_cos = utils.get_cos_similarity(clip_preds, cls_id_to_name, clip_model, model, device, batch_size)
        print("Similarity fn: {}, D_probe: {}".format(similarity_fn, d_probe))
        print("Clip similarity: {:.4f}, mpnet similarity: {:.4f}".format(clip_cos, mpnet_cos))

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 1/1 [00:01<00:00,  1.65s/it]


Similarity fn: cos_similarity, D_probe: cifar100_train
Clip similarity: 0.6484, mpnet similarity: 0.2756


100%|██████████| 1/1 [00:01<00:00,  1.81s/it]


Similarity fn: cos_similarity, D_probe: broden
Clip similarity: 0.6235, mpnet similarity: 0.2153


100%|██████████| 1/1 [00:01<00:00,  1.48s/it]


Similarity fn: cos_similarity, D_probe: imagenet_val
Clip similarity: 0.6216, mpnet similarity: 0.2829


100%|██████████| 1/1 [00:03<00:00,  3.17s/it]


Similarity fn: cos_similarity, D_probe: imagenet_broden
Clip similarity: 0.6421, mpnet similarity: 0.2587
Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 1000/1000 [02:49<00:00,  5.91it/s]


Similarity fn: rank_reorder, D_probe: cifar100_train
Clip similarity: 0.7227, mpnet similarity: 0.3247


100%|██████████| 1000/1000 [03:39<00:00,  4.55it/s]


Similarity fn: rank_reorder, D_probe: broden
Clip similarity: 0.7471, mpnet similarity: 0.3856


100%|██████████| 1000/1000 [02:46<00:00,  6.00it/s]


Similarity fn: rank_reorder, D_probe: imagenet_val
Clip similarity: 0.7832, mpnet similarity: 0.4911


100%|██████████| 1000/1000 [06:54<00:00,  2.41it/s]


Similarity fn: rank_reorder, D_probe: imagenet_broden
Clip similarity: 0.7866, mpnet similarity: 0.5035
Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 1000/1000 [00:01<00:00, 622.31it/s]


Similarity fn: wpmi, D_probe: cifar100_train
Clip similarity: 0.7192, mpnet similarity: 0.3457


100%|██████████| 1000/1000 [00:01<00:00, 597.84it/s]


Similarity fn: wpmi, D_probe: broden
Clip similarity: 0.7427, mpnet similarity: 0.3886


100%|██████████| 1000/1000 [00:01<00:00, 553.30it/s]


Similarity fn: wpmi, D_probe: imagenet_val
Clip similarity: 0.7944, mpnet similarity: 0.5301


100%|██████████| 1000/1000 [00:01<00:00, 553.67it/s]


Similarity fn: wpmi, D_probe: imagenet_broden
Clip similarity: 0.7930, mpnet similarity: 0.5266
Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 1000/1000 [00:05<00:00, 185.16it/s]


torch.Size([1000, 20000])
Similarity fn: soft_wpmi, D_probe: cifar100_train
Clip similarity: 0.7300, mpnet similarity: 0.3671


100%|██████████| 1000/1000 [00:04<00:00, 203.97it/s]


torch.Size([1000, 20000])
Similarity fn: soft_wpmi, D_probe: broden
Clip similarity: 0.7412, mpnet similarity: 0.3946


100%|██████████| 1000/1000 [00:04<00:00, 209.43it/s]


torch.Size([1000, 20000])
Similarity fn: soft_wpmi, D_probe: imagenet_val
Clip similarity: 0.7900, mpnet similarity: 0.5262


100%|██████████| 1000/1000 [00:04<00:00, 209.87it/s]


torch.Size([1000, 20000])
Similarity fn: soft_wpmi, D_probe: imagenet_broden
Clip similarity: 0.7900, mpnet similarity: 0.5239


# Accuracies

In [5]:
def get_topk_acc(sim, k=5):
    correct = 0
    for orig_id in range(1000):
        vals, ids = torch.topk(sim[orig_id], k=k)
        for idx in ids[:k]:
            correct += (int(idx)==orig_id)
    return (correct/1000)*100

def get_correct_rank_mean_median(sim):
    ranks = []
    for orig_id in range(1000):
        vals, ids = torch.sort(sim[orig_id], descending=True)
        
        ranks.append(list(ids).index(orig_id)+1)
        
    mean = sum(ranks)/len(ranks)
    median = sorted(ranks)[500]
    return mean, median

def get_auc(sim):
    max_sim, preds = torch.max(sim.cpu(), dim=1)
    gtruth = torch.arange(0, 1000)
    correct = (preds==gtruth)
    fpr, tpr, thresholds = metrics.roc_curve(correct, max_sim)
    auc = metrics.roc_auc_score(correct, max_sim)
    return auc

In [9]:
concept_set = 'data/imagenet_labels.txt'
with open(concept_set, 'r') as f: 
    words = (f.read()).split('\n')
    

for similarity_fn in similarity_fns:
    for d_probe in d_probes:
        utils.save_activations(clip_name = clip_name, target_name = target_name, target_layers = [target_layer], 
                               d_probe = d_probe, concept_set = concept_set, batch_size = batch_size, 
                               device = device, pool_mode=pool_mode, save_dir = save_dir)

        save_names = utils.get_save_names(clip_name = clip_name, target_name = target_name,
                                          target_layer = target_layer, d_probe = d_probe,
                                          concept_set = concept_set, pool_mode=pool_mode,
                  
                                          save_dir = save_dir)

        target_save_name, clip_save_name, text_save_name = save_names

        similarities, target_feats = utils.get_similarity_from_activations(target_save_name, clip_save_name, 
                                                                           text_save_name, 
                                                                           eval("similarity.{}".format(similarity_fn)),
                                                                           device=device)
        
        print("Similarity fn: {}, D_probe: {}".format(similarity_fn, d_probe))
        print("Top 1 acc: {:.2f}%, Top 5 acc: {:.2f}%".format(get_topk_acc(similarities, k=1),
                                                         get_topk_acc(similarities, k=5)))
        
        mean, median = get_correct_rank_mean_median(similarities)
        print("Mean rank of correct class: {:.2f}, Median rank of correct class: {}".format(mean, median))
        print("AUC: {:.4f}".format(get_auc(similarities)))



Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 1/1 [00:00<00:00,  9.83it/s]


Similarity fn: cos_similarity, D_probe: cifar100_train
Top 1 acc: 8.60%, Top 5 acc: 25.10%
Mean rank of correct class: 53.94, Median rank of correct class: 21
AUC: 0.5926


100%|██████████| 1/1 [00:00<00:00,  8.83it/s]


Similarity fn: cos_similarity, D_probe: broden
Top 1 acc: 5.70%, Top 5 acc: 21.30%
Mean rank of correct class: 63.92, Median rank of correct class: 24
AUC: 0.5710


100%|██████████| 1/1 [00:00<00:00, 11.84it/s]


Similarity fn: cos_similarity, D_probe: imagenet_val
Top 1 acc: 15.90%, Top 5 acc: 43.80%
Mean rank of correct class: 22.56, Median rank of correct class: 7
AUC: 0.4849


100%|██████████| 1/1 [00:00<00:00,  5.21it/s]


Similarity fn: cos_similarity, D_probe: imagenet_broden
Top 1 acc: 11.30%, Top 5 acc: 34.60%
Mean rank of correct class: 32.64, Median rank of correct class: 11
AUC: 0.5003
Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 1000/1000 [00:12<00:00, 81.02it/s]


Similarity fn: rank_reorder, D_probe: cifar100_train
Top 1 acc: 36.60%, Top 5 acc: 67.50%
Mean rank of correct class: 13.63, Median rank of correct class: 3
AUC: 0.6338


100%|██████████| 1000/1000 [00:10<00:00, 93.23it/s]


Similarity fn: rank_reorder, D_probe: broden
Top 1 acc: 57.70%, Top 5 acc: 83.70%
Mean rank of correct class: 6.69, Median rank of correct class: 1
AUC: 0.6853


100%|██████████| 1000/1000 [00:13<00:00, 75.74it/s]


Similarity fn: rank_reorder, D_probe: imagenet_val
Top 1 acc: 89.80%, Top 5 acc: 98.60%
Mean rank of correct class: 2.28, Median rank of correct class: 1
AUC: 0.6434


100%|██████████| 1000/1000 [00:14<00:00, 67.79it/s]


Similarity fn: rank_reorder, D_probe: imagenet_broden
Top 1 acc: 89.90%, Top 5 acc: 98.20%
Mean rank of correct class: 2.12, Median rank of correct class: 1
AUC: 0.5993
Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 1000/1000 [00:00<00:00, 7502.62it/s]


Similarity fn: wpmi, D_probe: cifar100_train
Top 1 acc: 24.00%, Top 5 acc: 55.00%
Mean rank of correct class: 20.46, Median rank of correct class: 4
AUC: 0.6355


100%|██████████| 1000/1000 [00:00<00:00, 6698.21it/s]


Similarity fn: wpmi, D_probe: broden
Top 1 acc: 47.10%, Top 5 acc: 79.40%
Mean rank of correct class: 7.58, Median rank of correct class: 2
AUC: 0.7118


100%|██████████| 1000/1000 [00:00<00:00, 6421.12it/s]


Similarity fn: wpmi, D_probe: imagenet_val
Top 1 acc: 86.90%, Top 5 acc: 98.10%
Mean rank of correct class: 2.00, Median rank of correct class: 1
AUC: 0.7176


100%|██████████| 1000/1000 [00:00<00:00, 6964.22it/s]


Similarity fn: wpmi, D_probe: imagenet_broden
Top 1 acc: 86.90%, Top 5 acc: 98.10%
Mean rank of correct class: 1.99, Median rank of correct class: 1
AUC: 0.7270
Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 1000/1000 [00:00<00:00, 1393.59it/s]


torch.Size([1000, 1000])
Similarity fn: soft_wpmi, D_probe: cifar100_train
Top 1 acc: 46.30%, Top 5 acc: 79.40%
Mean rank of correct class: 8.61, Median rank of correct class: 2
AUC: 0.6673


100%|██████████| 1000/1000 [00:00<00:00, 1180.32it/s]


torch.Size([1000, 1000])
Similarity fn: soft_wpmi, D_probe: broden
Top 1 acc: 70.70%, Top 5 acc: 90.00%
Mean rank of correct class: 4.80, Median rank of correct class: 1
AUC: 0.7856


100%|██████████| 1000/1000 [00:00<00:00, 1344.09it/s]


torch.Size([1000, 1000])
Similarity fn: soft_wpmi, D_probe: imagenet_val
Top 1 acc: 95.50%, Top 5 acc: 98.90%
Mean rank of correct class: 1.18, Median rank of correct class: 1
AUC: 0.9208


100%|██████████| 1000/1000 [00:00<00:00, 1253.33it/s]


torch.Size([1000, 1000])
Similarity fn: soft_wpmi, D_probe: imagenet_broden
Top 1 acc: 95.40%, Top 5 acc: 99.00%
Mean rank of correct class: 1.19, Median rank of correct class: 1
AUC: 0.9166
