In [1]:
import os
import sys
import numpy as np
import torch

from datetime import datetime
from typing import Tuple
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from torch.nn import KLDivLoss, CrossEntropyLoss, CosineEmbeddingLoss, MSELoss
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau

### Loading CLIP Model

In [16]:
from clip.model import Transformer, VisionTransformer
from clip.model import convert_weights
from clip.model import CLIP


embed_dim = 512
image_resolution = 224
vision_layers = 12
vision_width = 768
vision_patch_size = 32
context_length = 77
vocab_size = 49408
transformer_width = 512
transformer_heads = 8
transformer_layers = 12

clip_model = CLIP(
    embed_dim,
    image_resolution,
    vision_layers,
    vision_width,
    vision_patch_size,
    context_length,
    vocab_size,
    transformer_width,
    transformer_heads,
    transformer_layers,
)

### Convert Dataset into Torch Dataloader

In [17]:
import clip
from torch.utils.data import DataLoader
from datasets import load_from_disk

dset = load_from_disk("./data/processed")


_, preprocess = clip.load("ViT-B/32")


def transform_func(examples):
    examples["image"] = [preprocess(img) for img in examples["image"]]
    return examples


dset = dset.with_transform(transform_func)

train_dataloader = DataLoader(dset, batch_size=16, shuffle=True, num_workers=8)

In [18]:
dset

Dataset({
    features: ['caption', 'image'],
    num_rows: 7012
})

In [28]:
class Trainer:
    def __init__(self, clip_model, train_dataloader):

        self.model = clip_model
        self.train_dataloader = train_dataloader

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.model = self.model.to(self.device)

        self.epochs = 30
        self.start_epoch = 1
        self.temperature = 1

        self.loss_img = nn.CrossEntropyLoss()
        self.loss_txt = nn.CrossEntropyLoss()
        # set up optimizer
        self.optimizer = Adam(
            self.model.parameters(),
            lr=1e-3,
            betas=(0.9, 0.98),
            eps=1e-6,
            weight_decay=0.2,
        )
        
    def cross_entropy(self, preds, targets, reduction="none"):
        log_softmax = nn.LogSoftmax(dim=-1)
        loss = (-targets * log_softmax(preds)).sum(1)
        if reduction == "none":
            return loss
        elif reduction == "mean":
            return loss.mean()
    
    def compute_loss(self, images, texts, return_outputs=False):
        texts = clip.tokenize(texts)
        texts = texts.to(self.device)
        images = images.to(self.device).half()

        image_embeddings = self.model.encode_image(images)
        text_embeddings = self.model.encode_text(texts)
        
        logits = (text_embeddings @ image_embeddings.T) / self.temperature

        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T

        targets = F.softmax(
            (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )

        texts_loss = self.cross_entropy(logits, targets, reduction="none")
        images_loss = self.cross_entropy(logits.T, targets.T, reduction="none")
        con_loss = (images_loss + texts_loss) / 2.0
        con_loss = con_loss.mean()
        return con_loss

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            print(f"Starting Epoch {epoch} ------------------------------------------")
            loss_value = self._train_epoch(epoch)
            print(f"Combined Loss Value after {epoch} Epoch is {loss_value}")

    def _train_epoch(self, epoch):
        loss_value = 0
        for batch_idx, data in enumerate(self.train_dataloader):

            self.optimizer.zero_grad()

            texts = data["caption"]
            images = data["image"]

            loss = self.compute_loss(images, texts)

            loss_value += loss

            loss.backward()

            self.optimizer.step()

            if batch_idx % 100 == 0:
                print(
                    f"Loss after {batch_idx}/{len(self.train_dataloader)} Batch is {loss_value/(batch_idx+1)} "
                )

        return loss_value.detach().cpu().numpy() / len(self.train_dataloader)

In [29]:
Trainer = Trainer(
    clip_model=clip_model,
    train_dataloader=train_dataloader,
)

In [30]:
Trainer.train()

Starting Epoch 1 ------------------------------------------
Loss after 0/439 Batch is 28.542377471923828 
Loss after 100/439 Batch is 5.4333343505859375 
Loss after 200/439 Batch is 3.624295473098755 
Loss after 300/439 Batch is 2.9344849586486816 
Loss after 400/439 Batch is 3.1431682109832764 
Combined Loss Value after 1 Epoch is 3.00731281810578
Starting Epoch 2 ------------------------------------------
Loss after 0/439 Batch is 2.125880002975464 
Loss after 100/439 Batch is 1.5105254650115967 
Loss after 200/439 Batch is 1.5328954458236694 
Loss after 300/439 Batch is 1.91533625125885 
Loss after 400/439 Batch is 1.9073659181594849 
Combined Loss Value after 2 Epoch is 1.8783043794045415
Starting Epoch 3 ------------------------------------------
Loss after 0/439 Batch is 1.544331431388855 
Loss after 100/439 Batch is 1.5405009984970093 
Loss after 200/439 Batch is 1.589935064315796 
Loss after 300/439 Batch is 1.5900063514709473 
Loss after 400/439 Batch is 1.5982508659362793 
Co

Loss after 100/439 Batch is 1.5146740674972534 
Loss after 200/439 Batch is 1.512189269065857 
Loss after 300/439 Batch is 1.5381606817245483 
Loss after 400/439 Batch is 1.5296515226364136 
Combined Loss Value after 24 Epoch is 1.5279656412389664
Starting Epoch 25 ------------------------------------------
Loss after 0/439 Batch is 1.4173530340194702 
Loss after 100/439 Batch is 1.5520180463790894 
Loss after 200/439 Batch is 1.5596923828125 
Loss after 300/439 Batch is 1.5372940301895142 
Loss after 400/439 Batch is 1.5338860750198364 
Combined Loss Value after 25 Epoch is 1.527862757376495
Starting Epoch 26 ------------------------------------------
Loss after 0/439 Batch is 1.748734951019287 
Loss after 100/439 Batch is 1.5017292499542236 
Loss after 200/439 Batch is 1.5147569179534912 
Loss after 300/439 Batch is 1.5090062618255615 
Loss after 400/439 Batch is 1.5010285377502441 
Combined Loss Value after 26 Epoch is 1.5001472351491316
Starting Epoch 27 ---------------------------

In [32]:
from datetime import datetime
format_data = "date_%d_%m_%y_time_%H_%M_%S"
timestamp = datetime.strftime(datetime.now(), format_data)

torch.save(Trainer.model.visual.state_dict(),f"results/{timestamp}_OG_CLIP_VisionTransformer.pt")
torch.save(Trainer.model.transformer.state_dict(),f"results/{timestamp}_OG_CLIP_TextTransformer.pt")