<p align="center">بسم الله الرحمن الرحیم</p>

# Library

In [None]:
! pip install transformers
! pip install git+https://github.com/openai/CLIP.git

In [None]:
import gc
import time
import copy
import PIL
import torch
import os
import dill
import clip
import requests
from tqdm import tqdm
import pandas as pd
import numpy as np
from PIL import Image
import torch.nn as nn
import matplotlib.pyplot as plt
from pkg_resources import packaging
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast
from sklearn.model_selection import train_test_split
from transformers import CLIPModel, CLIPConfig, CLIPVisionModel, CLIPFeatureExtractor
from transformers import AutoModel, AutoTokenizer, AutoModel, TFAutoModel, AutoConfig
from transformers import BertModel
from transformers import TrainingArguments, Trainer, RobertaModel
from transformers import default_data_collator

print("Torch version:", torch.__version__)

Torch version: 1.13.1+cu116


# Config and Hyper-parameter

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
TEST_SIZE = 0.1
VAL_SIZE = 0.1
BATCH_SIZE = 256
EPOCH = 10
LR = 1e-7
EPS = 1e-9
WEIGHT_DECAY = 0.1
MAX_LR = 1e-2
BASE_MODEL_PATH = '/content/drive/MyDrive/clip_trained_model/'

# Data Prepare

In [None]:
df = pd.DataFrame()

In [None]:
class CLIP_Dataset(Dataset):
    def __init__(self, list_image_path, list_text):
        self.image_path = list_image_path
        self.texts  = clip.tokenize(list_text)

    def __len__(self):
        return len(self.title)

    def __getitem__(self, idx):
        image = preprocess(Image.open(self.image_path[idx]))
        text = self.texts[idx]
        return image, text

In [None]:
train_df, test_df = train_test_split(df, test_size=TEST_SIZE, shuffle=True, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=VAL_SIZE, shuffle=True, random_state=42)

train_dataset = CLIP_Dataset(train_df['image_path'].tolist(), train_df['text'].tolist())
val_dataset = CLIP_Dataset(val_df['image_path'].tolist(), val_df['text'].tolist())
test_dataset = CLIP_Dataset(test_df['image_path'].tolist(), test_df['text'].tolist())

dataloader = {"train": [], "val": [], "test": []}
dataloader["train"] = DataLoader(train_dataset, batch_size=BATCH_SIZE)
dataloader["val"] = DataLoader(val_dataset, batch_size=BATCH_SIZE)
dataloader["test"] = DataLoader(test_dataset, batch_size=BATCH_SIZE)

In [None]:
for phase in ["train", "val", "test"]:
    print(f"=== {phase} ===")
    for item in dataloader[phase]:
      print(f"image shape: {item[0].shape}")
      print(f"text shape: {item[0].shape}")
      break

# Model Loading

In [None]:
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [None]:
model, preprocess = clip.load("ViT-L/14", device=device, jit=False)
model = model.float()

In [None]:
layer_num = [0, 0]
freeze_layer_thr = 20

for param in model.transformer.parameters():
    layer_num[0]+=1
for param in model.visual.parameters():
    layer_num[1]+=1

for i, param in enumerate(model.transformer.parameters()):
    if freeze_layer_thr >= layer_num[0] - i:
        param.requires_grad = True
    else:
        param.requires_grad = False
for i, param in enumerate(model.visual.parameters()):
    if freeze_layer_thr >= layer_num[1] - i:
        param.requires_grad = True
    else:
        param.requires_grad = False

# Optimizer, Loss, LR_Scheduler

In [None]:
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), 
                             lr=LR,
                             betas=(0.9,0.98),
                             eps=EPS,
                             weight_decay=WEIGHT_DECAY) 

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 
                                                max_lr=MAX_LR, 
                                                steps_per_epoch=len(dataloader['train']), 
                                                epochs=EPOCH)

# Training

In [None]:
def train_model(model, optimizer, scheduler, num_epochs=10):
    since = time.time()

    for epoch in range(4, num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            total_loss = 0.0
            num = 0

            for batch in tqdm(dataloader[phase]):
                optimizer.zero_grad()
                images, texts = batch
                images = images.to(device)
                texts = texts.to(device)

                with torch.set_grad_enabled(phase == 'train'):
                    logits_per_image, logits_per_text = model(images, texts)
                    ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
                    batch_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2

                    if phase == 'train':
                        batch_loss.backward()
                        optimizer.step()

                total_loss += batch_loss
                num += 1
            if phase == 'train':
                scheduler.step()

            epoch_loss = total_loss / num

            print(f'{phase} Loss: {epoch_loss:.4f}')

            if phase == 'train':
                torch.save(model.state_dict(), BASE_MODEL_PATH+f'clip_en_fi_ep{epoch}.pt')
            
        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')

In [None]:
train_model(model=model, optimizer=optimizer, scheduler=scheduler, num_epochs=EPOCH)

# Testing

In [None]:
model.eval()
total_loss = 0
num = 0
with torch.no_grad():
    for batch in tqdm(dataloader['test']):
        images,texts = batch
        images= images.to(device)
        texts = texts.to(device)
        logits_per_image, logits_per_text = model(images, texts)
        ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
        total_loss += (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth))/2
        num += 1
total_loss = total_loss/num
print(f'Test loss: {total_loss}')