In [None]:
from PIL import Image
import requests
from transformers import CLIPProcessor, CLIPModel
import torch
import torch.nn.functional as F
import pandas as pd
# from src import util
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import os
from peft import LoraConfig, get_peft_model, PromptTuningConfig, PeftType
import wandb

torch.manual_seed(42)

In [None]:
def get_image_paths_and_labels_from_df(df, data_dir):
    article_ids = df["article_id"].values
    image_paths = []
    labels = []
    
    for article_id in article_ids:
        image_path = f"{data_dir}/images/0{str(article_id)[:2]}/0{article_id}.jpg"
        # Check if the image file exists
        if os.path.exists(image_path):
            image_paths.append(image_path)
            # Add corresponding label only if the image exists
            labels.append(df[df["article_id"] == article_id])

    return image_paths, labels

class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, processor=None):
        self.image_paths = image_paths
        self.processor = processor
        self.image_ids = []

        for image_path in self.image_paths:
            if not os.path.exists(image_path):
                raise FileNotFoundError(f"Image {image_path} not found.")
            else:
                image_id = int(image_path.split("/")[-1].split(".")[0])
                self.image_ids.append(image_id)
            

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx])
        if self.processor is not None:
            inputs = self.processor(images=image, return_tensors="pt", padding=True)
            image = inputs["pixel_values"][0]
        return image, self.image_ids[idx]

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", 
                                  cache_dir="model", local_files_only=False)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", 
                                          cache_dir="model", local_files_only=False)

model = model.to(device)

In [None]:
# data_dir = "data"
# data_dir = "/kaggle/input/h-and-m-personalized-fashion-recommendations"
data_dir = "/Users/tobiaspeihengli/Downloads/DD2430/h-and-m-personalized-fashion-recommendations"
# articles_filtered: remove articles with no images
articles = pd.read_csv(f"{data_dir}/articles.csv")

In [None]:
# map from article_id to df index
article_id_to_idx = {article_id: idx for idx, article_id in enumerate(articles["article_id"])}

# get all classes of the dataframe
class_names = articles.columns.tolist()
label_names = dict()
label_names_to_idx = dict()
for class_name in class_names:
    label_names[class_name] = articles[class_name].unique()
    label_names_to_idx[class_name] = {label_name: idx for idx, label_name in enumerate(label_names[class_name])}

article_ids = label_names["article_id"]
selected_class_names = ["product_group_name", "product_type_name", "graphical_appearance_name", "colour_group_name", "perceived_colour_value_name", "perceived_colour_master_name", "department_name", "index_name", "index_group_name", "section_name", "garment_group_name"]
# selected_class_names = ["product_type_name", "graphical_appearance_name"]

In [None]:
# grouped by product_code
grouped = articles.groupby("product_code")
groups = [group for _, group in grouped]

# split 0.8/0.1/0.1
train_groups, test_groups = train_test_split(groups, test_size=0.2, random_state=42) 
val_groups, test_groups = train_test_split(test_groups, test_size=0.5, random_state=42) 

train_df = pd.concat(train_groups)
val_df = pd.concat(val_groups)
test_df = pd.concat(test_groups)

# 256 for local test
# train_df = train_df.sample(256)
# val_df = val_df.sample(256)
# test_df = test_df.sample(256)

print(f"{len(train_df)=} {len(val_df)=} {len(test_df)=}")

In [None]:
# train_paths, train_labels = get_image_paths_and_labels_from_df(train_df, data_dir)
# val_paths, val_labels = get_image_paths_and_labels_from_df(val_df, data_dir)
# test_paths, test_labels = get_image_paths_and_labels_from_df(test_df, data_dir)

In [None]:
# os.environ["TOKENIZERS_PARALLELISM"] = "false"

# train_dataset = ImageDataset(train_paths, processor)
# train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True) # Don't forget to change the batch size to 1

# val_dataset = ImageDataset(val_paths, processor)
# val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=256, shuffle=False)

# test_dataset = ImageDataset(test_paths, processor)
# test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False)

