# Installing CLIP

In [64]:
# !pip install ftfy regex tqdm
# !pip install git+https://github.com/openai/CLIP.git

In [65]:
from PIL import Image
import torch
from torch import nn, optim
import glob
import os
import pandas as pd
import json
import numpy as np
import clip
from torch.utils.data import Dataset, DataLoader, BatchSampler
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
import random
from matplotlib.pyplot import imshow
import torchtext
import nltk, re, string, collections
from nltk.util import ngrams
import collections
%matplotlib inline
BATCH_SIZE = 10
EPOCH = 5

# Clip 모델

In [66]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

# 전처리 마친 데이터셋 (pickle)

In [67]:
df = pd.read_pickle("image_processed_data")

In [68]:
df.dropna(subset=['processed_image'], inplace=True)

In [69]:
df.reset_index(drop=True, inplace=True)

In [70]:
# from googletrans import Translator
# def translate(text):
#     translator = Translator()
#     translated = translator.translate(text, src='ko', dest='en')
#     return translated.text

# df['caption']=df['caption'].apply(translate)

In [71]:
df['caption']="abc"

## train_test_split (0.2)

In [72]:
train, test = train_test_split(df, test_size=0.2, random_state=42)
len(train), len(test)

(412, 104)

## 학습용 Dataset으로 변환 

In [73]:
from torch.utils.data import Dataset

class ClipDataset(Dataset):
    def __init__(self, df, preprocess):
        self.preprocess = preprocess
        self.df = df
        self.img_paths = df['image_link'].tolist()
        self.captions = df['caption'].tolist()
        self.processed_images = df['processed_image'].tolist()
        self.path2label = {path: i for i, path in enumerate(self.img_paths)}
        
    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = self.processed_images[idx]
        caption = self.captions[idx]
        label = self.path2label[img_path]
        return image, caption, label

In [74]:
train_dataset = ClipDataset(train, preprocess)
test_dataset = ClipDataset(test, preprocess)

In [75]:
len(train_dataset), len(test_dataset), train_dataset[0]

