In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from torchtune.modules import RMSNorm
from tokenizers import Tokenizer
from pathlib import Path
from transformers import RobertaTokenizer, RobertaModel
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, ToTensor
from torchvision.transforms.v2 import RGB

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.utils.data import random_split
from PIL import Image

from transformers import ViTImageProcessor, ViTForImageClassification, ViTFeatureExtractor


import timm

In [21]:

@dataclass
class ModelArgs:
    #Hyperparameters
    img_size = (224, 224)
    block_size = 77
    batch_size = 32
    embeddings_dims = 768
    projection_dims = 768
    attn_dropout = 0.1
    no_of_heads = 12 #IMP needs to be thoroughly calculated
    dropout = 0.1
    epochs = 100
    lr = 4e-4
    no_of_decoder_layers = 12 #IMP needs to be thoroughly calculated
    weight_decay_optim = 0.2
    beta_1 = 0.9
    beta_2 = 0.98
    epsilon = 1e-6
    device = 'cuda'
    vocab_size = 2000
    head_lr = 1e-3
    image_encoder_lr = 1e-4
    text_encoder_lr = 1e-5
    model_name = 'resnet50'
    pretrained = True # for both image encoder and text encoder
    trainable = True # for both image encoder and text encoder
    bias = -10
    temperature = torch.log(torch.tensor(10))

In [3]:
class Normalization(nn.Module):
    def __init__(
        self,
        embeddings_dims: int = ModelArgs.embeddings_dims
    ):  
        super().__init__()
        self.layernorm_layer = torch.nn.LayerNorm(normalized_shape=embeddings_dims)
        
        
    def forward(self, x):
        
        x = self.layernorm_layer(x)
        return x
        

In [None]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaModel.from_pretrained('roberta-base')

In [5]:
class TextModel(nn.Module):
    def __init__(self):
        super().__init__()
            
        
        
        self.layer_norm = Normalization()
        self.model = RobertaModel.from_pretrained('roberta-base')
        self.tokenizer = tokenizer
        self.multimodalTextLayerProjector = nn.Linear(in_features=ModelArgs.embeddings_dims, out_features=ModelArgs.projection_dims, device=ModelArgs.device)
        
        for p in self.model.parameters():
            p.requires_grad = True
        self.model.train()
    def forward(self, x):
        # print("Problemetic x shape: ", x['input_ids'].shape)
        # print("Problemetic x shape: ", x['attention_mask'].shape)
        x['input_ids'] = x['input_ids'].squeeze(1)
        x['attention_mask'] = x['attention_mask'].squeeze(1) 
        x = self.model(input_ids = x['input_ids'], attention_mask = x['attention_mask'])['last_hidden_state'][:, 0, :] 
        # print(x)
        x = self.layer_norm(x)
        return self.multimodalTextLayerProjector(x)

In [16]:
class VisionModel(nn.Module):
    """
    Encode images to a fixed size vector
    """

    def __init__(
        self, model_name=ModelArgs.model_name, pretrained=ModelArgs.pretrained, trainable=ModelArgs.trainable
    ):
        super().__init__()
        self.model = timm.create_model(
            model_name, pretrained, num_classes=0, global_pool="avg"
        )
        for p in self.model.parameters():
            p.requires_grad = trainable

    def forward(self, x):
        return self.model(x)