In [None]:
class MultiOutputLayer(torch.nn.Module):
    def __init__(self, input_size, inter_size, output_size):
        super(MultiOutputLayer, self).__init__()
        self.fc1 = torch.nn.Linear(input_size, inter_size)
        self.fc2 = torch.nn.Linear(inter_size, output_size)
        self.dropout = torch.nn.Dropout(0.5)
        self.act = torch.nn.SiLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [None]:
# class MultiOutputClipModel(torch.nn.Module):
#     def __init__(self, clip_model, class_names, vision_hidden_size, inter_size, output_size):
#         super(MultiOutputClipModel, self).__init__()
#         self.clip_model = clip_model
#         self.class_names = class_names
#         self.output_layers = torch.nn.ModuleDict({
#             class_name: MultiOutputLayer(vision_hidden_size, inter_size, output_size)
#             for class_name in self.class_names
#         })
    
#     def forward(
#         self,
#         text_input_dict,
#         pixel_values,
#         # position_ids = None,
#         output_attentions = None,
#         output_hidden_states = None,
#         return_dict = None,
#     ):

#         output_attentions = output_attentions if output_attentions is not None else self.clip_model.config.output_attentions
#         output_hidden_states = (
#             output_hidden_states if output_hidden_states is not None else self.clip_model.config.output_hidden_states
#         )
#         return_dict = return_dict if return_dict is not None else self.clip_model.config.use_return_dict

#         vision_outputs = self.clip_model.vision_model(
#             pixel_values=pixel_values,
#             output_attentions=output_attentions,
#             output_hidden_states=output_hidden_states,
#             return_dict=return_dict,
#         )

#         vision_embeds = vision_outputs[1]
#         vision_embeds_dict = {
#             class_name: output_layer(vision_embeds) 
#                 for class_name, output_layer in self.output_layers.items()
#         }

#         text_outputs_dict = {
#             class_name: self.clip_model.text_model(
#                 input_ids=text_input_dict[class_name]["input_ids"],
#                 attention_mask=text_input_dict[class_name]["attention_mask"],
#                 # position_ids=position_ids,
#                 output_attentions=output_attentions,
#                 output_hidden_states=output_hidden_states,
#                 return_dict=return_dict,
#             ) for class_name in self.class_names
#         }

#         text_embeds_dict = {
#             class_name: self.clip_model.text_projection(text_outputs[1])
#                 for class_name, text_outputs in text_outputs_dict.items()
#         }

#         logits_per_image_dict = {
#             class_name: vision_embeds_dict[class_name] @ text_embeds_dict[class_name].T
#                 for class_name in self.output_layers.keys()
#         }

#         return logits_per_image_dict

