<a href="https://colab.research.google.com/github/R12942159/NTU_DLCV/blob/Hw3/p1_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
import os
import csv
import json
import clip
import torch
import argparse
import numpy as np

from PIL import Image
from tqdm import tqdm
from google.colab import drive

#### Get cuda from GPU device.

In [15]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using: {device}")

Using: cuda


#### Construct image path

In [16]:
# parser = argparse.ArgumentParser()
# parser.add_argument("path1")
# parser.add_argument("path2")
# parser.add_argument("path3")
# args = parser.parse_args()

# val_path = args.path1
# id2label_path = args.path2
# output_path = args.path3
val_path = '/content/hw3_data/p1_data/val'
id2label_path = '/content/hw3_data/p1_data/id2label.json'
output_path = '/content/hw3_data/pred.csv'

In [17]:
img_paths = [os.path.join(val_path, i) for i in os.listdir(val_path) if i.endswith('.png')]
id2label = json.load(open(id2label_path, 'r'))
labels = [l for _, l in id2label.items()]

#### Training with ["RN50"(49.88%), "RN101"(50.92%), "RN50x4"(43.68%), "RN50x16"(58.72%), "RN50x64"(62.05%), "ViT-B/32"(71.16%), "ViT-B/16"(74.6%), "ViT-L/14"(81.44%), "ViT-L/14@336px"(80.68%)]

In [18]:
def train(model_name = "ViT-L/14"):
    model, transform = clip.load(name=model_name, device=device, jit=False)

    ID, LABEL= [], []

    for path in tqdm(img_paths):
        img = Image.open(path).convert('RGB')
        img = transform(img).unsqueeze(0).to(device) #  # resize to (1, 3, 448, 448), expected 4D input
        prompt_text = torch.cat(
            [clip.tokenize(f'A photo of a {i}.', context_length=77, truncate=False) for i in labels]
        ).to(device)

        with torch.no_grad():
            img_features = model.encode_image(img) # (1, 1024)
            text_features = model.encode_text(prompt_text) # (50, 1024)

        # Pick the top 1 most similar labels for the image
        img_features /= img_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        similarity = (100. * img_features @ text_features.T).softmax(dim=-1)
        label = similarity[0].argmax()

        ID.append(path.split('/')[-1])
        LABEL.append(label.item()) # torch.Tensor -> int

    # Write output csv file
    with open(output_path, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(('filename', 'label'))
        for row in zip(ID, LABEL):
            writer.writerow(row)

In [19]:
train()

100%|██████████| 2500/2500 [03:27<00:00, 12.05it/s]