In [22]:
class SigLip(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.vision = VisionModel()
        self.text = TextModel()
        # self.tokenizer = tokenizer
        self.multimodelTextLayerPorjector = nn.Linear(in_features=ModelArgs.embeddings_dims, out_features=ModelArgs.projection_dims, device=ModelArgs.device)
        self.multimodalVisionLayerProjector = nn.Linear(in_features=ModelArgs.embeddings_dims, out_features=ModelArgs.projection_dims, device=ModelArgs.device)
        # self.temperature = nn.Parameter(torch.ones(size=(ModelArgs.batch_size,), device=ModelArgs.device), requires_grad=True)
        self.temperature = nn.Parameter(ModelArgs.temperature, requires_grad=True)
        self.bias = nn.Parameter(ModelArgs.bias, requires_grad=True)

    def forward(self, x):
        
        embeds_text = self.text(x)
        # print("Inside CLiP text: ", embeds_text.shape)
        proj_txt = torch.nn.functional.normalize(self.multimodelTextLayerPorjector(embeds_text))
        embeds_img = self.vision(x)
        # print("Inside ViT: ", embeds_img.shape)
        proj_img = torch.nn.functional.normalize(self.multimodalVisionLayerProjector(embeds_img))
        # print(proj_txt.shape)
        # print(proj_img.shape)
        logits = -(proj_txt @ proj_img.T) * torch.exp(self.temperature) + self.bias
        # print("Inside CLiP logits shape: ", logits.shape)
        return logits

In [None]:
siglip = SigLip()

In [23]:
#Config
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2

train_transforms = A.Compose(
    [   
        A.Resize(height=224, width=224),
        A.CenterCrop(height=224, width=224),
        # A.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], max_pixel_value=224.0,),
        # A.ToFloat(max_value=224),
        ToTensorV2(),
    ]
)

test_tyransforms = A.Compose(
    [
        # A.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], max_pixel_value=224.0,),
        # A.ToFloat(max_value=224),
        ToTensorV2(),
    ]
)

In [None]:
import pandas as pd

df = pd.read_csv('data/flickr8000/captions.txt', sep=',')
df

In [None]:
df_sampled = df.sample(frac=0.01, random_state=42)
df_sampled

In [30]:
import os
import numpy as np
class CLiPDatatset(Dataset):
    def __init__(self, path):
        self.tokenizer = tokenizer
        self.path = path
        # self.dir = os.listdir(self.path)        
    def __len__(self):
        
        return df_sampled.shape[0]
        
    def __getitem__(self, idx):
        
        text, img = df_sampled.iloc[idx][1], df_sampled.iloc[idx][0]
        # print(text)
        # print(img)
        img_path = os.path.join(self.path, img) 
        # print(img_path)
        img = np.array(Image.open(img_path))

        input_transformed = train_transforms(image = img)['image']
        
        text_tokenized = self.tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=ModelArgs.block_size)
        
        # print(text_tokenized)
        encoded_items = {
            
            key: torch.tensor(values)
            for key, values in text_tokenized.items()
            
        }
        encoded_items['image'] = input_transformed
        return encoded_items

In [31]:
dir = 'data/flickr8000/images'
dataset = CLiPDatatset(dir)

# Assuming 'dataset' is already created
# Split the dataset into training and validation sets
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])


#Creating dataloaders

trainloader = DataLoader(train_dataset, batch_size=ModelArgs.batch_size, shuffle=True)
valloader = DataLoader(val_dataset, batch_size=ModelArgs.batch_size, shuffle=False)


In [None]:
import itertools
params = [
        {"params": siglip.vision.parameters(), "lr": ModelArgs.image_encoder_lr},
        {"params": siglip.text.parameters(), "lr": ModelArgs.text_encoder_lr},
        {"params": itertools.chain(
            siglip.multimodalVisionLayerProjector.parameters(), siglip.multimodelTextLayerPorjector.parameters(), [siglip.temperature]
        ), "lr": ModelArgs.head_lr, "weight_decay": ModelArgs.weight_decay_optim}
    ]

optimizer = torch.optim.Adam(lr=ModelArgs.lr, params=params, eps=ModelArgs.epsilon)

loss_fn = nn.CrossEntropyLoss()

# def cross_entropy(pred=None, targets=None, dim=None):
#     # print("Targets shape is: ",targets.shape)
#     # print("Predictions shape is :", pred.shape)
    
#     preds = nn.functional.log_softmax(pred, dim=-1)

#     l = (-targets * preds).sum(1).mean()
#     return l

In [19]:
from going_modular import engine

In [None]:
results = engine.train(model=siglip,
                       writer=None,
                       train_dataloader=trainloader,
                       test_dataloader=valloader,
                       optimizer=optimizer,
                       loss_fn=loss_fn,
                       epochs=30,
                       device=ModelArgs.device)