(412,
 104,
 (tensor([[[1.7260, 1.7260, 1.7260,  ..., 1.6092, 1.6092, 1.6092],
           [1.7260, 1.7260, 1.7260,  ..., 1.6092, 1.6092, 1.6092],
           [1.7260, 1.7260, 1.7260,  ..., 1.6092, 1.6092, 1.6092],
           ...,
           [1.7260, 1.7260, 1.7260,  ..., 1.6530, 1.6530, 1.6530],
           [1.7260, 1.7260, 1.7260,  ..., 1.6530, 1.6530, 1.6530],
           [1.7260, 1.7260, 1.7260,  ..., 1.6530, 1.6530, 1.6530]],
  
          [[1.9548, 1.9548, 1.9548,  ..., 1.8348, 1.8348, 1.8348],
           [1.9548, 1.9548, 1.9548,  ..., 1.8348, 1.8348, 1.8348],
           [1.9548, 1.9548, 1.9548,  ..., 1.8348, 1.8348, 1.8348],
           ...,
           [1.9548, 1.9548, 1.9548,  ..., 1.8498, 1.8498, 1.8498],
           [1.9548, 1.9548, 1.9548,  ..., 1.8498, 1.8498, 1.8498],
           [1.9548, 1.9548, 1.9548,  ..., 1.8498, 1.8498, 1.8498]],
  
          [[2.0321, 2.0321, 2.0321,  ..., 1.9184, 1.9184, 1.9184],
           [2.0321, 2.0321, 2.0321,  ..., 1.9184, 1.9184, 1.9184],
          

In [76]:
len(test_dataset), len(test_dataset), test_dataset[0]

(104,
 104,
 (tensor([[[1.7406, 1.7406, 1.7552,  ..., 1.6676, 1.6676, 1.6676],
           [1.7406, 1.7406, 1.7552,  ..., 1.6530, 1.6676, 1.6676],
           [1.7406, 1.7406, 1.7552,  ..., 1.6530, 1.6676, 1.6676],
           ...,
           [1.7114, 1.7114, 1.7114,  ..., 1.6384, 1.6530, 1.6530],
           [1.7114, 1.7114, 1.7114,  ..., 1.6384, 1.6384, 1.6530],
           [1.7260, 1.7260, 1.7260,  ..., 1.6384, 1.6384, 1.6384]],
  
          [[1.9698, 1.9698, 1.9848,  ..., 1.8798, 1.8948, 1.8948],
           [1.9698, 1.9698, 1.9848,  ..., 1.8798, 1.8948, 1.8948],
           [1.9698, 1.9698, 1.9848,  ..., 1.8798, 1.8948, 1.8948],
           ...,
           [1.9098, 1.9098, 1.9098,  ..., 1.8648, 1.8798, 1.8798],
           [1.9098, 1.9098, 1.9098,  ..., 1.8648, 1.8648, 1.8798],
           [1.9248, 1.9248, 1.9248,  ..., 1.8648, 1.8648, 1.8648]],
  
          [[2.0464, 2.0464, 2.0606,  ..., 1.9610, 1.9753, 1.9753],
           [2.0464, 2.0464, 2.0606,  ..., 1.9610, 1.9753, 1.9753],
          

In [77]:
i = 0
for k,v in train_dataset.path2label.items():
    i+=1
    print(k,v)
    if i == 10:
        break

https://image.msscdn.net/images/style/detail/41600/detail_41600_6627945de2203_500.jpg 0
https://image.msscdn.net/images/style/detail/41620/detail_41620_6627b05821ca8_500.jpg 1
https://image.msscdn.net/images/style/detail/41610/detail_41610_6627975b8dfbf_500.jpg 2
https://image.msscdn.net/images/style/detail/41498/detail_41498_66210236d45a6_500.jpg 3
https://image.msscdn.net/images/style/detail/41624/detail_41624_6627b60a0ba47_500.jpg 4
https://image.msscdn.net/images/style/detail/41610/detail_41610_6627975b227cb_500.jpg 5
https://image.msscdn.net/images/style/detail/41596/detail_41596_662796fd21413_500.jpg 6
https://image.msscdn.net/images/style/detail/41632/detail_41632_6629a1f2b8120_500.jpg 7
https://image.msscdn.net/images/style/detail/41567/detail_41567_6626781114d9f_500.jpg 8
https://image.msscdn.net/images/style/detail/41600/detail_41600_6627945d799c2_500.jpg 9


## BalancedBatchSampler (ensures no same class per batch)

In [78]:
# https://github.com/pytorch/pytorch/blob/e5742494f6080c8e6f43c37689fc18a7c4b39dfd/torch/utils/data/dataloader.py#L145
class BalancedBatchSampler(BatchSampler):
    """
    BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples.
    Returns batches of size n_classes * n_samples
    """

    def __init__(self, labels, n_classes, n_samples):
        self.labels = labels
        self.labels_set = list(set(self.labels.numpy()))
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
                                 for label in self.labels_set}
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l])
        self.used_label_indices_count = {label: 0 for label in self.labels_set}
        self.count = 0
        self.n_classes = n_classes
        self.n_samples = n_samples
        self.n_dataset = len(self.labels)
        self.batch_size = self.n_samples * self.n_classes

    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < self.n_dataset:
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
            indices = []
            for class_ in classes:
                indices.extend(self.label_to_indices[class_][
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples])
                self.used_label_indices_count[class_] += self.n_samples
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            yield indices
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return self.n_dataset // self.batch_size
    
train_labels = torch.tensor([item[2] for item in train_dataset])
train_sampler = BalancedBatchSampler(train_labels, BATCH_SIZE, 1)
train_dataloader = DataLoader(train_dataset, batch_sampler=train_sampler)

test_labels = torch.tensor([item[2] for item in test_dataset])
test_sampler = BalancedBatchSampler(test_labels, BATCH_SIZE, 1)
test_dataloader = DataLoader(test_dataset, batch_sampler=test_sampler)

In [79]:
for i, item in enumerate(train_sampler):
    labels = []
    for idx in item:
        label = train_dataset[idx][2]
        labels.append(label)
    break
len(labels), len(set(labels))

(10, 10)

In [80]:
for batch in test_dataloader:
    imgs, txts, labels = batch
    print(imgs.shape)
    print(len(txts))
    print(labels)
    print(labels.shape)
    print(torch.unique(labels).shape)
    break

torch.Size([10, 3, 224, 224])
10
tensor([  1,  92,  60,   4,  52,  82,  85,  69, 100,  37])
torch.Size([10])
torch.Size([10])


# Training

In [81]:
#https://github.com/openai/CLIP/issues/57
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 

if device == "cpu":
    model.float()

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_dataloader)*EPOCH)

In [82]:
best_te_loss = 1e5
best_ep = -1
for epoch in range(EPOCH):
    print(f"running epoch {epoch}, best test loss {best_te_loss} after epoch {best_ep}")
    step = 0
    tr_loss = 0
    model.train()
    pbar = tqdm(train_dataloader, leave=False)
    for batch in pbar:
        step += 1
        optimizer.zero_grad()

        images, texts, _ = batch
        images = images.to(device)
        texts = clip.tokenize(texts).to(device)
#         print(images.shape, texts.shape)
        logits_per_image, logits_per_text = model(images, texts)
        ground_truth = torch.arange(BATCH_SIZE).to(device)

        total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
        total_loss.backward()
        tr_loss += total_loss.item()
        if device == "cpu":
            optimizer.step()
            scheduler.step()
        else:
            convert_models_to_fp32(model)
            optimizer.step()
            scheduler.step()
            clip.model.convert_weights(model)
        pbar.set_description(f"train batchCE: {total_loss.item()}", refresh=True)
    tr_loss /= step
    
    step = 0
    te_loss = 0
    with torch.no_grad():
        model.eval()
        test_pbar = tqdm(test_dataloader, leave=False)
        for batch in test_pbar:
            step += 1
            images, texts, _ = batch
            images = images.to(device)
            texts = clip.tokenize(texts).to(device)
            logits_per_image, logits_per_text = model(images, texts)
            ground_truth = torch.arange(BATCH_SIZE).to(device)

            total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
            te_loss += total_loss.item()
            test_pbar.set_description(f"test batchCE: {total_loss.item()}", refresh=True)
        te_loss /= step
        
    if te_loss < best_te_loss:
        best_te_loss = te_loss
        best_ep = epoch
        torch.save(model.state_dict(), "best_model.pt")
    print(f"epoch {epoch}, tr_loss {tr_loss}, te_loss {te_loss}")
torch.save(model.state_dict(), "last_model.pt")

running epoch 0, best test loss 100000.0 after epoch -1


  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

epoch 0, tr_loss 2.3521341463414633, te_loss 2.3080078125
running epoch 1, best test loss 2.3080078125 after epoch 0


  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

epoch 1, tr_loss 2.3070217225609757, te_loss 2.3052734375
running epoch 2, best test loss 2.3052734375 after epoch 1


  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

epoch 2, tr_loss 2.304973323170732, te_loss 2.30625
running epoch 3, best test loss 2.3052734375 after epoch 1


  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

epoch 3, tr_loss 2.3048780487804876, te_loss 2.305078125
running epoch 4, best test loss 2.305078125 after epoch 3


  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

epoch 4, tr_loss 2.305020960365854, te_loss 2.30546875
