# DIXIT

The following notebook contains the experiments performed to obtain a bot that can play a full game of Dixit.

The code is organized as follows:

*   MOUNT DRIVE - IMPORT LIBRARIES - UNZIP DATA  
    This section contains all the code needed to run the experiments. In particular, the drive is mounted, the COCO dataset in unzipped and the libraries are imported (also the seed for randomness are set to ensure reproducibility)

*   BLIP CAPTIONS REPHRASING USING ZEPHYR  
    In this section the simple descriptive captions extracted with a BLIP model are rephrased using an open-source LLM. The choice for this particular model was motivated by some experiments, showing that it was good enough in performing this particular task. Also, given the amount of captions to be rephrased, an open source model without any limitations was the optimal choice to avoid paying expensive LLM APIs or wait days to perform the task using the free ones.

    NOTE 1: AN HUGGING FACE KEY IS REQUIRED TO PERFORM THE TASK!  

    NOTE 2: THE REPHRASING CAN BE DONE USING THE FREE GPU AVAILABLE ON COLAB (T4), BUT IT TAKES 12 HOURS!

*   BLIP FINE TUNING(ONLINE - REPHRASED)  
    This block contains the code used to perform the fine tuning of the BLIP model for both the data found online and the rephrased captions obtained using the Zephyr LLM. Given that there were very few Dixit images, the visual part of the model was frozen during fine tuning. Also, the dataset found online presented a lot of problems, probably because it contained too noisy data, or data from which anything meaningful could be learned.

    NOTE: For this task an L4 GPU was used

*   CLIP FINE TUNING(ONLINE - REPHRASED)  
    The section contains the code used to fine tune CLIP on both the data found online and the ones produced with the rephrasing. The performances where not great, and even with a fine tune of just the projection layers (the ones that project the text and image features in a common embedding space), the model kept overfitting or not learning anything meaningful. Probably the Dixit images where still too few, even with just the projection layers to be fine tuned. Also, the same problems regarding the online data discovered during the BLIP fine tuning were also detected here.

*   FINE TUNING CLIP WITH FINE TUNED BLIP ON COCO  
    The block contains the code used to fine tune CLIP using COCO dataset. The process was performed in the following way: given an image inside the dataset, the fine tuned BLIP model was used to extract a caption, to then obtain an image with its own creative caption, to be used to fine tune CLIP. The idea was that in this way, the model could learn from more general images without overfitting, to then be used in Dixit. This fine tuning bettered the performance of the model by 4-5%.


*Additional info:*  
Throughout the whole notebook, among all the experiments that require a number of epochs to be defined, a standard number (100) was chosen, large enough to ensure a long training if needed. Despite this, in every single experiment, an early manual stopping was performed.

Also, in order to use the images inside a custom Dataset, a standard image size (224x224) was used for all the models that required images as input (BLIP and CLIP). This format was not chosen for a particular reason except standardization among all the experiments. However, by default, the CLIP processor, before feeding the images to the model, resize them in this exact same format.




# MOUNT DRIVE - IMPORT LIBRARIES - UNZIP DATA

### Mount drive for data

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

### Unzip COCO dataset

In [None]:
!unzip "/content/drive/MyDrive/Dixit/COCO_Dataset/train2014.zip"

In [None]:
!unzip "/content/drive/MyDrive/Dixit/COCO_Dataset/val2014.zip"

In [None]:
!unzip "/content/drive/MyDrive/Dixit/COCO_Dataset/annotations_trainval2014.zip"

### Import Libraries

In [None]:
from transformers import CLIPProcessor, CLIPModel
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import AutoTokenizer, AutoModelForCausalLM

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch import optim


import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision

import os
import random
import math
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import trange, tqdm
from PIL import Image
import pandas as pd
from collections import defaultdict
import json

random.seed(23)
torch.manual_seed(23)

# BLIP CAPTIONS REPHRASING USING ZEPHYR

The following block contains the code used to perform the rephrasing. Given that this process was performed with two different prompts, everything was left as it was during the second prompt rephrasing. The first prompt remains as a comment inside the main block whose purpose is to do the actual rephrasing. The path here written were the ones used to save the second prompt rephrasing.


The following package is needed to run the LLM