In [None]:
class MultiOutputClipModel(torch.nn.Module):
    def __init__(self, clip_model, class_names, vision_hidden_size, inter_size, output_size, num_virtual_tokens):
        super(MultiOutputClipModel, self).__init__()
        self.clip_model = clip_model
        self.class_names = class_names
        self.output_layers = torch.nn.ModuleDict({
            class_name: MultiOutputLayer(vision_hidden_size, inter_size, output_size)
            for class_name in self.class_names
        })

        # Soft prompt embeddings per class
        self.num_virtual_tokens = num_virtual_tokens
        embedding_dim = self.clip_model.text_model.embeddings.token_embedding.embedding_dim
        self.soft_prompt_embeddings = torch.nn.ParameterDict({
            class_name: torch.nn.Parameter(torch.randn(num_virtual_tokens, embedding_dim))
            for class_name in self.class_names
        })
        
    def forward(
        self,
        text_input_dict,
        pixel_values,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.clip_model.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.clip_model.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.clip_model.config.use_return_dict

        # Vision processing remains the same
        vision_outputs = self.clip_model.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        vision_embeds = vision_outputs[1]
        vision_embeds_dict = {
            class_name: output_layer(vision_embeds) 
                for class_name, output_layer in self.output_layers.items()
        }

        # Text processing with soft prompts
        text_embeds_dict = {}
        for class_name in self.class_names:
            text_inputs = text_input_dict[class_name]
            input_ids = text_inputs["input_ids"]
            attention_mask = text_inputs["attention_mask"]

            input_embeds = self.clip_model.text_model.embeddings.token_embedding(input_ids)
            batch_size = input_embeds.shape[0]

            # Expand and Concatenate
            soft_prompt = self.soft_prompt_embeddings[class_name].unsqueeze(0).expand(batch_size, -1, -1)
            input_embeds = torch.cat([soft_prompt, input_embeds], dim=1)

            # Adjust attention mask
            soft_prompt_mask = torch.ones(batch_size, self.num_virtual_tokens).to(attention_mask.device)
            attention_mask = torch.cat([soft_prompt_mask, attention_mask], dim=1)
            valid_token_indices = (attention_mask.sum(dim=-1) - 1).long()
            attention_mask = attention_mask[:, None, None, :]
            attention_mask = attention_mask.expand(-1, 1, attention_mask.size(-1), attention_mask.size(-1))

            encoder_outputs = self.clip_model.text_model.encoder(
                inputs_embeds=input_embeds,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            last_hidden_state = encoder_outputs[0]
            # print(last_hidden_state.shape)

            pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), valid_token_indices]

            pooled_output = self.clip_model.text_model.final_layer_norm(pooled_output)

            text_embeds = self.clip_model.text_projection(pooled_output)

            text_embeds_dict[class_name] = text_embeds


        # Compute logits
        logits_per_image_dict = {
            class_name: vision_embeds_dict[class_name] @ text_embeds_dict[class_name].T
                for class_name in self.output_layers.keys()
        }

        return logits_per_image_dict


In [None]:
# custom criterion: cross entropy loss across all classes
class MultiOutputClipCriterion(torch.nn.Module):
    def __init__(self, class_names):
        super(MultiOutputClipCriterion, self).__init__()
        self.class_names = class_names
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, logits_dict, labels_dict):
        loss = 0
        for class_name in self.class_names:
            logits = logits_dict[class_name]
            labels = labels_dict[class_name]
            loss += self.criterion(logits, labels)
        return loss

In [None]:
# freeze all parameters in model

# for param in model.parameters():
#     param.requires_grad = False

# Define LoRA configuration
lora_config = LoraConfig(
    r=8,                  # Low-rank dimension (adjustable)
    lora_alpha=32,          # Scaling factor (adjustable)
    target_modules=["q_proj", "v_proj", "k_proj"],  # Specify which layers to apply LoRA to
    lora_dropout=0.05,       # Dropout rate (optional)
    bias="none",            # Whether to include biases ("none", "all", "lora_only")
    task_type="classification"  # Task type ("classification" or "regression")
)

# Apply LoRA to the CLIP model
model = get_peft_model(model, lora_config)


In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def get_few_shot_train_df(articles, class_name, n_few_shot):
    labels = articles[class_name].unique()
    few_shot_samples = []
    for label in labels:
        label_articles = articles[articles[class_name] == label]
        samples = label_articles.sample(n=min(n_few_shot, len(label_articles)), random_state=42)
        few_shot_samples.append(samples)
    train_df_class = pd.concat(few_shot_samples)
    return train_df_class

def get_class_df(df, class_name):
    return df[df[class_name].notna()]

num_epochs = 10  # Adjust as needed
num_virtual_tokens = 16  # You can adjust this number

n_few_shot_list = [2**i for i in range(0, 5)]  # [1, 2, 4, 8, ..., 1024]

