In [1]:
import os
import pandas as pd
from PIL import Image
import numpy as np

import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "5"

In [2]:
from dataset_utils import load_dataset

dataset_name = "mimic-cxr"
df, label_names = load_dataset(dataset_name)

In [3]:
from clip_utils import load_clip

model, transform, clip_name = load_clip("ViT-B/16")



In [4]:
from clip_utils import get_clip_img_caption_features

img_features, caption_features = get_clip_img_caption_features(df, model, transform, dataset_name, bs=128)

In [5]:
from zero_shot_utils import calc_binary_acc, create_label_encs, calc_accuracies

In [9]:
label_kwargs = (model, label_names)
acc_func_args = (img_features, df["labels"].to_numpy(), label_names)

In [10]:
prefixes = ["A picture of "]
#prefixes = [ "stock image "]
accs = calc_binary_acc(*acc_func_args, 
                       *create_label_encs(*label_kwargs, prefixes, use_multi_label_setting=False, use_norm=False))
torch.mean(torch.stack(accs)).item()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/377095 [00:00<?, ?it/s]

0.09709356725215912

In [11]:
prefixes = ["A photo of "]
#prefixes = [ "stock image "]
accs = calc_binary_acc(*acc_func_args, 
                       *create_label_encs(*label_kwargs, prefixes, use_multi_label_setting=False, use_norm=False))
torch.mean(torch.stack(accs)).item()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/377095 [00:00<?, ?it/s]

0.09257160127162933

In [12]:
prefixes = ["A photo of a people. "]
#prefixes = [ "stock image "]
accs = calc_binary_acc(*acc_func_args, 
                       *create_label_encs(*label_kwargs, prefixes, use_multi_label_setting=False, use_norm=True))
torch.mean(torch.stack(accs)).item()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/377095 [00:00<?, ?it/s]

0.044776707887649536

In [13]:
prefixes = ["A photo of ", "A ", "A picture of ", "picture ", "A picture of a person. ", "flickr. ", "stock image. "]

pos_encs = []
neg_encs = []

for prefix in prefixes:
    pos_enc, neg_enc = create_label_encs(*label_kwargs, [prefix], use_multi_label_setting=False, use_norm=False)
    pos_encs.append(pos_enc)
    neg_encs.append(neg_enc)

accs = calc_binary_acc(*acc_func_args, pos_encs, neg_encs)
torch.mean(torch.stack(accs)).item()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/377095 [00:00<?, ?it/s]

0.12272033095359802

In [14]:
import string    
import random # define the random module  
S = 10  # number of characters in the string.  
# call random.choices() string module to find the string in Uppercase + numeric data.  
prefixes = [''.join(random.choices(string.ascii_uppercase + string.digits, k = S)) + ". " for _ in range(5)]

pos_encs = []
neg_encs = []

for prefix in prefixes:
    pos_enc, neg_enc = create_label_encs(*label_kwargs, [prefix], use_multi_label_setting=False, use_norm=False)
    pos_encs.append(pos_enc)
    neg_encs.append(neg_enc)

accs = calc_binary_acc(*acc_func_args, pos_encs, neg_encs)
torch.mean(torch.stack(accs)).item()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/377095 [00:00<?, ?it/s]

0.0498608760535717

In [15]:
prefixes = ["A photo of ", "a ", "A painting of ", "flickr. ", "Artstation ", "stock image "]
#prefixes = [ "stock image "]
accs = calc_binary_acc(*acc_func_args, *create_label_encs(*label_kwargs, prefixes,
                                                        use_multi_label_setting=False, use_norm=True))
torch.mean(torch.stack(accs)).item()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/377095 [00:00<?, ?it/s]

0.20875561237335205

In [16]:
prefixes = ["A photo of ", "a ", "A painting of ", "flickr. ", "Artstation ", "stock image "]
#prefixes = [ "stock image "]
accs = calc_binary_acc(*acc_func_args, *create_label_encs(*label_kwargs, prefixes,
                                                        use_multi_label_setting=False))
torch.mean(torch.stack(accs)).item()

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/377095 [00:00<?, ?it/s]

0.20875561237335205

In [17]:
print(calc_accuracies(*acc_func_args, create_label_encs(*label_kwargs, prefixes, use_multi_label_setting=True)))

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([14, 512])


TypeError: calc_accuracies() takes 2 positional arguments but 4 were given

In [None]:
pos_encs, neg_encs = create_label_encs(*label_kwargs, ["A photo of "], use_multi_label_setting=False)
baseline_enc = get_text_features(["A photo of "], model, device, None, batch_size=1, save=False)[0]

In [None]:
pos_encs = create_label_encs(*label_kwargs, ["A photo with "], use_multi_label_setting=1)
neg_encs = create_label_encs(*label_kwargs, ["A photo without "], use_multi_label_setting=1)
accs = calc_binary_acc(*acc_func_args, pos_encs, neg_encs)
print(torch.stack(accs).mean())

In [None]:
pos_encs = create_label_encs(*label_kwargs, ["A "], use_multi_label_setting=1)
neg_encs = create_label_encs(*label_kwargs, ["Not a "], use_multi_label_setting=1)
accs = calc_binary_acc(*acc_func_args, pos_encs, neg_encs)
print(torch.stack(accs).mean())