In [22]:
import os 
from tqdm import tqdm
import numpy as np
import clip
import torch 
from PIL import Image
import pandas as pd

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device)

100%|███████████████████████████████████████| 338M/338M [01:12<00:00, 4.86MiB/s]


In [16]:
def predict_with_pil_image(image_input, clip_labels, top_k=5):
    image_input = preprocess(image_input).unsqueeze(0).to(device)
    text_inputs = torch.cat([clip.tokenize(cl) for cl in clip_labels]).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image_input)
        text_features = model.encode_text(text_inputs)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity_score_probability = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    percentages, indices = similarity_score_probability[0].topk(5)
    similarity_score_probability = percentages.numpy() * 100.0 
    indices = indices.numpy()
    return similarity_score_probability, indices


In [17]:
image = Image.open('clips-data-2020/clips/clips-25001.png')


In [18]:
raw_labels = list(range(100))
clip_labels = ['{} number of paper clips where some paper clips are partially occluded'.format(label) for label in raw_labels]


In [19]:
percentages, indices = predict_with_pil_image(image, clip_labels)
print("\nTop 5 predictions:\n")
for percent, index in zip(percentages, indices):
    print(f"{raw_labels[index]}: {percent:.2f}%")


Top 5 predictions:

0: 3.36%
4: 2.78%
2: 2.77%
3: 2.46%
5: 2.42%


In [20]:
medium_image = Image.open('clips-data-2020/clips/clips-25086.png')
hard_image = Image.open('clips-data-2020/clips/clips-25485.png')


In [None]:
testing_data = pd.read_csv('train.csv')
errors = []
for idx, row in tqdm(testing_data.iterrows(), total=1000):
    image = Image.open(
        'clips-data-2020/clips/clips-{}.png'.format(row['id'])
    )
    percentages, indices = predict_with_pil_image(image, clip_labels, 1)
    errors.append(abs(row['clip_count'] - raw_labels[indices[0]]))
    if idx == 1000:
        break
print('{} average count error'.format(np.mean(errors)))

  2%|▋                                      | 19/1000 [02:02<1:54:56,  7.03s/it]