In [1]:
import os
import json
#change virtually to parent dir
os.chdir("..")
import torch

import open_clip
import data_utils
import utils

In [2]:
device = "cuda"
dataset_name = "imagenet_val"
batch_size = 128
save_dir = 'saved_activations'
clip_name = "ViT-L-16-SigLIP-384"#"ViT-SO400M-14-SigLIP-384"
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(clip_name.split("_")[-1], pretrained=utils.CN_TO_CHECKPOINT[clip_name],
                                                                       device=device)
tokenizer = open_clip.get_tokenizer(clip_name.split("_")[-1])

clip_data = data_utils.get_data(dataset_name, clip_preprocess)

In [3]:
with open('data/concept_sets/imagenet_labels_clean.txt', 'r') as f: 
    class_names = (f.read()).split('\n')

clip_save_name = "{}/{}_{}.pt".format(save_dir, dataset_name, clip_name.replace('/', ''))
utils.save_clip_image_features(clip_model, clip_data, clip_save_name, batch_size, device)
clip_image_features = torch.load(clip_save_name, map_location=device).float()

In [4]:
num_classes = max(clip_data.targets) + 1

# Convert to one-hot encoding
one_hot_labels = torch.zeros((len(clip_data.targets), num_classes))
one_hot_labels[torch.arange(len(clip_data.targets)), clip_data.targets] = 1
one_hot_labels = one_hot_labels.to(device)

with open('data/concept_sets/imagenet_superclass_to_ids.json', 'r') as f:
    superclass_to_id = json.load(f)

new_labels = []
for sclass in superclass_to_id.keys():
    subclasses = superclass_to_id[sclass]
    #not using subclasses that cover more than 30% of the data
    if len(subclasses) > 300:
        print(sclass, len(subclasses))
        continue
    class_names.append(sclass.replace("_", " "))
    new_labels.append(torch.sum(torch.stack([one_hot_labels[:, i] for i in subclasses], dim=0), dim=0))
new_labels = torch.stack(new_labels, dim=1)
print(one_hot_labels.shape, new_labels.shape)
one_hot_labels = torch.cat([one_hot_labels, new_labels], dim=1)

entity 1000
physical entity 997
object 958
whole 949
organism 410
animal 398
vertebrate 337
artifact 522
instrumentality 358
torch.Size([50000, 1000]) torch.Size([50000, 391])


In [5]:
clip_text_features = utils.get_clip_text_features(clip_model, tokenizer(class_names).to(device)).float()
with torch.no_grad():
    clip_image_features /= clip_image_features.norm(dim=-1, keepdim=True)
    clip_text_features /= clip_text_features.norm(dim=-1, keepdim=True)
    clip_feats = (clip_image_features @ clip_text_features.T)

In [6]:
def get_loss(a=1, b=0):
    loss_fn = torch.nn.BCELoss()
    with torch.no_grad():
        outs = torch.sigmoid(a*(clip_feats+b))
        loss = loss_fn(outs, one_hot_labels)
        return loss.cpu()

In [7]:
a_values = [0.1, 0.25, 0.5, 0.75, 1]+[2*n for n in range(76)]
b_values = [0.01*n for n in range(-100, 101)]

best_loss = torch.inf
for a in a_values:
    for b in b_values:
        loss = get_loss(a, b)
        if loss < best_loss:
            best_loss = loss
            print("a={:.2f}, b={:.2f}, loss={:.5f}".format(a, b, loss))

a=0.10, b=-1.00, loss=0.64293
a=0.25, b=-1.00, loss=0.57265
a=0.50, b=-1.00, loss=0.46869
a=0.75, b=-1.00, loss=0.38047
a=1.00, b=-1.00, loss=0.30676
a=2.00, b=-1.00, loss=0.12546
a=4.00, b=-1.00, loss=0.03004
a=6.00, b=-1.00, loss=0.02347
a=6.00, b=-0.99, loss=0.02337
a=6.00, b=-0.98, loss=0.02328
a=6.00, b=-0.97, loss=0.02319
a=6.00, b=-0.96, loss=0.02312
a=6.00, b=-0.95, loss=0.02305
a=6.00, b=-0.94, loss=0.02299
a=6.00, b=-0.93, loss=0.02295
a=6.00, b=-0.92, loss=0.02291
a=6.00, b=-0.91, loss=0.02289
a=6.00, b=-0.90, loss=0.02288
a=8.00, b=-0.73, loss=0.02276
a=8.00, b=-0.72, loss=0.02265
a=8.00, b=-0.71, loss=0.02255
a=8.00, b=-0.70, loss=0.02246
a=8.00, b=-0.69, loss=0.02240
a=8.00, b=-0.68, loss=0.02235
a=8.00, b=-0.67, loss=0.02233
a=8.00, b=-0.66, loss=0.02233
a=10.00, b=-0.58, loss=0.02225
a=10.00, b=-0.57, loss=0.02211
a=10.00, b=-0.56, loss=0.02199
a=10.00, b=-0.55, loss=0.02189
a=10.00, b=-0.54, loss=0.02183
a=10.00, b=-0.53, loss=0.02180
a=12.00, b=-0.48, loss=0.02176
a=1