### !ls /home/saurav/Documents/required_dataset/images

In [1]:
from tqdm import tqdm
import os
import clip
import pandas as pd
import torch
from PIL import Image
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import CLIPProcessor, CLIPModel

#loading CLIP MODEL and preprocessing funciton
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

2025-01-31 15:30:20.384477: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-31 15:30:20.440496: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
DATA_PATH = '/home/saurav/Documents/'
image_folder = DATA_PATH+'required_dataset/images'
text_file = DATA_PATH+'required_dataset/styles2.csv'
df = pd.read_csv(text_file)
total_rows = len(df)
batch_size = 2
embeddings = {}

In [3]:
def generate_description(row):
    columns = [
        str(row['gender']),
        str(row['masterCategory']),
        str(row['subCategory']),
        str(row['articleType']),
        str(row['baseColour']),
        str(row['season']),
        str(row['year']),
        str(row['usage']),
        str(row['productDisplayName'])
    ]
    return ' '.join(columns)

In [4]:
#ls /home/saurav/Documents/required_dataset/images

In [5]:
#custom dataset loader creation
class ImageTextDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, texts, preprocess):
        self.image_paths = image_paths
        self.texts = texts  #list of the corresponding texts
        self.preprocess = preprocess  #The CLIP processor

    def __len__(self):
        return len(self.texts)   #number of samples in the dataset

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        if not os.path.exists(image_path):
            print(f"Warning: image file {image_path} not found. Skipping this image")
            return None
            
        image = Image.open(image_path)    #open image
        text = self.texts[idx]                       #get corresponding text
        #inputs = self.preprocess(text=[text], images=image, return_tensors="pt", padding=True)
        text_inputs = clip.tokenize([text]).squeeze(0).to(device)
        image_inputs = self.preprocess(image).to(device)
       
        return{'image': image_inputs, 'text': text_inputs}    #return processed image-text pair

In [6]:
#creating a DataLoader
batch = df.iloc[0:total_rows]
image_id = [str(row['id']) for _,row in batch.iterrows()]
image_paths = [os.path.join(image_folder, f"{image_id}.jpg") for image_id in image_id]

texts = [generate_description(row) for _, row in batch.iterrows()]

dataset = ImageTextDataset(image_paths, texts, preprocess)   #create dataset
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) #create DataLoader  

In [7]:
#setting up training process
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)  
loss_fn = nn.CrossEntropyLoss()  

In [None]:
#training the model
model.to(device)  #move model to correct model
model.train()  #set model to training mode

num_epoch = 20

for epoch in tqdm(range(num_epoch)):
    for batch in tqdm(dataloader):
        if not batch:   #skip empty batch
            continue
        optimizer.zero_grad()  #reset gradients

        inputs = {k: v.to(device) for k,v in batch.items()}  #move batch to GPU/CPU
        image_features, text_features = model(**inputs)  #forward pass
        
        logits_per_image = (image_features @ text_features.T)  #image-text similarity score
        logits_per_text = logits_per_image.T  #text-image similarity score

        ground_truth = torch.arange(len(logits_per_image), device=device)  #create labels
        loss = (loss_fn(logits_per_image, ground_truth) + loss_fn(logits_per_text, ground_truth)) / 2   #compute loss

        loss.backward()  #Backpropagation (adjust model weights)
        optimizer.step()  #update model parameteres

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")
            

  0%|                                                                 | 0/20 [00:00<?, ?it/s]
  0%|                                                              | 0/22224 [00:00<?, ?it/s][A
  0%|                                                   | 1/22224 [00:02<15:35:20,  2.53s/it][A
  0%|                                                   | 2/22224 [00:04<14:01:09,  2.27s/it][A
  0%|                                                   | 3/22224 [00:06<13:50:52,  2.24s/it][A
  0%|                                                   | 4/22224 [00:08<13:36:01,  2.20s/it][A
  0%|                                                   | 5/22224 [00:11<13:53:52,  2.25s/it][A
  0%|                                                   | 6/22224 [00:13<14:07:03,  2.29s/it][A
  0%|                                                   | 7/22224 [00:15<13:37:00,  2.21s/it][A
  0%|                                                   | 8/22224 [00:17<13:13:58,  2.14s/it][A
  0%|                            

In [12]:
#saving the finetuned model
#saves the model's state dictionary
torch.save(model.state_dict(), "fine_tuned_clip/model_weights.pth")   #fine_tuned_clip is the directory

In [None]:
#loading the  clip model
model.load_state_dict(torch.load("fine_tuned_clip/model_weights.pth"))