In [12]:
from PIL import Image
from torch.utils.data import DataLoader, Dataset, ConcatDataset
import torchvision.transforms as transforms
import clip
from pathlib import Path
import pandas as pd
from typing import Optional, Dict
import torch
import os

transform = { 
    "base": transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
    ]),

    "crop": transforms.Compose([
        transforms.CenterCrop((266, 375)),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
    ]),

    "brightness": transforms.Compose([
        transforms.ColorJitter(brightness=0.3),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
    ])
}

def augmented_dataset(path, mode):
    base_dataset = get_dataset(path, mode, augmentation_type= "base" )
    crop_dataset = get_dataset(path, mode, augmentation_type="crop")
    brightness_dataset = get_dataset(path, mode, augmentation_type="brightness")

    return ConcatDataset([base_dataset, crop_dataset, brightness_dataset])


class ImageDataset(Dataset):
    def __init__(self, data: pd.DataFrame, transform: Optional[Dict[str, transforms.Compose]] = None, augmentation_type: str = None):
        self.data = data
        self.transform = transform
        self.augmentation_type = augmentation_type

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        labels = self.data.iloc[idx]
        path = labels["IMAGE_PATH"]
        description = labels["DESCRIPTION"]
        image_id = labels["MFC"]

        assert os.path.exists(path), f"Image {path} does not exist"
        image = Image.open(path).convert("RGB")
        description = clip.tokenize(description, truncate=True)

        if self.transform:
            image = self.transform[self.augmentation_type](image)
        return image, description, (self.augmentation_type, image_id)

    @staticmethod
    def collate_fn(batch):
        images, description, images_id = zip(*batch)
        return torch.stack(images), torch.vstack(description), list(images_id)


def get_dataset(path: str, df_path: str, augmentation_type: str) -> Dataset:
    """
    Returns a DataLoader object for the dataset at the specified path.
    """
    df = pd.read_csv(df_path)
    path = Path(path) if isinstance(path, str) else path
    df["IMAGE_PATH"] = df["IMAGE_PATH"].apply(lambda x: path / x)
    return ImageDataset(data=df, augmentation_type=augmentation_type, transform=transform)


In [23]:
ecommerce_path="/kaggle/input/armani-catalogue/images/images"
train_csv="/kaggle/input/armani-traintest/train.csv"
test_csv= "/kaggle/input/armani-traintest/test.csv"
dataset = augmented_dataset(ecommerce_path,train_csv)
test_dataset= get_dataset(ecommerce_path,test_csv, "base")
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=os.cpu_count(), collate_fn=ImageDataset.collate_fn)
test_dataloader= DataLoader(test_dataset, batch_size=128, shuffle=True, num_workers=os.cpu_count(), collate_fn=ImageDataset.collate_fn)

In [17]:
from transformers import CLIPProcessor, CLIPModel
import torch
from tqdm.notebook import tqdm
from torch import nn
import clip



# Load the model and processor
device = "cuda:0" if torch.cuda.is_available() else "cpu" # If using GPU then use mixed precision training.
model, preprocess = clip.load("ViT-B/32",device=device,jit=False) #Must set jit=False for training



In [18]:
device

'cuda:0'

In [None]:
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-6,betas=(0.9,0.98),eps=1e-6,weight_decay=0.4) #Params used from paper, the lr is smaller, more safe for fine tuning to new dataset
#define the number of epochs
num_epochs = 20

# Early Stopping
patience = 5  # Numero massimo di epoche senza miglioramenti
best_val_loss = float("inf")
no_improvement = 0


#training loop
for epoch in range(num_epochs):
    model.train()
    #plotting the progress bar  
    pbar = tqdm(train_dataloader, total=len(train_dataloader))
    #iterate over the training dataloader
    train_loss=0
    val_loss=0
    for batch in pbar:
         #set the gradients to zero
        optimizer.zero_grad()
        #load the images and texts
        images,texts, _ = batch 

        
        images= images.to(device)

        texts = texts.to(device)
    
        logits_per_image, logits_per_text = model(images, texts)

        # Compute loss
        ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
        total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
        
        train_loss+=total_loss.item()
        
        # Backward pass
        total_loss.backward()

        optimizer.step()

    avg_train_loss = train_loss/len(train_dataloader)
    tqdm.write(f"Epoch {epoch+1}/{num_epochs},Train Loss: {avg_train_loss:.4f}")
    #validation step
    model.eval()
    with torch.no_grad():
        #iterate over the test dataloader
        for batch in test_dataloader:
            images, texts, _ = batch
            images= images.to(device)
            texts = texts.to(device)
        
            logits_per_image, logits_per_text = model(images, texts)
            
            # Compute loss
            ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
            total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
            val_loss+=total_loss.item()

    avg_val_loss = val_loss/len(test_dataloader)
    tqdm.write(f"Epoch {epoch+1}/{num_epochs}, Val Loss: {avg_val_loss:.4f}")

     # Check Early Stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        no_improvement = 0
        # Salva il modello migliore
        torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': total_loss,
        }, f"/kaggle/working/clipmodel_5.pt") #just change to your preferred folder/filename
        print("Validation loss improved. Model saved.")
    else:
        no_improvement += 1
        print(f"No improvement for {no_improvement} epoch(s).")

    if no_improvement >= patience:
        print("Early stopping triggered.")
        break
        

    #update the progress bar
    pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {total_loss.item():.4f}")

  0%|          | 0/265 [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 1/20,Train Loss: 1.3992
Epoch 1/20, Val Loss: 1.0607
Validation loss improved. Model saved.


  0%|          | 0/265 [00:00<?, ?it/s]

Epoch 2/20,Train Loss: 0.6878
Epoch 2/20, Val Loss: 0.9388
Validation loss improved. Model saved.


  0%|          | 0/265 [00:00<?, ?it/s]

Epoch 3/20,Train Loss: 0.4530
Epoch 3/20, Val Loss: 0.9179
Validation loss improved. Model saved.


  0%|          | 0/265 [00:00<?, ?it/s]