for class_name in selected_class_names:
    print(f"Starting experiments for class: {class_name}")
    # Prepare label names and indices for this class
    labels = label_names[class_name]
    label_names_class = {class_name: labels}
    label_names_to_idx_class = {class_name: {label: idx for idx, label in enumerate(labels)}}

    # Prepare validation and test datasets for this class
    val_df_class = get_class_df(val_df, class_name)
    test_df_class = get_class_df(test_df, class_name)
    val_paths, val_labels = get_image_paths_and_labels_from_df(val_df_class, data_dir)
    test_paths, test_labels = get_image_paths_and_labels_from_df(test_df_class, data_dir)
    val_dataset = ImageDataset(val_paths, processor)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=256, shuffle=False)
    test_dataset = ImageDataset(test_paths, processor)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False)

    # Text inputs for the model
    text_input_dict = {
        class_name: processor(text=[f"{label}" for label in labels], 
                              return_tensors="pt", padding=True).to(device)
    }

    for n_few_shot in n_few_shot_list:
        print(f"\nTraining with n_few_shot = {n_few_shot}")
        # Prepare training dataset for this class and n_few_shot
        train_df_class = get_few_shot_train_df(train_df, class_name, n_few_shot)
        train_paths, train_labels = get_image_paths_and_labels_from_df(train_df_class, data_dir)
        train_dataset = ImageDataset(train_paths, processor)
        train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=n_few_shot, shuffle=True)

        # Initialize model for each n_few_shot
        mo_model = MultiOutputClipModel(model, [class_name], 768, 128, 512, num_virtual_tokens).to(device)
        mo_model.train()

        criteria = MultiOutputClipCriterion(class_names=[class_name])
        optimizer = torch.optim.AdamW(mo_model.parameters(), lr=1e-4)
        step = 0

        # Initialize wandb with n_few_shot and class_name in config
        wandb.init(project=f"Few_shot_experiment_{class_name}",
                   name=f"n_few_shot_{n_few_shot}",
                   config={'n_few_shot': n_few_shot, 'class_name': class_name},
                   reinit=True)

        def validate(model, dataloader, criteria, device, text_inputs, class_names):
            model.eval()
            total_loss = 0.0
            total_correct = {class_name: 0 for class_name in class_names}
            total_samples = 0

            with torch.no_grad():
                for images, image_ids in tqdm(dataloader):
                    images = images.to(device)
                    logits_per_image_dict = model(pixel_values=images, text_input_dict=text_inputs)

                    # Get true labels from image_ids
                    true_labels_dict = {
                        class_name: [label_names_to_idx_class[class_name][articles.loc[article_id_to_idx[image_id.item()], class_name]] 
                                   for image_id in image_ids]
                        for class_name in class_names
                    }
                    true_labels_dict = {class_name: torch.tensor(true_labels).to(device)
                                        for class_name, true_labels in true_labels_dict.items()}
                    
                    # Compute loss
                    loss = criteria(logits_per_image_dict, true_labels_dict)
                    total_loss += loss.item() * images.size(0)

                    # Predictions and accuracy
                    total_samples += images.size(0)
                    for class_name in class_names:
                        _, preds = torch.max(logits_per_image_dict[class_name], dim=1)
                        total_correct[class_name] += (preds == true_labels_dict[class_name]).sum().item()

            avg_loss = total_loss / total_samples
            accuracy = {class_name: total_correct[class_name] / total_samples for class_name in class_names}
            return avg_loss, accuracy

        for epoch in range(num_epochs):
            mo_model.train()
            total_loss = 0.0
            total_correct = 0
            total_samples = 0

            for images, image_ids in tqdm(train_dataloader):
                images = images.to(device)
                logits_per_image_dict = mo_model(pixel_values=images, text_input_dict=text_input_dict)

                # Get true labels from image_ids
                true_labels_dict = {
                    class_name: [label_names_to_idx_class[class_name][articles.loc[article_id_to_idx[image_id.item()], class_name]] 
                               for image_id in image_ids]
                }
                true_labels_dict = {class_name: torch.tensor(true_labels).to(device) 
                                    for class_name, true_labels in true_labels_dict.items()}
                
                # Compute loss
                loss = criteria(logits_per_image_dict, true_labels_dict)
                total_loss += loss.item() * images.size(0)

                # Predictions and accuracy
                correct = 0
                total_samples += images.size(0)
                for class_name in [class_name]:
                    _, preds = torch.max(logits_per_image_dict[class_name], dim=1)
                    correct += (preds == true_labels_dict[class_name]).sum().item()
                total_correct += correct

                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # log the loss and accuracy to wandb
                wandb.log({"train_loss": loss.item(), "train_accuracy": correct / images.size(0)},
                          step=step)
                step += 1

            avg_loss = total_loss / total_samples
            accuracy = total_correct / total_samples
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

        # Validate after each epoch
        val_loss, val_accuracy_dict = validate(mo_model, val_dataloader, criteria, device, text_input_dict, [class_name])
        val_accuracy = val_accuracy_dict[class_name]
        print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

        # Log to wandb
        log_dict = {
            "val_loss": val_loss,
            "val_accuracy": val_accuracy,
            "n_few_shot": n_few_shot
        }
        wandb.log(log_dict, step=step)

        wandb.finish()

        print(f"Finished training for n_few_shot = {n_few_shot} for class: {class_name}")

    print(f"Completed all experiments for class: {class_name}\n")