In [None]:
!pip install -U bitsandbytes

In [None]:
token = "" # <-- Insert your hugging face key

tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-alpha", token = token)
model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-alpha", device_map="auto", load_in_4bit=True, token = token)

Given that the rephrasing takes almost 12 hours, the process can be both started and resumed to the last iteration using the following method

In [None]:
def start_rephrasing(path_to_captions, path_to_rephrased_captions):
    with open(path_to_captions, 'r') as f:
        captions_dict = json.load(f)

    if not os.path.exists(path_to_rephrased_captions):
      with open(path_to_rephrased_captions, 'w') as f:
        json.dump({}, f)
      return captions_dict

    with open(path_to_rephrased_captions, 'r') as f:
        rephrased_captions_dict = json.load(f)

    return {k:v for k,v in captions_dict.items() if k not in rephrased_captions_dict}

In [None]:
captions_dict = start_rephrasing("/content/drive/MyDrive/Dixit/captions/captions(original_images).json", "/content/drive/MyDrive/Dixit/captions/rephrased_captions(2nd_prompt).json")

BE AWARE: It takes almost 12 hours on a T4 GPU!

In [None]:
model.eval()
for image, captions in tqdm(captions_dict.items(), desc="REPHRASING..."):

    rephrased_captions = []

    for caption in captions:
        messages = [
            {
                "role": "system",
                #"content": "You are a chatbot whose purpose is to rephrase captions in the shortest, most creative, mysterious and vague way.", (1st prompt)
                "content": "You are a chatbot whose purpose is to use three words representing emotions and abstract reasoning to summarize a caption.", #(2nd prompt)
            },
            {"role": "user", "content": caption},
        ]

        model_inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to("cuda")

        input_length = model_inputs.shape[1]

        with torch.no_grad():
          generated_ids = model.generate(
              model_inputs,
              max_new_tokens=30,
              num_return_sequences=100,
              temperature=0.7,
              top_p=0.9,
              repetition_penalty=1.2,
              no_repeat_ngram_size=3,
              eos_token_id=tokenizer.eos_token_id,
              do_sample=True
          )

        rephrased = tokenizer.batch_decode(generated_ids[:, input_length:], skip_special_tokens=True)

        rephrased_captions.extend(rephrased)

    with open("/content/drive/MyDrive/Dixit/captions/rephrased_captions(2nd_prompt).json", 'r') as f:
        rephrased_captions_dict = json.load(f)

    rephrased_captions_dict[image] = rephrased_captions

    with open("/content/drive/MyDrive/Dixit/captions/rephrased_captions(2nd_prompt).json", 'w') as f:
      json.dump(rephrased_captions_dict, f)

# BLIP FINE TUNING(ONLINE - REPHRASED)

The following cells contain the code to perform fine tuning on the BLIP model for both the dataset found online and the rephrased captions. Again, given that there are two different dataset for the rephrased captions obtained with different prompts, the paths involving the rephrased dataset here used refer to the fine tuning done with the second prompt dataset

### IMAGES LOADING

In [None]:
def open_csv(path):
    df = pd.read_csv (path)
    return df

In [None]:
image_path = "/content/drive/MyDrive/Dixit/dixit_cards"

images = dict()
for image_name in os.listdir(image_path):
  image = Image.open(os.path.join(image_path, image_name)).convert('RGB')
  images[int(image_name.split(".")[0])] = image

### DATASET CREATION

You can choose which dataset (and so which kind of data) to use for the fine tuning

===================================================================================================================

To fine tune BLIP using the dixit data found online, run the following cells:

In [None]:
class DixitDataset(Dataset):
  def __init__(self, images, annotations, processor):
    self.images = images
    self.annotations = annotations
    self.processor = processor

  def __len__(self):
    return len(self.annotations)

  def __getitem__(self, index):

    card, caption = self.annotations[index]
    image = self.images[card]

    inputs = processor(images=image, text=caption, return_tensors="pt", padding="max_length", truncation=True)
    for k in inputs.keys():
      inputs[k] = inputs[k].squeeze()

    return inputs

Given that the model was learning to output blank captions if trained on the untouched data found online, an idea was to train just with longer captions to see how it would behave. After different attempts with different lenghts (the last one being lenght = 4, here used in the following cell), it was clear that those data were too noisy and incomplete.

