In [1]:
import torch
import clip
from PIL import Image
import numpy as np
import json
from tqdm import tqdm
import pandas as pd
import os
import sys

In [2]:
myseed = 53  # set a random seed for reproducibility
np.random.seed(myseed)
torch.manual_seed(myseed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(myseed)
    torch.cuda.manual_seed_all(myseed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
image_path = "./hw3_data/p1_data/val" # sys.argv[1]
id2label_path = "./hw3_data/p1_data/id2label.json" # sys.argv[2]
output_path = "./p1_predict.csv" # sys.argv[3]

### Main

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model, preprocess = clip.load("ViT-B/32", device = device)

image_names = [f for f in os.listdir(image_path) if f.endswith(".png")]
# print(image_names)

In [5]:
with open(id2label_path, 'r') as file:
    id2label = json.load(file)
labels = [v for k, v in id2label.items()]
# print(labels)

In [6]:
values = []
indices = []
for image in tqdm(image_names):
    image = preprocess(Image.open(os.path.join(image_path, image))).unsqueeze(0).to(device)
    text = torch.cat([clip.tokenize(f"a photo of a {c}") for c in labels]).to(device)
    # text = torch.cat([clip.tokenize(f"This is a photo of {c}") for c in labels]).to(device)
    # text = torch.cat([clip.tokenize(f"This is not a photo of {c}") for c in labels]).to(device)
    # text = torch.cat([clip.tokenize(f"No {c}, no score.") for c in labels]).to(device)
    
    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)

    image_features /= image_features.norm(dim = -1, keepdim = True)
    text_features /= text_features.norm(dim = -1, keepdim = True)
    
    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    value, index = similarity[0].topk(1) # topk(n): Pick the top n most similar labels for the image

    # Print the result
#     print("\nTop predictions:\n")
#     for v, i in zip(value, index):
#         print(f"{labels[i]:>16s}: {100 * v.item():.2f}%")
    
    values.append(value.item())
    indices.append(index)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2500/2500 [01:33<00:00, 26.86it/s]


In [7]:
predict = []
for index in indices:
    predict.append(index.item())
    # print(index.item())
    # print(f"{labels[index]:>16s}")

In [8]:
df = pd.DataFrame({'filename': image_names, 'label': predict})

In [9]:
df.to_csv(output_path, index = False)

In [10]:
corr = 0
for n, l in zip(image_names, predict):
    if int(n.split('_')[0]) == l:
        corr += 1
print(f'accuracy: {corr / len(image_names)}')

accuracy: 0.7112