In [None]:
# os.environ["TOKENIZERS_PARALLELISM"] = "false"

# n_few_shot = 4  # Adjust the number of few-shot examples per label

# def get_few_shot_train_df(articles, class_name, n_few_shot):
#     labels = articles[class_name].unique()
#     few_shot_samples = []
#     for label in labels:
#         label_articles = articles[articles[class_name] == label]
#         samples = label_articles.sample(n=min(n_few_shot, len(label_articles)), random_state=42)
#         few_shot_samples.append(samples)
#     train_df_class = pd.concat(few_shot_samples)
#     return train_df_class

# def get_class_df(df, class_name):
#     return df[df[class_name].notna()]

# num_epochs = 1  # Adjust as needed
# num_virtual_tokens = 16  # You can adjust this number

# for class_name in selected_class_names:
#     print(f"Training for class: {class_name}")
#     # Prepare label names and indices for this class
#     labels = label_names[class_name]
#     label_names_class = {class_name: labels}
#     label_names_to_idx_class = {class_name: {label: idx for idx, label in enumerate(labels)}}

#     # Prepare training, validation, and test datasets for this class
#     train_df_class = get_few_shot_train_df(train_df, class_name, n_few_shot)
#     val_df_class = get_class_df(val_df, class_name)
#     test_df_class = get_class_df(test_df, class_name)

#     train_paths, train_labels = get_image_paths_and_labels_from_df(train_df_class, data_dir)
#     val_paths, val_labels = get_image_paths_and_labels_from_df(val_df_class, data_dir)
#     test_paths, test_labels = get_image_paths_and_labels_from_df(test_df_class, data_dir)

#     train_dataset = ImageDataset(train_paths, processor)
#     train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True) # Don't forget to change the batch size to 1

#     val_dataset = ImageDataset(val_paths, processor)
#     val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=256, shuffle=False)

#     test_dataset = ImageDataset(test_paths, processor)
#     test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False)

#     mo_model = MultiOutputClipModel(model, [class_name], 768, 128, 512, num_virtual_tokens).to(device)
#     mo_model.train()

#     text_input_dict = {
#         class_name: processor(text=[f"{label}" for label in labels], 
#                               return_tensors="pt", padding=True).to(device)
#     }

#     criteria = MultiOutputClipCriterion(class_names=[class_name])
#     optimizer = torch.optim.AdamW(mo_model.parameters(), lr=1e-4)
#     step = 0
#     wandb.init(project=f"Few_shot_experiment_{class_name}")

#     def validate(model, dataloader, criteria, device, text_inputs, class_names):
#         model.eval()
#         total_loss = 0.0
#         total_correct = {class_name: 0 for class_name in class_names}
#         total_samples = 0

#         with torch.no_grad():
#             for images, image_ids in tqdm(dataloader):
#                 images = images.to(device)
#                 logits_per_image_dict = model(pixel_values=images, text_input_dict=text_inputs)

#                 # Get true labels from image_ids
#                 true_labels_dict = {
#                     class_name: [label_names_to_idx_class[class_name][articles.loc[article_id_to_idx[image_id.item()], class_name]] 
#                                for image_id in image_ids]
#                     for class_name in class_names
#                 }
#                 true_labels_dict = {class_name: torch.tensor(true_labels).to(device)
#                                     for class_name, true_labels in true_labels_dict.items()}
                
#                 # Compute loss
#                 loss = criteria(logits_per_image_dict, true_labels_dict)
#                 total_loss += loss.item() * images.size(0)

