In [None]:
# !pip install open_clip_torch
# !pip install transformers
# !pip install torch # DO NOT RUN IN GCP NOTEBOOK!!!
# !pip install open_clip_torch
# !pip install git+https://github.com/openai/CLIP.git

In [None]:
import pandas as pd
from torch.utils.data import DataLoader
from torch import nn
from torch import optim
from PIL import Image
from tqdm import tqdm
import open_clip
import torch
from torch.utils.data import Dataset

device = "cuda:0" if torch.cuda.is_available() else "cpu" # If using GPU then use mixed precision training.

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# Loading the training dataset
subset = pd.read_csv('train_subset.csv') # Change this if you want to use other data!

In [None]:
def get_image_file_path(id):
    folder_id = "0" + str(id)
    first_three = str(folder_id)[:3]
    image_path = f"images/{first_three}/{folder_id}.jpg"
    return image_path

In [None]:
# concatenate text fields as text input

def generate_combined_string(row):
    first_sentence = row['detail_desc'].split('. ')[0]
    perceived_colour_master_name = row['perceived_colour_master_name'].lower()
    if perceived_colour_master_name.lower() in ['undefined', 'unknown']:
        colour_string = ""
    else:
        perceived_colour_value_name = row['perceived_colour_value_name']
        perceived_colour_value_name = perceived_colour_value_name[0].upper() + perceived_colour_value_name[1:].lower()  # Capitalize first letter, convert rest to lowercase
        colour_string = f"{perceived_colour_value_name} {perceived_colour_master_name}"
        first_sentence = first_sentence[0].lower() + first_sentence[1:]  # Convert first letter to lowercase
    return f"{colour_string} {first_sentence}"

text_data = subset.apply(generate_combined_string, axis=1)

# Show some example lines
for i in range(30):
    print(text_data[i])
    
image_paths = subset["article_id"].apply(lambda article_id: get_image_file_path(article_id)).tolist()

Dusty light khaki green short dress in a patterned viscose weave with a small stand-up collar and small V-neck opening at the top
Medium dusty khaki green top in lightweight sweatshirt fabric in a relaxed fit with long raglan sleeves and ribbing around the neckline, cuffs and hem.
Dusty light grey sleeveless top in soft jersey with a sheen
Dark mole baby Exclusive
Medium blue cold shoulder blouse in woven fabric with narrow, adjustable shoulder straps, elastication and a flounce at the top, long sleeves with elastication at the cuffs, and a smocked hem.
Dusty light pink wide sports top in fast-drying functional fabric with a slightly wider neckline, short cap sleeves and a drawstring at one side of the hem
Dark black ankle-length jeans in washed, stretch denim with a high waist, zip fly and button, patch front pockets, back pockets and straight, wide legs.
Light white shirt in a linen weave in a straight, relaxed style with a grandad collar, classic front and yoke at the back
Light whi

In [None]:
import os

# Remove texts which have no corresponding img
filtered_texts = list(filter(lambda text_path: os.path.isfile(text_path[1]), zip(text_data, image_paths)))
filtered_paths = list(map(lambda text_path: text_path[1], filtered_texts))

# Extract the text_data from filtered_texts
filtered_texts = list(map(lambda text_path: text_path[0], filtered_texts))

In [None]:
# Define hyperparameters
BATCH_SIZE = 128
EPOCH = 100
LR = 1e-7
PATIENCE = 5

In [None]:
# Load model
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K', device=device)
tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
preprocess = preprocess_train

In [None]:
# Inspiration from https://github.com/openai/CLIP/issues/83

class image_title_dataset(Dataset):
    def __init__(self, list_image_path,list_txt):

        self.image_path = list_image_path
        self.title  = tokenizer(list_txt)

    def __len__(self):
        return len(self.title)

    def __getitem__(self, idx):
        image = preprocess(Image.open(self.image_path[idx])) # Image from PIL module
        title = self.title[idx]
        return image,title

In [None]:
# ready the data
list_image_path = filtered_paths
list_txt = filtered_texts
dataset = image_title_dataset(list_image_path,list_txt)

# validation set for early stopping
train, val = torch.utils.data.random_split(dataset, [0.9, 0.1])
train_dataloader = DataLoader(train,batch_size = BATCH_SIZE) #Define your own dataloader
val_dataloader = DataLoader(val,batch_size = BATCH_SIZE) #Define your own dataloader

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR,betas=(0.9,0.98),eps=1e-6) #Params used from paper, the lr is smaller, more safe for fine tuning to new dataset

