In [1]:
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch
import open_clip
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim

In [2]:
captions = pd.read_json('captions.json', orient='index')
captions.reset_index(inplace=True)
captions.columns = ['image_path', 'text']
captions

Unnamed: 0,image_path,text
0,MEN-Denim-id_00000080-01_7_additional.jpg,The lower clothing is of long length. The fabr...
1,MEN-Denim-id_00000089-01_7_additional.jpg,"His tank top has sleeves cut off, cotton fabri..."
2,MEN-Denim-id_00000089-02_7_additional.jpg,"His sweater has long sleeves, cotton fabric an..."
3,MEN-Denim-id_00000089-03_7_additional.jpg,"His shirt has short sleeves, cotton fabric and..."
4,MEN-Denim-id_00000089-04_7_additional.jpg,"The sweater the person wears has long sleeves,..."
...,...,...
42539,WOMEN-Tees_Tanks-id_00007979-04_4_full.jpg,The lady wears a tank tank shirt with pure col...
42540,WOMEN-Tees_Tanks-id_00007979-04_7_additional.jpg,The person wears a sleeveless tank shirt with ...
42541,WOMEN-Tees_Tanks-id_00007981-03_1_front.jpg,This woman wears a sleeveless tank top with ot...
42542,WOMEN-Tees_Tanks-id_00007981-03_3_back.jpg,The tank top the lady wears has sleeves cut of...


In [3]:
class FashionOutfitDataset(Dataset):
    def __init__(self, data, preprocess, tokenizer):
        self.data = data
        self.preprocess = preprocess
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        image_path = self.data.iloc[index, 0]
        text = self.data.iloc[index, 1]

        image = self.preprocess(Image.open(f'images/{image_path}').convert("RGB"))
        tokens = self.tokenizer([text])[0]

        return image, tokens

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu" 
model, preprocess, tokenizer = open_clip.create_model_and_transforms(
    'ViT-B-32', pretrained='laion2b_s34b_b79k'
)
tokenizer = open_clip.get_tokenizer('ViT-B-32')

dataset = FashionOutfitDataset(captions, preprocess, tokenizer)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

model = model.to(device)

In [5]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-6)

def train_clip(model, dataloader, epochs=3):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for images, texts in tqdm(dataloader):
            images = images.to(device)
            texts = texts.to(device)

            image_features = model.encode_image(images)
            text_features = model.encode_text(texts)

            # normalize
            image_features = image_features / image_features.norm(dim=1, keepdim=True)
            text_features = text_features / text_features.norm(dim=1, keepdim=True)

            logits_per_image = image_features @ text_features.T
            logits_per_text = text_features @ image_features.T

            labels = torch.arange(len(images), device=device)

            loss = (loss_fn(logits_per_image, labels) + loss_fn(logits_per_text, labels)) / 2

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        
        print(f"Epoch {epoch+1} - Loss: {total_loss:.4f}")

In [6]:
train_clip(model, dataloader, epochs=5)

  1%|          | 8/1330 [01:50<5:03:52, 13.79s/it]


KeyboardInterrupt: 