#                 # Predictions and accuracy
#                 total_samples += images.size(0)
#                 for class_name in class_names:
#                     _, preds = torch.max(logits_per_image_dict[class_name], dim=1)
#                     total_correct[class_name] += (preds == true_labels_dict[class_name]).sum().item()

#         avg_loss = total_loss / total_samples
#         accuracy = {class_name: total_correct[class_name] / total_samples for class_name in class_names}
#         return avg_loss, accuracy

#     for epoch in range(num_epochs):
#         mo_model.train()
#         total_loss = 0.0
#         total_correct = 0
#         total_samples = 0

#         for images, image_ids in tqdm(train_dataloader):
#             images = images.to(device)
#             logits_per_image_dict = mo_model(pixel_values=images, text_input_dict=text_input_dict)

#             # Get true labels from image_ids
#             true_labels_dict = {
#                 class_name: [label_names_to_idx_class[class_name][articles.loc[article_id_to_idx[image_id.item()], class_name]] 
#                            for image_id in image_ids]
#             }
#             true_labels_dict = {class_name: torch.tensor(true_labels).to(device) 
#                                 for class_name, true_labels in true_labels_dict.items()}
            
#             # Compute loss
#             loss = criteria(logits_per_image_dict, true_labels_dict)
#             total_loss += loss.item() * images.size(0)

#             # Predictions and accuracy
#             correct = 0
#             total_samples += images.size(0)
#             for class_name in [class_name]:
#                 _, preds = torch.max(logits_per_image_dict[class_name], dim=1)
#                 correct += (preds == true_labels_dict[class_name]).sum().item()
#             total_correct += correct

#             # Backward pass
#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()

#             # log the loss and accuracy to wandb
#             wandb.log({"train_loss": loss.item(), "train_accuracy": correct / images.size(0)},
#                       step=step)
#             step += 1

#         avg_loss = total_loss / total_samples
#         accuracy = total_correct / total_samples
#         print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

#         # Validate after each epoch
#         val_loss, val_accuracy_dict = validate(mo_model, val_dataloader, criteria, device, text_input_dict, [class_name])
#         val_accuracy = val_accuracy_dict[class_name]
#         print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

#         # Log to wandb
#         log_dict = {
#             "val_loss": val_loss,
#             "val_accuracy": val_accuracy
#         }
#         wandb.log(log_dict, step=step)

#     wandb.finish()

#     test_loss, test_accuracy_dict = validate(
#         mo_model, test_dataloader, criteria, device, text_input_dict, [class_name]
#     )

#     print(f"Test Loss: {test_loss:.4f}")
#     print(f"Test Accuracy: {test_accuracy_dict[class_name]:.4f}")

#     print(f"Finished training for class: {class_name}\n")


In [None]:
# mo_model = MultiOutputClipModel(model, selected_class_names, 768, 128, 512).to(device)
# mo_model.train()

# num_virtual_tokens = 16  # You can adjust this number
# mo_model = MultiOutputClipModel(model, selected_class_names, 768, 128, 512, num_virtual_tokens).to(device)
# mo_model.train()

In [None]:
# show all trainable parameters in mo_model
# for name, param in mo_model.named_parameters():
#     if param.requires_grad:
#         print(name)

In [None]:
# generate text input
# text_input_dict = {
#     class_name: processor(text=[f"{label}" for label in label_names[class_name]], 
#                           return_tensors="pt", padding=True).to(device)
#     for class_name in selected_class_names
# }

In [None]:
# num_epochs = 1  # Adjust as needed
# criteria = MultiOutputClipCriterion(class_names=selected_class_names)
# optimizer = torch.optim.AdamW(mo_model.parameters(), lr=1e-4)
# step = 0
# wandb.init(project="Multi_head_lora_prompt_tune_experiment256samples")

In [None]:
# def validate(model, dataloader, criteria, device, text_inputs, class_names):
#     model.eval()
#     total_loss = 0.0
#     total_correct = {class_name: 0 for class_name in class_names}
#     total_samples = 0

