In [1]:
import open_clip
import torch
from PIL import Image
from tqdm import tqdm

from pathlib import Path


In [2]:
# open_clip.list_pretrained()

In [3]:
# ! wget -O master.zip https://github.com/alexeygrigorev/clothing-dataset-small/archive/master.zip && unzip master.zip && rm master.zip && mv clothing-dataset-small-master data > /dev/null

In [4]:
model_name = "ViT-H-14" # "ViT-H-14"
weights = "laion2b_s32b_b79k" # "laion2b_s32b_b79k"

In [5]:
classes = [x.stem for x in Path("data/train").glob("*")]
print(classes)

['shoes', 'pants', 'dress', 'shirt', 'outwear', 't-shirt', 'skirt', 'shorts', 'hat', 'longsleeve']


In [6]:
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=weights)
tokenizer = open_clip.get_tokenizer(model_name)

In [7]:
text = tokenizer([f"{x} clothing item" for x in classes])

with torch.no_grad(), torch.cuda.amp.autocast():
    
    text_features = model.encode_text(text)
    text_features /= text_features.norm(dim=-1, keepdim=True)


In [8]:
validation_files = list(Path("data/validation").glob("*/*"))
validation_labels = [x.parent.stem for x in validation_files]

In [9]:
validation_predictions = []

with torch.no_grad(), torch.cuda.amp.autocast():
    
    for image in tqdm(validation_files):
            
        image = preprocess(
                    Image.open(image).convert("RGB")
                ).unsqueeze(0)
        
        image_features = model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1).squeeze().tolist()
        
        validation_predictions.append(sorted(list(zip(text_probs, classes)), reverse=True))


        

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 341/341 [03:16<00:00,  1.74it/s]


In [None]:
validation_top_prediction = [x[0][1] for x in validation_predictions]

In [14]:
validation_predictions[0]

[(0.9992687106132507, 'shoes'),
 (0.0005700313486158848, 'outwear'),
 (0.000109674314444419, 'shorts'),
 (3.829190609394573e-05, 'hat'),
 (1.1568498848646414e-05, 'pants'),
 (9.062980552698718e-07, 'dress'),
 (8.301198590743297e-07, 'skirt'),
 (2.8311593069929586e-08, 'longsleeve'),
 (1.5165813493922542e-08, 'shirt'),
 (1.214558320583592e-08, 't-shirt')]

In [11]:
accuracy = sum([int(y == y_hat) for y, y_hat in zip(validation_labels, validation_top_prediction)]) / len(validation_labels)
print(f"Validation zero shot accuracy = {accuracy}")

Validation zero shot accuracy = 0.9002932551319648


In [12]:
# 84% + label_update -> 90% + larger_model -> 
# Supervised xception validation accuracy -> https://github.com/alexeygrigorev/mlbookcamp-code/blob/master/chapter-07-neural-nets/07-neural-nets-test.ipynb 