In [None]:
# add your own code to track the training progress.
min_loss = 1000000000000
min_epoch = None
stop_training = False
for epoch in range(EPOCH):
    train_loss=0.0
    for batch in tqdm(train_dataloader):
        optimizer.zero_grad()

        images,texts = batch

        images= images.to(device)
        texts = texts.to(device)

        outputs = model(images, texts)
        logits_per_image = outputs[0]
        logits_per_text = outputs[1]


        ground_truth = torch.arange(len(images),dtype=torch.long,device=device)

        total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
        total_loss.backward()

        train_loss+=total_loss.item()
        
        optimizer.step()

    train_loss /= len(train_dataloader)

    print(f'TRAIN EPOCH:\t{epoch},\tLOSS:\t{train_loss}')
    # validation
    with torch.no_grad():
        val_loss = 0
        for batch in val_dataloader:
            images,texts = batch

            images= images.to(device)
            texts = texts.to(device)

            outputs = model(images, texts)
            logits_per_image = outputs[0]
            logits_per_text = outputs[1]


            ground_truth = torch.arange(len(images),dtype=torch.long,device=device)

            total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
            val_loss+=total_loss.item()
        val_loss /= len(val_dataloader)
        print(f'VAL EPOCH:\t{epoch},\tLOSS:\t{val_loss}')


        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': train_loss, # this total_loss is really just the loss of the most recent sample
        }, f"models/colour_first_lr_{LR}_epoch_{epoch}_valloss_{val_loss}_better.pt") #just change to your preferred folder/filename

    # early stopping
    min_loss = min(min_loss, val_loss)
    if min_loss == val_loss:
        min_epoch = epoch
        print(f'New minimum loss reached at epoch {epoch}')

    if epoch - min_epoch >= PATIENCE:
        stop_training = True
        print(f'Loss did not improve in last {PATIENCE} epochs. Stopping.')

    if stop_training:
        break

100%|██████████| 580/580 [53:37<00:00,  5.55s/it]


TRAIN EPOCH:	0,	LOSS:	6.227045070713964
VAL EPOCH:	0,	LOSS:	6.213827683375432
New minimum loss reached at epoch 0


100%|██████████| 580/580 [53:26<00:00,  5.53s/it]


TRAIN EPOCH:	1,	LOSS:	6.2016163390258265
VAL EPOCH:	1,	LOSS:	6.190689783829909
New minimum loss reached at epoch 1


100%|██████████| 580/580 [53:22<00:00,  5.52s/it]


TRAIN EPOCH:	2,	LOSS:	6.182627651609224
VAL EPOCH:	2,	LOSS:	6.176787618490366
New minimum loss reached at epoch 2


100%|██████████| 580/580 [53:01<00:00,  5.48s/it]


TRAIN EPOCH:	3,	LOSS:	6.1726986794636165
VAL EPOCH:	3,	LOSS:	6.1703459959763745
New minimum loss reached at epoch 3


100%|██████████| 580/580 [52:52<00:00,  5.47s/it]


TRAIN EPOCH:	4,	LOSS:	6.168175218845236
VAL EPOCH:	4,	LOSS:	6.167339376302865
New minimum loss reached at epoch 4


100%|██████████| 580/580 [53:01<00:00,  5.48s/it]


TRAIN EPOCH:	5,	LOSS:	6.16599769016792
VAL EPOCH:	5,	LOSS:	6.16579426251925
New minimum loss reached at epoch 5


100%|██████████| 580/580 [53:15<00:00,  5.51s/it]


TRAIN EPOCH:	6,	LOSS:	6.164820159714798
VAL EPOCH:	6,	LOSS:	6.164921797238863
New minimum loss reached at epoch 6


100%|██████████| 580/580 [52:58<00:00,  5.48s/it]


TRAIN EPOCH:	7,	LOSS:	6.164113446761822
VAL EPOCH:	7,	LOSS:	6.1643769631019
New minimum loss reached at epoch 7


100%|██████████| 580/580 [52:40<00:00,  5.45s/it]


TRAIN EPOCH:	8,	LOSS:	6.163654725304966
VAL EPOCH:	8,	LOSS:	6.164038071265588
New minimum loss reached at epoch 8


100%|██████████| 580/580 [52:43<00:00,  5.45s/it]


TRAIN EPOCH:	9,	LOSS:	6.163334690291306
VAL EPOCH:	9,	LOSS:	6.163804765848013
New minimum loss reached at epoch 9


 18%|█▊        | 105/580 [09:35<43:37,  5.51s/it]

In [None]:
# # To load the fine-tuned model do:
# model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K', device=device)
# checkpoint = torch.load('models/coolmodel.pt') # for example models/CLIP-ViT-B-32-laion2B-s34B-b79K_1_epoch.pt
# model.load_state_dict(checkpoint['model_state_dict'])