In [None]:
annotation_path = "/content/drive/MyDrive/Dixit/dixit.csv"
df = open_csv(annotation_path)
list_of_dicts = df.to_dict(orient='records')

captions_list = []
for d in list_of_dicts:
  if len(d["DESCRIPTION"].split()) > 3:
    narrator = -1
    for i in range(1, d["PLAYERS"] + 1):
      if d[f"C{i}_TARGET"]:
        narrator = i
        break

    card_number = int(d[f"C{narrator}_CARD"])
    captions_list.append((card_number, d["DESCRIPTION"]))

===================================================================================================================

To fine tune BLIP on rephrased captions, run the following cells:

In [None]:
class DixitDataset(Dataset):
  def __init__(self, images, annotations, processor):
    self.images = images
    self.annotations = annotations
    self.processor = processor

  def __len__(self):
    return len(self.annotations)

  def __getitem__(self, index):

    card, captions = self.annotations[index]
    image = self.images[card]

    caption = random.choice(captions)

    inputs = processor(images=image, text=caption, return_tensors="pt", padding="max_length", truncation=True)
    for k in inputs.keys():
      inputs[k] = inputs[k].squeeze()

    return inputs

In [None]:
path = "/content/drive/MyDrive/Dixit/captions/rephrased_captions(2nd_prompt).json"

with open(path, "r") as f:
  data = json.load(f)

captions_list = []
for image, captions in data.items():
  captions_list.append((int(image.split(".")[0]), captions))

### HYPERPARAMETERS AND CODE TO DO TRAINING

In [None]:
split = 0.15

random.shuffle(captions_list)
train_captions = captions_list[int(len(captions_list)*split):]
val_captions = captions_list[:int(len(captions_list)*split)]

In [None]:
learning_rate = 5e-5

image_size = 224

nepochs = 100

batch_size = 16

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

In [None]:
train_dataset = DixitDataset(images, train_captions, processor)
val_dataset = DixitDataset(images, val_captions, processor)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=8)

Given the very small amount of images (84), compared to the thousand captions from both the online and rephrased data, the BLIP model was overfitting pretty quickly the cards provided. To avoid this the vision model was frozen, so that just the text part was being fine tuned.

In [None]:
for param in caption_model.vision_model.parameters():
    param.requires_grad = False

In [None]:
optimizer = optim.AdamW(caption_model.parameters(), lr=learning_rate)

In [None]:
def train(epochs, caption_model, train_dataloader, val_dataloader, optimizer, path, start_epoch = 0):
  for epoch in trange(start_epoch+1, epochs, leave=False, desc="Epoch"):
      loss_epoch = 0
      caption_model.train()
      for inputs in tqdm(train_dataloader, desc="Training", leave=False):

          inputs = inputs.to(device)

          outputs = caption_model(pixel_values = inputs["pixel_values"], input_ids = inputs["input_ids"], attention_mask = inputs["attention_mask"], labels = inputs["input_ids"])
          loss = outputs.loss

          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

          loss_epoch += loss.item()

      vloss_epoch = 0
      caption_model.eval()
      with torch.no_grad():
        for inputs in tqdm(val_dataloader, desc="Validation", leave=False):
            inputs = inputs.to(device)

            outputs = caption_model(pixel_values = inputs["pixel_values"], input_ids = inputs["input_ids"], attention_mask = inputs["attention_mask"], labels = inputs["input_ids"])
            vloss = outputs.loss
            vloss_epoch += vloss.item()

      metrics =  f"EPOCH {epoch}/{epochs}. Training Loss: {loss_epoch/len(train_dataloader)} Validation Loss: {vloss_epoch/len(val_dataloader)}"
      with open(os.path.join(path, "metrics.txt"), "a") as f:
          f.write(metrics + "\n")
          torch.save(caption_model.state_dict(), os.path.join(path, "weights", f"epoch{epoch}.pt"))
          torch.save(optimizer.state_dict(), os.path.join(path, "optimizer", f"epoch{epoch}.pt"))
      print(metrics)

### RESUME TRAINING

If for some reason Google Colab kick you during the training, run this cell with your values to resume the training

