In [None]:
from datasets import load_dataset
import torch
from PIL import Image
import open_clip

ds = load_dataset("adhamelarabawy/fashion_human_classification").shuffle()["train"]
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')

In [None]:
num_samples = 1000
samples = ds.shuffle().select(range(num_samples))

def add_score(sample):
    image = preprocess(sample["image"]).unsqueeze(0)
    text = tokenizer(["human"])

    with torch.no_grad(), torch.cuda.amp.autocast():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)

        text_probs = (100.0 * image_features @ text_features.T)# .softmax(dim=-1)
    sample["score"] = text_probs.item()
    return sample

samples = samples.map(add_score)

In [None]:
# show the distribution of scores for the samples, categorized by label
import matplotlib.pyplot as plt
import numpy as np

labels = np.array(samples["has_human"])
scores = np.array(samples["score"])

# fig = plt.hist(scores[labels==False], bins=100)
plt.hist(scores[labels==False], bins=100, alpha=0.5, label='No Human')
plt.hist(scores[labels==True], bins=100, alpha=0.5, label='Human')

plt.xlabel('CLIP Similarity Score')
plt.ylabel('Frequency')
plt.title('CLIP Similarity Scores for Human vs. No Human Images (n=1000, ViT-g-14)')
plt.legend()

thresh = ((scores[labels==True].mean()) + (scores[labels==False].mean())) / 2
thresh = 4.5

# draw vertical line at thresh
plt.axvline(x=thresh, color='k', linestyle='--')
# draw vercial line at each mean
plt.axvline(x=scores[labels==True].mean(), color='gray', linestyle='--')
plt.axvline(x=scores[labels==False].mean(), color='gray', linestyle='--')

tp = np.sum(scores[labels==True] > thresh)
tn = np.sum(scores[labels==False] < thresh)
fp = np.sum(scores[labels==False] > thresh)
fn = np.sum(scores[labels==True] < thresh)

print(f"Accuracy: {(tp+tn)/(tp+tn+fp+fn)}")
print(f"Precision: {tp/(tp+fp)}")
print(f"Recall: {tp/(tp+fn)}")
print(f"F1: {2*tp/(2*tp+fp+fn)}")
print(f"Keep Rate: {1 - len(scores[labels==True] > thresh)/len(scores)}")
# plt.savefig("results/vanilla_scores.png")