In [1]:
from PIL import Image
import requests
from transformers import CLIPProcessor, CLIPModel
import torch
import pandas as pd
from src import util
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import os
from peft import LoraConfig, get_peft_model
import wandb

In [2]:
# set random seed 42
torch.manual_seed(42)

<torch._C.Generator at 0x7ff2491c9030>

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", 
                                  cache_dir="model", local_files_only=True)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", 
                                          cache_dir="model", local_files_only=True)
model = model.to(device)



In [4]:
data_dir = "data"
# articles_filtered: remove articles with no images
articles = pd.read_csv(f"{data_dir}/articles_filtered.csv")

In [5]:
# map from article_id to df index
article_id_to_idx = {article_id: idx for idx, article_id in enumerate(articles["article_id"])}

# get all classes of the dataframe
class_names = articles.columns.tolist()
label_names = dict()
label_names_to_idx = dict()
for class_name in class_names:
    label_names[class_name] = articles[class_name].unique()
    label_names_to_idx[class_name] = {label_name: idx for idx, label_name in enumerate(label_names[class_name])}

article_ids = label_names["article_id"]
# selected_class_names = ["product_group_name", "product_type_name", "graphical_appearance_name", "colour_group_name", "perceived_colour_value_name", "perceived_colour_master_name", "department_name", "index_name", "index_group_name", "section_name", "garment_group_name"]
selected_class_names = ["product_type_name", "graphical_appearance_name"]

In [7]:
# grouped by product_code
grouped = articles.groupby("product_code")
groups = [group for _, group in grouped]

# split 0.6/0.2/0.2
train_groups, test_groups = train_test_split(groups, test_size=0.4, random_state=42) 
val_groups, test_groups = train_test_split(test_groups, test_size=0.5, random_state=42) 

train_df = pd.concat(train_groups)
val_df = pd.concat(val_groups)
test_df = pd.concat(test_groups)

print(f"{len(train_df)=} {len(val_df)=} {len(test_df)=}")

len(train_df)=63272 len(val_df)=20882 len(test_df)=20946


In [8]:
train_paths, train_labels = util.get_image_paths_and_labels_from_df(train_df, data_dir)
val_paths, val_labels = util.get_image_paths_and_labels_from_df(val_df, data_dir)
test_paths, test_labels = util.get_image_paths_and_labels_from_df(test_df, data_dir)

In [9]:
class MultiOutputLayer(torch.nn.Module):
    def __init__(self, input_size, inter_size, output_size):
        super(MultiOutputLayer, self).__init__()
        self.fc1 = torch.nn.Linear(input_size, inter_size)
        self.fc2 = torch.nn.Linear(inter_size, output_size)
        self.dropout = torch.nn.Dropout(0.5)
        self.act = torch.nn.SiLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [10]:
class MultiOutputClipModel(torch.nn.Module):
    def __init__(self, clip_model, class_names, vision_hidden_size, inter_size, output_size):
        super(MultiOutputClipModel, self).__init__()
        self.clip_model = clip_model
        self.class_names = class_names
        self.output_layers = torch.nn.ModuleDict({
            class_name: MultiOutputLayer(vision_hidden_size, inter_size, output_size)
            for class_name in self.class_names
        })
    
    def forward(
        self,
        text_input_dict,
        pixel_values,
        # position_ids = None,
        output_attentions = None,
        output_hidden_states = None,
        return_dict = None,
    ):

        output_attentions = output_attentions if output_attentions is not None else self.clip_model.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.clip_model.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.clip_model.config.use_return_dict

        vision_outputs = self.clip_model.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        vision_embeds = vision_outputs[1]
        vision_embeds_dict = {
            class_name: output_layer(vision_embeds) 
                for class_name, output_layer in self.output_layers.items()
        }

        text_outputs_dict = {
            class_name: self.clip_model.text_model(
                input_ids=text_input_dict[class_name]["input_ids"],
                attention_mask=text_input_dict[class_name]["attention_mask"],
                # position_ids=position_ids,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            ) for class_name in self.class_names
        }

        text_embeds_dict = {
            class_name: self.clip_model.text_projection(text_outputs[1])
                for class_name, text_outputs in text_outputs_dict.items()
        }

        logits_per_image_dict = {
            class_name: vision_embeds_dict[class_name] @ text_embeds_dict[class_name].T
                for class_name in self.output_layers.keys()
        }

        return logits_per_image_dict

In [11]:
# custom criterion: cross entropy loss across all classes
class MultiOutputClipCriterion(torch.nn.Module):
    def __init__(self, class_names):
        super(MultiOutputClipCriterion, self).__init__()
        self.class_names = class_names
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, logits_dict, labels_dict):
        loss = 0
        for class_name in self.class_names:
            logits = logits_dict[class_name]
            labels = labels_dict[class_name]
            loss += self.criterion(logits, labels)
        return loss

In [12]:
train_dataset = util.ImageDataset(train_paths, processor)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)


In [14]:
# freeze all parameters in model

for param in model.parameters():
    param.requires_grad = False

In [15]:
mo_model = MultiOutputClipModel(model, selected_class_names, 768, 32, 512).to(device)
mo_model.train()