In [None]:
start_epoch = -1
path_to_weights = ""
path_to_optimizer = ""
caption_model.load_state_dict(torch.load(path_to_weights))
optimizer.load_state_dict(torch.load(path_to_optimizer))

### TRAIN

An L4 GPU was used for this training

In [None]:
path_to_save = ""

In [None]:
train(nepochs, caption_model, train_dataloader, val_dataloader, optimizer, path_to_save)

# CLIP FINE TUNING (ONLINE - REPHRASED)

Block containing the code used to fine tune CLIP on both the dataset found online and the one created with Zephyr. Again, given that there are two different dataset for the rephrased captions obtained with different prompts, the paths involving the rephrased dataset here used refer to the fine tuning done with the second prompt dataset

In [None]:
def open_csv(path):
    df = pd.read_csv (path)
    return df

In [None]:
image_path = "/content/drive/MyDrive/Dixit/dixit_cards"

images = dict()
for image_name in os.listdir(image_path):
  image = Image.open(os.path.join(image_path, image_name)).convert('RGB')
  images[int(image_name.split(".")[0])] = image

==============================================================================

Run this to fine tune clip on data found online:

In [None]:
class DixitDataset(Dataset):
  def __init__(self, images, annotations, processor):
    self.images = images
    self.annotations = annotations
    self.processor = processor

  def __len__(self):
    return len(self.annotations)

  def __getitem__(self, index):

    card, caption = self.annotations[index]
    image = self.images[card]

    inputs = self.processor(text=caption, images=image, return_tensors="pt", padding="max_length", truncation=True).to(device)

    return {k:v.squeeze() for k,v in inputs.items()}

In [None]:
annotation_path = "/content/drive/MyDrive/Dixit/dixit.csv"
df = open_csv(annotation_path)
list_of_dicts = df.to_dict(orient='records')

captions_list = []
for d in list_of_dicts:
  if len(d["DESCRIPTION"].split()) > 3:
    narrator = -1
    for i in range(1, d["PLAYERS"] + 1):
      if d[f"C{i}_TARGET"]:
        narrator = i
        break

    card_number = int(d[f"C{narrator}_CARD"])
    captions_list.append((card_number, d["DESCRIPTION"]))

Run this to fine tune clip of rephrased captions:

In [None]:
class DixitDataset(Dataset):
  def __init__(self, images, annotations, processor):
    self.images = images
    self.annotations = annotations
    self.processor = processor

  def __len__(self):
    return len(self.annotations)

  def __getitem__(self, index):

    card, captions = self.annotations[index]
    image = self.images[card]

    caption = random.choice(captions)

    inputs = self.processor(text=caption, images=image, return_tensors="pt", padding="max_length", truncation=True).to(device)

    return {k:v.squeeze() for k,v in inputs.items()}

In [None]:
path = "/content/drive/MyDrive/Dixit/captions/rephrased_captions(2nd_prompt).json"

with open(path, "r") as f:
  data = json.load(f)

captions_list = []
for image, captions in data.items():
  captions_list.append((int(image.split(".")[0]), captions))

==============================================================================

In [None]:
split = 0.10

random.shuffle(captions_list)
train_captions = captions_list[int(len(captions_list)*split):]
val_captions = captions_list[:int(len(captions_list)*split)]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = "openai/clip-vit-base-patch16"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)

optimizer = optim.AdamW(model.parameters(), lr=1e-7)

model.to(device)

In [None]:
batch_size = 4

train_dataset = DixitDataset(images, train_captions, processor)
val_dataset = DixitDataset(images, val_captions, processor)

data_loader_train = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
data_loader_val = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

Given that the CLIP model was consistently overfitting, to mitigate this behaviour the idea was to fine tune just the projection layers. Those layers project the text and images features in a common embedding space. It is reasonable to think that a pretrained CLIP model can already extract meaningful features, the problem is how to represent them in a common space.

In [None]:
for param in model.parameters():
    param.requires_grad = False

for param in model.visual_projection.parameters():
    param.requires_grad = True
for param in model.text_projection.parameters():
    param.requires_grad = True