#     with torch.no_grad():
#         for images, image_ids in tqdm(dataloader):
#             images = images.to(device)
#             logits_per_image_dict = model(pixel_values=images, text_input_dict=text_inputs)

#             # Get true labels from image_ids
#             true_labels_dict = {
#                 class_name: [label_names_to_idx[class_name][articles.loc[article_id_to_idx[image_id.item()], class_name]] 
#                            for image_id in image_ids]
#                 for class_name in class_names
#             }
#             true_labels_dict = {class_name: torch.tensor(true_labels).to(device)
#                                 for class_name, true_labels in true_labels_dict.items()}
            
#             # Compute loss
#             loss = criteria(logits_per_image_dict, true_labels_dict)
#             total_loss += loss.item() * images.size(0)

#             # Predictions and accuracy
#             total_samples += images.size(0)
#             for class_name in class_names:
#                 _, preds = torch.max(logits_per_image_dict[class_name], dim=1)
#                 total_correct[class_name] += (preds == true_labels_dict[class_name]).sum().item()

#     avg_loss = total_loss / total_samples / len(class_names)
#     accuracy = {class_name: total_correct[class_name] / total_samples for class_name in class_names}
#     return avg_loss, accuracy

In [None]:
# for epoch in range(num_epochs):
#     mo_model.train()
#     total_loss = 0.0
#     total_correct = 0
#     total_samples = 0

#     for images, image_ids in tqdm(train_dataloader):
#         images = images.to(device)
#         logits_per_image_dict = mo_model(pixel_values=images, text_input_dict=text_input_dict)

#         # Get true labels from image_ids
#         true_labels_dict = {
#             class_name: [label_names_to_idx[class_name][articles.loc[article_id_to_idx[image_id.item()], class_name]] 
#                        for image_id in image_ids]
#             for class_name in selected_class_names
#         }
#         true_labels_dict = {class_name: torch.tensor(true_labels).to(device) 
#                             for class_name, true_labels in true_labels_dict.items()}
        
#         # Compute loss
#         loss = criteria(logits_per_image_dict, true_labels_dict)
#         total_loss += loss.item() * images.size(0)

#         # Predictions and accuracy
#         correct = 0
#         total_samples += images.size(0)
#         for class_name in selected_class_names:
#             _, preds = torch.max(logits_per_image_dict[class_name], dim=1)
#             correct += (preds == true_labels_dict[class_name]).sum().item()
#         total_correct += correct

#         # Backward pass
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         # log the loss and accuracy to wandb
#         wandb.log({"train_loss": loss.item(), "train_accuracy": correct / images.size(0) / len(selected_class_names)},
#                   step=step)
#         step += 1

#     avg_loss = total_loss / total_samples / len(selected_class_names)
#     accuracy = total_correct / total_samples / len(selected_class_names)
#     print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

#     # Validate after each epoch
#     val_loss, val_accuracy_dict = validate(mo_model, val_dataloader, criteria, device, text_input_dict, selected_class_names)
#     val_accuracy = sum(val_accuracy_dict.values()) / len(val_accuracy_dict)
#     print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

#     # Log to wandb
#     log_dict = {
#         "val_loss": val_loss,
#         "val_accuracy": val_accuracy
#     }
#     for class_name, accuracy in val_accuracy_dict.items():
#         log_dict[f"val_accuracy_{class_name}"] = accuracy

#     wandb.log(log_dict, step=step)

# wandb.finish()

In [None]:
# Save the model
# torch.save(mo_model.state_dict(), "model/multihead_lora_prompt_tune.pth")

In [None]:
# test_loss, test_accuracy_dict = validate(
#     mo_model, test_dataloader, criteria, device, text_input_dict, selected_class_names
# )

# print(f"Test Loss: {test_loss:.4f}")

# print("Test Accuracy per Class:")
# for class_name, accuracy in test_accuracy_dict.items():
#     print(f"{class_name}: {accuracy:.4f}")

# test_accuracy = sum(test_accuracy_dict.values()) / len(test_accuracy_dict)
# print(f"Average Test Accuracy: {test_accuracy:.4f}")