MultiOutputClipModel(
  (clip_model): CLIPModel(
    (text_model): CLIPTextTransformer(
      (embeddings): CLIPTextEmbeddings(
        (token_embedding): Embedding(49408, 512)
        (position_embedding): Embedding(77, 512)
      )
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-11): 12 x CLIPEncoderLayer(
            (self_attn): CLIPSdpaAttention(
              (k_proj): Linear(in_features=512, out_features=512, bias=True)
              (v_proj): Linear(in_features=512, out_features=512, bias=True)
              (q_proj): Linear(in_features=512, out_features=512, bias=True)
              (out_proj): Linear(in_features=512, out_features=512, bias=True)
            )
            (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn): QuickGELUActivation()
              (fc1): Linear(in_features=512, out_features=2048, bias=True)
              (fc2): Linear(in_features=2048, out_features=5

In [16]:
# show all trainable parameters in mo_model
for name, param in mo_model.named_parameters():
    if param.requires_grad:
        print(name)

output_layers.product_type_name.fc1.weight
output_layers.product_type_name.fc1.bias
output_layers.product_type_name.fc2.weight
output_layers.product_type_name.fc2.bias
output_layers.graphical_appearance_name.fc1.weight
output_layers.graphical_appearance_name.fc1.bias
output_layers.graphical_appearance_name.fc2.weight
output_layers.graphical_appearance_name.fc2.bias


In [17]:
# generate text input
text_input_dict = {
    class_name: processor(text=[f"A photo of a {label}" for label in label_names[class_name]], 
                          return_tensors="pt", padding=True).to(device)
    for class_name in selected_class_names
}

In [18]:
num_epochs = 1  # Adjust as needed
criteria = MultiOutputClipCriterion(class_names=selected_class_names)
optimizer = torch.optim.AdamW(mo_model.parameters(), lr=1e-4)
step = 0
wandb.init(project="clip-multi-output")

for epoch in range(num_epochs):
    mo_model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for images, image_ids in tqdm(train_dataloader):
        images = images.to(device)
        logits_per_image_dict = mo_model(pixel_values=images, text_input_dict=text_input_dict)

        # Get true labels from image_ids
        true_labels_dict = {
            class_name: [label_names_to_idx[class_name][articles.loc[article_id_to_idx[image_id.item()], class_name]] 
                       for image_id in image_ids]
            for class_name in selected_class_names
        }
        true_labels_dict = {class_name: torch.tensor(true_labels).to(device) 
                            for class_name, true_labels in true_labels_dict.items()}
        
        # Compute loss
        loss = criteria(logits_per_image_dict, true_labels_dict)
        total_loss += loss.item() * images.size(0)

        # Predictions and accuracy
        correct = 0
        total_samples += images.size(0)
        for class_name in selected_class_names:
            _, preds = torch.max(logits_per_image_dict[class_name], dim=1)
            correct += (preds == true_labels_dict[class_name]).sum().item()
        total_correct += correct

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # log the loss and accuracy to wandb
        wandb.log({"loss": loss.item(), "accuracy": correct / images.size(0) / len(selected_class_names)},
                  step=step)
        step += 1

    avg_loss = total_loss / total_samples / len(selected_class_names)
    accuracy = total_correct / total_samples / len(selected_class_names)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

    # Validate after each epoch
    # val_loss, val_accuracy = validate(model, val_dataloader, criteria, device, text_inputs, class_name)
    # print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

wandb.finish()

# Save the model
torch.save(mo_model.state_dict(), "model/2_output_clip_model-2.pth")

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33monjackay[0m ([33monjackay-kth-royal-institute-of-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 495/495 [35:22<00:00,  4.29s/it]


Epoch [1/1], Loss: 2.1472, Accuracy: 0.4863


VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁▃▄▄▅▅▅▅▆▆▆▆▇▇▆▆▆▇▇▇▇█▇▇▇▇▇▇▇█▇▇█▇██▇▇▇▇
loss,███▇▆▅▅▅▄▄▃▃▃▃▃▃▂▂▂▃▂▃▂▂▂▂▂▂▂▁▂▂▂▁▁▁▂▁▁▁

0,1
accuracy,0.5625
loss,3.68765


In [22]:
def validate(model, dataloader, criteria, device, text_inputs, class_names):
    model.eval()
    total_loss = 0.0
    total_correct = {class_name: 0 for class_name in class_names}
    total_samples = 0

    with torch.no_grad():
        for images, image_ids in tqdm(dataloader):
            images = images.to(device)
            logits_per_image_dict = model(pixel_values=images, text_input_dict=text_inputs)

            # Get true labels from image_ids
            true_labels_dict = {
                class_name: [label_names_to_idx[class_name][articles.loc[article_id_to_idx[image_id.item()], class_name]] 
                           for image_id in image_ids]
                for class_name in class_names
            }
            true_labels_dict = {class_name: torch.tensor(true_labels).to(device)
                                for class_name, true_labels in true_labels_dict.items()}
            
            # Compute loss
            loss = criteria(logits_per_image_dict, true_labels_dict)
            total_loss += loss.item() * images.size(0)

            # Predictions and accuracy
            total_samples += images.size(0)
            for class_name in class_names:
                _, preds = torch.max(logits_per_image_dict[class_name], dim=1)
                total_correct[class_name] += (preds == true_labels_dict[class_name]).sum().item()

    avg_loss = total_loss / total_samples / len(class_names)
    accuracy = {class_name: total_correct[class_name] / total_samples for class_name in class_names}
    return avg_loss, accuracy

In [24]:
val_dataset = util.ImageDataset(val_paths, processor)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=1024, shuffle=False)

test_dataset = util.ImageDataset(test_paths, processor)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1024, shuffle=False)

In [25]:
avg_loss, accuracy = validate(mo_model, val_dataloader, criteria, device, text_input_dict, selected_class_names)

100%|██████████| 21/21 [09:21<00:00, 26.76s/it]


In [26]:
print(avg_loss)
print(accuracy)

1.3311598339182198
{'product_type_name': 0.6103342591705775, 'graphical_appearance_name': 0.7141557322095585}