In [None]:
def train(epochs, model, processor, data_loader_train, data_loader_val, optimizer, path, start_epoch = 0):
  for epoch in trange(start_epoch+1, epochs, leave=False, desc="Epoch"):
      loss_epoch = 0
      model.train()
      for inputs in tqdm(data_loader_train, desc="Training", leave=False):

          loss = model(**inputs, return_loss = True).loss

          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

          loss_epoch += loss.item()

      vloss_epoch = 0
      model.eval()
      with torch.no_grad():
        for inputs in tqdm(data_loader_val, desc="Validation", leave=False):

            vloss = model(**inputs, return_loss = True).loss

            vloss_epoch += vloss.item()

      metrics =  f"EPOCH {epoch}/{epochs}. Training Loss: {loss_epoch/len(data_loader_train)} Validation Loss: {vloss_epoch/len(data_loader_val)}"
      with open(os.path.join(path, "metrics.txt"), "a") as f:
          f.write(metrics + "\n")
      torch.save(model.state_dict(), os.path.join(path, "weights", f"epoch{epoch}.pt"))
      torch.save(optimizer.state_dict(), os.path.join(path, "optimizer", f"epoch{epoch}.pt"))
      print(metrics)

In [None]:
path_to_save = ""
train(30, model, processor, data_loader_train, data_loader_val, optimizer, path_to_save, start_epoch = 0)

# FINE TUNING CLIP WITH FINE TUNED BLIP ON COCO

This block contain the CLIP fine tuning on the COCO Dataset, with captions extracted with the fine tuned BLIP model

In [None]:
train_set ='/content/train2014'
val_set = '/content/val2014'

train_images = [os.path.join(train_set, x) for x in os.listdir(train_set) if x.endswith('.jpg')]

val_images = [os.path.join(val_set, x) for x in os.listdir(val_set) if x.endswith('.jpg')]

In [None]:
# Train with only 10% of original images, or else an epoch can take some hours
percentage = 0.1

train_images = train_images[:int(len(train_images)*percentage)]
val_images = val_images[:int(len(val_images)*percentage)]

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

blip_model.load_state_dict(torch.load("/content/drive/MyDrive/Dixit/training_results/BLIP-rephrased(2nd_prompt)/weights/epoch50.pt"))

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

optimizer = torch.optim.AdamW(clip_model.parameters(), lr=5e-5)

As explained it the "CLIP FINE TUNING (ONLINE - REPHRASED)" section, just the projection layers are fine tuned to avoid overfitting

In [None]:
for param in clip_model.parameters():
    param.requires_grad = False

for param in clip_model.visual_projection.parameters():
    param.requires_grad = True
for param in clip_model.text_projection.parameters():
    param.requires_grad = True

## Train CLIP on COCO creative captions

In [None]:
class COCORephrasedDataset(Dataset):
  def __init__(self, path_images, transform = None):
    self.path_images = path_images
    self.transform = transform

  def __len__(self):
    return len(self.path_images)

  def __getitem__(self, index):
    path_image = self.path_images[index]

    image = Image.open(path_image).convert("RGB")
    if self.transform:
      image = self.transform(image)
    return image

Given that the BLIP processor already perform a rescaling on the pixel values, the usual transform.ToTensor() is a pain, because it does the same rescaling. If done twice, the model predict completely wrong captions. To avoid this, a custom transform is created. It does exactly the same as ToTensor() but without any rescaling.

In [None]:
batch_size = 64

image_size = 224

transform = transforms.Compose([transforms.Resize((image_size, image_size)), transforms.Lambda(lambda pic: torch.tensor(np.array(pic), dtype=torch.float32).permute(2, 0, 1))])

train_dataset = COCORephrasedDataset(train_images, transform = transform)
val_dataset = COCORephrasedDataset(val_images, transform = transform)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [None]:
def compute_accuracy(logits_per_image, logits_per_text):
    image_to_text_preds = logits_per_image.argmax(dim=-1)
    image_to_text_acc = (image_to_text_preds == torch.arange(len(image_to_text_preds), device=device)).float().mean()

    text_to_image_preds = logits_per_text.argmax(dim=-1)
    text_to_image_acc = (text_to_image_preds == torch.arange(len(text_to_image_preds), device=device)).float().mean()

    return image_to_text_acc.item(), text_to_image_acc.item()

