In [1]:
from PIL import Image
import requests
from transformers import CLIPProcessor, CLIPModel
import torch
import pandas as pd
from src import util
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", 
                                  cache_dir="model", local_files_only=True)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", 
                                          cache_dir="model", local_files_only=True)
model = model.to(device)



In [3]:
# number of parameters
print(f"Number of parameters: {model.num_parameters()}")

Number of parameters: 151277313


In [4]:
data_dir = "data"
articles = pd.read_csv(f"{data_dir}/articles.csv")
# customers = pd.read_csv(f"{data_dir}/customers.csv")
# transactions = pd.read_csv(f"{data_dir}/transactions_train.csv")

In [5]:
articles.head()

Unnamed: 0,article_id,product_code,prod_name,product_type_no,product_type_name,product_group_name,graphical_appearance_no,graphical_appearance_name,colour_group_code,colour_group_name,...,department_name,index_code,index_name,index_group_no,index_group_name,section_no,section_name,garment_group_no,garment_group_name,detail_desc
0,108775015,108775,Strap top,253,Vest top,Garment Upper body,1010016,Solid,9,Black,...,Jersey Basic,A,Ladieswear,1,Ladieswear,16,Womens Everyday Basics,1002,Jersey Basic,Jersey top with narrow shoulder straps.
1,108775044,108775,Strap top,253,Vest top,Garment Upper body,1010016,Solid,10,White,...,Jersey Basic,A,Ladieswear,1,Ladieswear,16,Womens Everyday Basics,1002,Jersey Basic,Jersey top with narrow shoulder straps.
2,108775051,108775,Strap top (1),253,Vest top,Garment Upper body,1010017,Stripe,11,Off White,...,Jersey Basic,A,Ladieswear,1,Ladieswear,16,Womens Everyday Basics,1002,Jersey Basic,Jersey top with narrow shoulder straps.
3,110065001,110065,OP T-shirt (Idro),306,Bra,Underwear,1010016,Solid,9,Black,...,Clean Lingerie,B,Lingeries/Tights,1,Ladieswear,61,Womens Lingerie,1017,"Under-, Nightwear","Microfibre T-shirt bra with underwired, moulde..."
4,110065002,110065,OP T-shirt (Idro),306,Bra,Underwear,1010016,Solid,10,White,...,Clean Lingerie,B,Lingeries/Tights,1,Ladieswear,61,Womens Lingerie,1017,"Under-, Nightwear","Microfibre T-shirt bra with underwired, moulde..."


In [6]:
# map from article_id to df index
article_id_to_idx = {article_id: idx for idx, article_id in enumerate(articles["article_id"])}

# get all classes of the dataframe
class_names = articles.columns.tolist()
label_names = dict()
for class_name in class_names:
    label_names[class_name] = articles[class_name].unique()
    print(f"{class_name}: {len(label_names[class_name])}")
article_ids = label_names["article_id"]

article_id: 105542
product_code: 47224
prod_name: 45875
product_type_no: 132
product_type_name: 131
product_group_name: 19
graphical_appearance_no: 30
graphical_appearance_name: 30
colour_group_code: 50
colour_group_name: 50
perceived_colour_value_id: 8
perceived_colour_value_name: 8
perceived_colour_master_id: 20
perceived_colour_master_name: 20
department_no: 299
department_name: 250
index_code: 10
index_name: 10
index_group_no: 5
index_group_name: 5
section_no: 57
section_name: 56
garment_group_no: 21
garment_group_name: 21
detail_desc: 43405


In [7]:
dataset = util.ImageDataset(data_dir="data", article_ids=article_ids, processor=processor)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=False)

In [8]:
# get the first batch
images, image_ids = next(iter(dataloader))

product_group_names = label_names["product_group_name"]
text_inputs = processor(text=[f"A photo of a {label}" for label in product_group_names], return_tensors="pt", padding=True)
text_inputs = text_inputs.to(device)
images = images.to(device)

with torch.no_grad():
    outputs = model(**text_inputs, pixel_values=images)

logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)

probs = probs.to("cpu")

values, indices = torch.topk(probs, k=5, dim=1)

In [9]:
top1_correct = 0
top5_correct = 0

for i in range(len(images)):
    true_label = articles.loc[article_id_to_idx[image_ids[i].item()]]['product_group_name']
    print(f"Image id: {image_ids[i].item()} \t True label: {true_label}")
    for j in indices[i]:
        print(f"{product_group_names[j]}: {probs[i][j]:.4f}")
        if product_group_names[j] == true_label:
            top5_correct += 1
            if j == indices[i][0]:
                top1_correct += 1
    print()

Image id: 108775015 	 True label: Garment Upper body
Garment Upper body: 0.4556
Nightwear: 0.1890
Garment Lower body: 0.1040
Garment Full body: 0.0965
Underwear/nightwear: 0.0912

Image id: 108775044 	 True label: Garment Upper body
Garment Upper body: 0.5538
Nightwear: 0.1782
Garment Full body: 0.1276
Underwear/nightwear: 0.0744
Garment Lower body: 0.0413

Image id: 108775051 	 True label: Garment Upper body
Garment Upper body: 0.6450
Nightwear: 0.1477
Underwear/nightwear: 0.0748
Garment Full body: 0.0575
Garment Lower body: 0.0457

Image id: 110065001 	 True label: Underwear
Underwear: 0.2966
Garment Upper body: 0.2616
Underwear/nightwear: 0.2394
Swimwear: 0.0763
Nightwear: 0.0463

Image id: 110065002 	 True label: Underwear
Underwear: 0.3397
Garment Upper body: 0.3045
Underwear/nightwear: 0.1608
Swimwear: 0.0546
Cosmetic: 0.0497

Image id: 110065011 	 True label: Underwear
Garment Upper body: 0.3652
Underwear: 0.2047
Underwear/nightwear: 0.1877
Swimwear: 0.0961
Cosmetic: 0.0426

Ima

In [10]:
print(f"Top 1 accuracy: {top1_correct / len(images)}")
print(f"Top 5 accuracy: {top5_correct / len(images)}")

Top 1 accuracy: 0.609375
Top 5 accuracy: 0.9375
