In [1]:
import os
import pandas as pd
from PIL import Image
import numpy as np

os.environ["CUDA_VISIBLE_DEVICES"] = "7"

In [2]:
from dataset_utils import load_dataset
dataset = "flickr30k"
img_paths, img_ids, captions, label_strs, train_ids, val_ids, test_ids, label_names = load_dataset(dataset)

[nltk_data] Downloading package stopwords to
[nltk_data]     /srv/home/8wiehe/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /srv/home/8wiehe/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
import clip
import torch
print(clip.available_models())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_name = "ViT-B/16"
model, transform = clip.load(clip_name, jit=False, device=device)
load_path = None
load_path = os.path.join(dataset, "feats")
clip_load_name = clip_name.replace("/", "_")
clip_load_name



['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']


'ViT-B_16'

In [4]:
from utils import get_image_features, get_text_features

test_sims = False
idx = 5002
personal_caption = ""

if test_sims:
    img_name, img_captions, img = get_data(idx)
    if len(personal_caption) > 0:
        img_captions.append(personal_caption)


    img_features = get_image_features([img], model, transform, device, 
                                      load_path, batch_size=16, save=False)
    caption_features = get_text_features(img_captions, model, device, 
                                         load_path, batch_size=16, save=False)

    sims = torch.cosine_similarity(img_features, caption_features)
    top_k = sims.topk(k=5)
    indices = top_k.indices
    vals = top_k.values

    Image.open(img).show()
    for idx, val in zip(indices, vals):
        print(val, img_captions[idx])

In [5]:
from utils import get_image_features, get_text_features

os.makedirs(load_path, exist_ok=True)
bs = 32

img_features = get_image_features(img_paths, model, transform, device, 
                                  os.path.join(load_path, f"{clip_load_name}_img_feats.pt"), 
                                  batch_size=bs, save=True)
caption_features = get_text_features(captions["caption_text"], model, device, 
                                     os.path.join(load_path, f"{clip_load_name}_caption_feats.pt"), 
                                     batch_size=bs, save=True)

In [6]:
import torch
import gc
torch.cuda.empty_cache()
gc.collect()

20

In [18]:
from zero_shot_utils import calc_binary_acc, create_label_encs, calc_accuracies

In [8]:
label_kwargs = (model, label_names)
acc_func_args = (img_features, label_strs, label_names)

In [10]:
prefixes = ["A picture of "]
#prefixes = [ "stock image "]
accs = calc_binary_acc(*acc_func_args, 
                       *create_label_encs(*label_kwargs, prefixes, use_multi_label_setting=False, use_norm=False))
torch.mean(torch.stack(accs)).item()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/31783 [00:00<?, ?it/s]

0.7251743674278259

In [11]:
prefixes = ["A photo of "]
#prefixes = [ "stock image "]
accs = calc_binary_acc(*acc_func_args, 
                       *create_label_encs(*label_kwargs, prefixes, use_multi_label_setting=False, use_norm=False))
torch.mean(torch.stack(accs)).item()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/31783 [00:00<?, ?it/s]

0.7638537287712097

In [12]:
prefixes = ["A photo of a people. "]
#prefixes = [ "stock image "]
accs = calc_binary_acc(*acc_func_args, 
                       *create_label_encs(*label_kwargs, prefixes, use_multi_label_setting=False, use_norm=True))
torch.mean(torch.stack(accs)).item()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/31783 [00:00<?, ?it/s]

0.7734894156455994

In [13]:
prefixes = ["A photo of ", "A ", "A picture of ", "picture ", "A picture of a person. ", "flickr. ", "stock image. "]

pos_encs = []
neg_encs = []

for prefix in prefixes:
    pos_enc, neg_enc = create_label_encs(*label_kwargs, [prefix], use_multi_label_setting=False, use_norm=False)
    pos_encs.append(pos_enc)
    neg_encs.append(neg_enc)

accs = calc_binary_acc(*acc_func_args, pos_encs, neg_encs)
torch.mean(torch.stack(accs)).item()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/31783 [00:00<?, ?it/s]

0.7885488271713257

In [14]:
import string    
import random # define the random module  
S = 10  # number of characters in the string.  
# call random.choices() string module to find the string in Uppercase + numeric data.  
prefixes = [''.join(random.choices(string.ascii_uppercase + string.digits, k = S)) + ". " for _ in range(5)]

pos_encs = []
neg_encs = []

for prefix in prefixes:
    pos_enc, neg_enc = create_label_encs(*label_kwargs, [prefix], use_multi_label_setting=False, use_norm=False)
    pos_encs.append(pos_enc)
    neg_encs.append(neg_enc)

accs = calc_binary_acc(*acc_func_args, pos_encs, neg_encs)
torch.mean(torch.stack(accs)).item()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/31783 [00:00<?, ?it/s]

0.6365594863891602

In [15]:
prefixes = ["A photo of ", "a ", "A painting of ", "flickr. ", "Artstation ", "stock image "]
#prefixes = [ "stock image "]
accs = calc_binary_acc(*acc_func_args, *create_label_encs(*label_kwargs, prefixes,
                                                        use_multi_label_setting=False, use_norm=True))
torch.mean(torch.stack(accs)).item()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/31783 [00:00<?, ?it/s]

0.8278830647468567

In [16]:
prefixes = ["A photo of ", "a ", "A painting of ", "flickr. ", "Artstation ", "stock image "]
#prefixes = [ "stock image "]
accs = calc_binary_acc(*acc_func_args, *create_label_encs(*label_kwargs, prefixes,
                                                        use_multi_label_setting=False))
torch.mean(torch.stack(accs)).item()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/31783 [00:00<?, ?it/s]

0.8278830647468567

In [19]:
print(calc_accuracies(*acc_func_args, create_label_encs(*label_kwargs, prefixes, use_multi_label_setting=True)))

NameError: name 'bs' is not defined

In [None]:
pos_encs, neg_encs = create_label_encs(*label_kwargs, ["A photo of "], use_multi_label_setting=False)
baseline_enc = get_text_features(["A photo of "], model, device, None, batch_size=1, save=False)[0]

In [None]:
pos_encs = create_label_encs(*label_kwargs, ["A photo with "], use_multi_label_setting=1)
neg_encs = create_label_encs(*label_kwargs, ["A photo without "], use_multi_label_setting=1)
accs = calc_binary_acc(*acc_func_args, pos_encs, neg_encs)
print(torch.stack(accs).mean())

In [None]:
pos_encs = create_label_encs(*label_kwargs, ["A "], use_multi_label_setting=1)
neg_encs = create_label_encs(*label_kwargs, ["Not a "], use_multi_label_setting=1)
accs = calc_binary_acc(*acc_func_args, pos_encs, neg_encs)
print(torch.stack(accs).mean())