In [None]:
def train(epochs, clip_model, clip_processor, blip_model, blip_processor, train_dataloader, val_dataloader, optimizer, path, start_epoch = 0):
  blip_model.eval()
  for epoch in trange(start_epoch+1, epochs, leave=False, desc="Epoch"):
      loss_epoch = 0

      clip_model.train()

      tot_image_to_text_acc = 0
      tot_text_to_image_acc = 0
      tot_training_samples = 0

      for inputs in tqdm(train_dataloader, desc="Training", leave=False):

          to_caption = blip_processor(inputs, return_tensors="pt").to(device)

          with torch.no_grad():
            captions = blip_model.generate(
                  **to_caption,
                  max_length=50,
                  num_return_sequences=1,
                  do_sample=True,
                  top_k=50,
                  top_p=0.95,
                  temperature=0.7,
                  repetition_penalty=1.2,
                  no_repeat_ngram_size=3
            )

          to_clip = [blip_processor.decode(caption, skip_special_tokens=True) for caption in captions]

          inputs_clip = clip_processor(text=to_clip, images=inputs, return_tensors="pt", padding="max_length", truncation=True).to(device)

          outputs = clip_model(**inputs_clip, return_loss=True)

          loss = outputs.loss
          logits_per_image = outputs.logits_per_image
          logits_per_text = outputs.logits_per_text

          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

          loss_epoch += loss.item()

          image_to_text_acc, text_to_image_acc = compute_accuracy(logits_per_image, logits_per_text)

          tot_image_to_text_acc += image_to_text_acc
          tot_text_to_image_acc += text_to_image_acc
          tot_training_samples += 1


      vloss_epoch = 0
      clip_model.eval()

      v_tot_image_to_text_acc = 0
      v_tot_text_to_image_acc = 0
      tot_validation_samples = 0

      with torch.no_grad():
        for inputs in tqdm(val_dataloader, desc="Validation", leave=False):

            to_caption = blip_processor(inputs, return_tensors="pt").to(device)

            captions = blip_model.generate(
                  **to_caption,
                  max_length=50,
                  num_return_sequences=1,
                  do_sample=True,
                  top_k=50,
                  top_p=0.95,
                  temperature=0.7,
                  repetition_penalty=1.2,
                  no_repeat_ngram_size=3
            )

            to_clip = [blip_processor.decode(caption, skip_special_tokens=True) for caption in captions]

            inputs_clip = clip_processor(text=to_clip, images=inputs, return_tensors="pt", padding="max_length", truncation=True).to(device)

            outputs = clip_model(**inputs_clip, return_loss=True)

            vloss = outputs.loss
            logits_per_image = outputs.logits_per_image
            logits_per_text = outputs.logits_per_text

            vloss_epoch += vloss.item()

            v_image_to_text_acc, v_text_to_image_acc = compute_accuracy(logits_per_image, logits_per_text)

            v_tot_image_to_text_acc += v_image_to_text_acc
            v_tot_text_to_image_acc += v_text_to_image_acc
            tot_validation_samples += 1

      metrics =  f"EPOCH {epoch}/{epochs}. Training Loss: {loss_epoch/len(train_dataloader)} Validation Loss: {vloss_epoch/len(val_dataloader)}"
      metrics += f"\nTraining image to text accuracy: {tot_image_to_text_acc/tot_training_samples} Training text to image accuracy: {tot_text_to_image_acc/tot_training_samples}"
      metrics += f"\nValidation image to text accuracy: {v_tot_image_to_text_acc/tot_validation_samples} Validation text to image accuracy: {v_tot_text_to_image_acc/tot_validation_samples}"
      with open(os.path.join(path, "metrics.txt"), "a") as f:
          f.write(metrics + "\n")
      torch.save(clip_model.state_dict(), os.path.join(path, "weights", f"epoch{epoch}.pt"))
      torch.save(optimizer.state_dict(), os.path.join(path, "optimizer", f"epoch{epoch}.pt"))
      print(metrics)

In [None]:
path_to_save = ""

In [None]:
train(100, clip_model, clip_processor, blip_model, blip_processor, train_dataloader, val_dataloader, optimizer, path_to_save, start_epoch=0)