In [None]:
from google.colab import drive
drive.mount('/content/drive')
#parent_dir = "/content/drive/MyDrive/"


In [None]:
#!pip install datasets
!nvidia-smi

# Import Libraries and Modules

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import matplotlib.pyplot as plt
import numpy as np
from torchvision.datasets import Caltech101
from torch.utils.data import random_split
import time
import os

torch.backends.cudnn.enabled = True  # Enabled by default
torch.backends.cudnn.benchmark = True


# Positional Embedding

In [None]:
class PositionalEmbedding(nn.Module):
  def __init__(self, width, max_seq_length):
    super().__init__()

    # Creating positional encoding
    pe = torch.zeros(max_seq_length, width)

    for pos in range(max_seq_length):
      for i in range(width):
        if i % 2 == 0:
          pe[pos][i] = np.sin(pos/(10000 ** (i/width)))
        else:
          pe[pos][i] = np.cos(pos/(10000 ** ((i-1)/width)))

    self.register_buffer('pe', pe.unsqueeze(0))

  def forward(self, x):
    # Add positional encoding to embeddings
    x = x + self.pe
    return x


pe = PositionalEmbedding(width = 8, max_seq_length=32)
print(f"p.pe.shape:{pe.pe.shape}")
x = torch.randn(1,32,8)
print(f"x.shape:{x.shape}")
res = pe(x)
print(f"res:{res.shape}")

trainable_params = sum(p.numel() for p in pe.parameters() if p.requires_grad)
print(f"Starting training with: {trainable_params} parameters")

print()

# Multi-Head Attention

In [None]:
class AttentionHead(nn.Module):
  def __init__(self, width, head_size, lora_rank):
    super().__init__()
    self.head_size = head_size

    self.query = nn.Linear(width, head_size)
    self.key = nn.Linear(width, head_size)
    self.value = nn.Linear(width, head_size)

    self.qA = nn.Linear(width, lora_rank)
    self.qB = nn.Linear(lora_rank, head_size)

    #self.kA = nn.Linear(width, lora_rank)
    #self.kB = nn.Linear(lora_rank, head_size)

    #self.vA = nn.Linear(width, lora_rank)
    #self.vB = nn.Linear(lora_rank, head_size)

    #nn.init.constant_(self.qB.weight, 0)
    #nn.init.constant_(self.qB.bias, 0)

    #nn.init.constant_(self.kB.weight, 0)
    #nn.init.constant_(self.kB.bias, 0)

    #nn.init.constant_(self.vB.weight, 0)
    #nn.init.constant_(self.vB.bias, 0)

    self.set_lora_mode(False)


  def set_lora_mode(self, mode):

    self.query.requires_grad_(not mode)
    self.key.requires_grad_(not mode)
    self.value.requires_grad_(not mode)

    self.qA.requires_grad_(mode)
    self.qB.requires_grad_(mode)

    #self.kA.requires_grad_(mode)
    #self.kB.requires_grad_(mode)

    #self.vA.requires_grad_(mode)
    #self.vB.requires_grad_(mode)


  def forward(self, x, mask=None):
    # Obtaining Queries, Keys, and Values
    Q = self.query(x) + self.qB(self.qA(x))
    K = self.key(x)# + self.kB(self.kA(x))
    V = self.value(x)# + self.vB(self.vA(x))

    # Dot Product of Queries and Keys
    attention = Q @ K.transpose(-2,-1)
    #print(f"attention matrix: {attention.shape}")

    # Scaling
    attention = attention / (self.head_size ** 0.5)

    # Applying Attention Mask
    if mask is not None:
        attention = attention.masked_fill(mask == 0, float("-inf"))

    attention = torch.softmax(attention, dim=-1)

    attention = attention @ V
    #print(f"attention vectors: {attention.shape}")

    return attention
ah = AttentionHead(width = 8, head_size = 4, lora_rank = 1)
print(f"query.weight:{ah.query.weight} {ah.query.in_features}:{ah.query.out_features}")
print(f"qA.weight:{ah.qA.weight} {ah.qA.in_features}:{ah.qA.out_features}")
print(f"qB.weight:{ah.qB.weight} {ah.qB.in_features}:{ah.qB.out_features}")
print()
print(f"key.weight:{ah.key.weight} {ah.key.in_features}:{ah.key.out_features}")
#print(f"kA.weight:{ah.kA.weight} {ah.kA.in_features}:{ah.kA.out_features}")
#print(f"kB.weight:{ah.kB.weight} {ah.kB.in_features}:{ah.kB.out_features}")
print()
print(f"value.weight:{ah.value.weight} {ah.value.in_features}:{ah.value.out_features}")
#print(f"vA.weight:{ah.vA.weight} {ah.vA.in_features}:{ah.vA.out_features}")
#print(f"vB.weight:{ah.vB.weight} {ah.vB.in_features}:{ah.vB.out_features}")

#32 tokens 8 features per token
short_input = torch.randn(1,32,8)
print(f"short_input.shape:{short_input.shape}")
short_output = ah(short_input)
print(f"short_output.shape:{short_output.shape}")
print()

#128 tokens 8 features per token
long_input = torch.randn(1,128,8)
print(f"long_input.shape:{long_input.shape}")
long_output = ah(long_input)
print(f"long_output.shape:{long_output.shape}")
"""
illegal_input = torch.randn(1,128,16)
print(f"illegal_input.shape:{illegal_input.shape}")
illegal_output = AttentionHead(illegal_input)
print(f"illegal_output.shape:{illegal_output.shape}")
"""
trainable_params = sum(p.numel() for p in ah.parameters() if p.requires_grad)
print(f"Initializing with: {trainable_params} parameters")

ah.set_lora_mode(False)
trainable_params = sum(p.numel() for p in ah.parameters() if p.requires_grad)
print(f"Training with: {trainable_params} parameters")


ah.set_lora_mode(True)
trainable_params = sum(p.numel() for p in ah.parameters() if p.requires_grad)
print(f"Finetuning with: {trainable_params} parameters")



print()

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, width, n_heads, lora_rank):
    super().__init__()
    self.head_size = width // n_heads

    self.W_o = nn.Linear(width, width)

    #self.oA = nn.Linear(width, lora_rank)
    #self.oB = nn.Linear(lora_rank, width)

    #nn.init.constant_(self.oB.weight, 0)
    #nn.init.constant_(self.oB.bias, 0)

    self.heads = nn.ModuleList([AttentionHead(width, self.head_size, lora_rank) for _ in range(n_heads)])
    self.set_lora_mode(False)

  def set_lora_mode(self, mode):
    for head in self.heads:
      head.set_lora_mode(mode)

    self.W_o.requires_grad_(not mode)
    #self.oA.requires_grad_(mode)
    #self.oB.requires_grad_(mode)

  def forward(self, x, mask=None):
    # Combine attention heads
    out = torch.cat([head(x, mask=mask) for head in self.heads], dim=-1)

    out = self.W_o(out)# + self.oB(self.oA(out))

    return out

MHA = MultiHeadAttention(width = 32, n_heads = 4, lora_rank = 1)
print(f"W_o.weight:{MHA.W_o.weight} {MHA.W_o.in_features}:{MHA.W_o.out_features}")
#print(f"oA.weight:{MHA.oA.weight} {MHA.oA.in_features}:{MHA.oA.out_features}")
#print(f"oB.weight:{MHA.oB.weight} {MHA.oB.in_features}:{MHA.oB.out_features}")
print()

#32 tokens 8 features per token
short_input = torch.randn(1,36,32)
print(f"short_input.shape:{short_input.shape}")
short_output = MHA(short_input)
print(f"short_output.shape:{short_output.shape}")
print()

#128 tokens 8 features per token
long_input = torch.randn(1,128,32)
print(f"long_input.shape:{long_input.shape}")
long_output = MHA(long_input)
print(f"long_output.shape:{long_output.shape}")
print()

""""
illegal_input = torch.randn(1,128,16)
print(f"illegal_input.shape:{illegal_input.shape}")
illegal_output = MHA(illegal_input)
print(f"illegal_output.shape:{illegal_output.shape}")
"""
trainable_params = sum(p.numel() for p in MHA.parameters() if p.requires_grad)
print(f"Initializing with: {trainable_params} parameters")

MHA.set_lora_mode(False)
trainable_params = sum(p.numel() for p in MHA.parameters() if p.requires_grad)
print(f"Training with: {trainable_params} parameters")


MHA.set_lora_mode(True)
trainable_params = sum(p.numel() for p in MHA.parameters() if p.requires_grad)
print(f"Finetuning with: {trainable_params} parameters")

print()

# Transformer Encoder

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, width, n_heads, lora_rank, r_mlp=4):
        super().__init__()
        self.width = width
        self.n_heads = n_heads

        # Sub-Layer 1 Normalization
        self.ln1 = nn.LayerNorm(width)

        # Multi-Head Attention
        self.mha = MultiHeadAttention(width, n_heads, lora_rank)

        # Sub-Layer 2 Normalization
        self.ln2 = nn.LayerNorm(width)

        # Multilayer Perception
        self.mlp = nn.Sequential(
            nn.Linear(self.width, self.width*r_mlp),
            nn.GELU(),
            nn.Linear(self.width*r_mlp, self.width)
        )
        self.set_lora_mode(False)
    def set_lora_mode(self, mode):
      self.mha.set_lora_mode(mode)
      self.mlp.requires_grad_(not mode)

    def forward(self, x, mask=None):
        # Residual Connection After Sub-Layer 1
        x = x + self.mha(self.ln1(x), mask=mask)

        # Residual Connection After Sub-Layer 2
        x = x + self.mlp(self.ln2(x))

        return x

TE = TransformerEncoder(width = 32, n_heads = 4, lora_rank=1)
#32 tokens 8 features per token
short_input = torch.randn(1,36,32)
print(f"short_input.shape:{short_input.shape}")
short_output = TE(short_input)
print(f"short_output.shape:{short_output.shape}")
print()

#128 tokens 8 features per token
long_input = torch.randn(1,128,32)
print(f"long_input.shape:{long_input.shape}")
long_output = TE(long_input)
print(f"long_output.shape:{long_output.shape}")
print()

"""
illegal_input = torch.randn(1,128,16)
print(f"illegal_input.shape:{illegal_input.shape}")
illegal_output = MHA(illegal_input)
print(f"illegal_output.shape:{illegal_output.shape}")
"""

trainable_params = sum(p.numel() for p in TE.parameters() if p.requires_grad)
print(f"Initializing with: {trainable_params} parameters")

TE.set_lora_mode(False)
trainable_params = sum(p.numel() for p in TE.parameters() if p.requires_grad)
print(f"Training with: {trainable_params} parameters")


TE.set_lora_mode(True)
trainable_params = sum(p.numel() for p in TE.parameters() if p.requires_grad)
print(f"Finetuning with: {trainable_params} parameters")

print()

# Tokenizer

In [None]:
def tokenizer(text, max_seq_length, encode=True, mask=None):
    if encode:
        out = chr(2) + text + chr(3) # Adding SOT and EOT tokens
        out = out + "".join([chr(0) for _ in range(max_seq_length-len(out))]) # Adding Padding
        out = torch.IntTensor(list(out.encode("utf-8"))) # Encoding Text
        mask = torch.ones(len(out.nonzero()))
        mask = torch.cat((mask,torch.zeros(max_seq_length-len(mask)))).type(torch.IntTensor)
    else:
        out = [chr(x) for x in text[1:len(mask.nonzero())-1]]
        out = "".join(out)
        mask = None

    return out, mask
"""
short_input = "hello world"
print(f"len short input: {len(short_input)}")
tokens, mask = tokenizer("hello world", encode = True, mask=None, max_seq_length=32)
print(tokens)
print(mask)

long_input = ("0123456789"*3)
print(f"len long input: {len(long_input)}") #start and end tokens added
tokens, mask = tokenizer(long_input, encode = True, mask=None, max_seq_length=32)
print(tokens)
print(mask)


illegal_input = ("0123456789"*4)images, labels = data["image"].to(device), data["caption"].to(device)
        image_features = model.image_encoder(images)
        text_features = model.text_encoder(text, mask=mask)

        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        similarity = (100.0 * (image_features @ text_features.T)).softmax(dim=-1)
        _, indices = torch.max(similarity,1)
        pred = torch.stack([tokenizer(test_set.captions[int(i)])[0] for i in indices]).to(device)
        correct += int(sum(torch.sum((pred==labels),dim=1)//len(pred[0])))
        total += len(labels)
print(f"len illegal input: {len(illegal_input)}") #start and end tokens added
tokens, mask = tokenizer(illegal_input, encode = True, mask=None, max_seq_length=32)
print(tokens)
print(mask)
"""
print()

# Text Encoder

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, vocab_size, width, max_seq_length, n_heads, n_layers, emb_dim, lora_rank):
        super().__init__()

        self.max_seq_length = max_seq_length  # Maximum length of input sequence

        self.encoder_embedding = nn.Embedding(vocab_size, width) # Embedding Table

        self.positional_embedding = PositionalEmbedding(width, max_seq_length)

        self.encoder = nn.ModuleList([TransformerEncoder(width, n_heads, lora_rank) for _ in range(n_layers)])

        # learned proj of image to embed
        self.projection = nn.Parameter(torch.randn(width, emb_dim))
        self.set_lora_mode(False)

    def set_lora_mode(self, mode):
      self.encoder_embedding.requires_grad_(not mode)
      self.positional_embedding.requires_grad_(not mode)
      for encoder in self.encoder:
        encoder.set_lora_mode(mode)
      self.projection.requires_grad_(not mode)

    def forward(self, text, mask):
        # Text Embedding
        x = self.encoder_embedding(text)

        # Positional Embedding
        x = self.positional_embedding(x)

        # Transformer Encoder
        for encoder_layer in self.encoder:
            x = encoder_layer(x, mask=mask)
        #print(f"x:{x.size()}")
        #print(f"text: {text.size()}")
        #print(f"mask:{mask.size()}")
        # Takes features from the EOT Embedding
        x = x[torch.arange(text.shape[0]),torch.sub(torch.sum(mask[:,0],dim=1),1)]

        # joint multimodal embedding
        if self.projection is not None:
            x = x @ self.projection

        x = x / torch.norm(x, dim=-1, keepdim=True)

        return x


TE = TextEncoder(vocab_size=256, width=32, max_seq_length=128, n_heads=4, n_layers=4, emb_dim=32, lora_rank=1)

trainable_params = sum(p.numel() for p in TE.parameters() if p.requires_grad)
print(f"Initializing with: {trainable_params} parameters")

TE.set_lora_mode(False)
trainable_params = sum(p.numel() for p in TE.parameters() if p.requires_grad)
print(f"Training with: {trainable_params} parameters")


TE.set_lora_mode(True)
trainable_params = sum(p.numel() for p in TE.parameters() if p.requires_grad)
print(f"Finetuning with: {trainable_params} parameters")

print()

# Image Encoder

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, width, img_size, patch_size, n_channels, n_layers, n_heads, emb_dim, lora_rank):
        super().__init__()

        assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, "img_size dimensions must be divisible by patch_size dimensions"
        assert width % n_heads == 0, "width must be divisible by n_heads"

        self.n_patches = (img_size[0] * img_size[1]) // (patch_size[0] * patch_size[1])

        self.max_seq_length = self.n_patches + 1

        # Patch Embedding
        self.linear_project = nn.Conv2d(n_channels, width, kernel_size=patch_size, stride=patch_size)

        # Classification Token
        self.cls_token = nn.Parameter(torch.randn(1, 1, width))

        self.positional_embedding = PositionalEmbedding(width,self.max_seq_length)

        self.encoder = nn.ModuleList([TransformerEncoder(width,n_heads, lora_rank) for _ in range(n_layers)])

        # learned proj of image to embed
        self.projection = nn.Parameter(torch.randn(width, emb_dim))
        self.set_lora_mode(False)


    def set_lora_mode(self, mode):
      self.linear_project.requires_grad_(not mode)
      self.cls_token.requires_grad_(not mode)
      for encoder in self.encoder:
        encoder.set_lora_mode(mode)
      self.projection.requires_grad_(not mode)

    def forward(self,x):
        # Patch Embedding
        x = self.linear_project(x)
        x = x.flatten(2).transpose(1, 2)

        # Positional Embedding
        x = torch.cat((self.cls_token.expand(x.size()[0], -1, -1),x), dim=1)
        x = self.positional_embedding(x)

        # Transformer Encoder
        for encoder_layer in self.encoder:
            x = encoder_layer(x)

        # Takes Class Tokens
        x = x[:, 0, :]

        # joint multimodal embedding
        if self.projection is not None:
            x = x @ self.projection

        x = x / torch.norm(x, dim=-1, keepdim=True)

        return x

IE = TextEncoder(vocab_size=256, width=32, max_seq_length=128, n_heads=4, n_layers=4, emb_dim=32, lora_rank=1)


trainable_params = sum(p.numel() for p in IE.parameters() if p.requires_grad)
print(f"Initializing with: {trainable_params} parameters")

IE.set_lora_mode(False)
trainable_params = sum(p.numel() for p in IE.parameters() if p.requires_grad)
print(f"Training with: {trainable_params} parameters")


IE.set_lora_mode(True)
trainable_params = sum(p.numel() for p in IE.parameters() if p.requires_grad)
print(f"Finetuning with: {trainable_params} parameters")

print()


# CLIP Model

In [None]:
class CLIP(nn.Module):
    def __init__(self, emb_dim, vit_width, img_size, patch_size, n_channels, vit_layers, vit_heads, vocab_size, text_width, max_seq_length, text_heads, text_layers, lora_rank):
        super().__init__()

        self.image_encoder = ImageEncoder(vit_width, img_size, patch_size, n_channels, vit_layers, vit_heads, emb_dim, lora_rank)

        self.text_encoder = TextEncoder(vocab_size, text_width, max_seq_length, text_heads, text_layers, emb_dim, lora_rank)

        self.temperature = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"useing: {self.device}")
        self.set_lora_mode(False)

    def set_lora_mode(self, mode):
      self.image_encoder.set_lora_mode(mode)
      self.text_encoder.set_lora_mode(mode)


    def forward(self,image,text, mask=None):
        I_e = self.image_encoder(image)
        T_e = self.text_encoder(text, mask=mask)

        # scaled pairwise cosine similarities [n, n]
        logits = (I_e @ T_e.transpose(-2,-1)) * torch.exp(self.temperature)

        # symmetric loss function
        labels = torch.arange(logits.shape[0]).to(self.device)

        loss_i = nn.functional.cross_entropy(logits.transpose(-2,-1), labels)
        loss_t = nn.functional.cross_entropy(logits, labels)

        loss = (loss_i + loss_t) / 2

        return loss

C = CLIP(emb_dim=64, vit_width=128, img_size=(224,224), patch_size=(14,14), n_channels=1, vit_layers=4, vit_heads=4, vocab_size=256, text_width=128, max_seq_length=64, text_heads=5, text_layers=2, lora_rank=1)


trainable_params = sum(p.numel() for p in C.parameters() if p.requires_grad)
print(f"Initializing with: {trainable_params} parameters")

C.set_lora_mode(False)
trainable_params = sum(p.numel() for p in C.parameters() if p.requires_grad)
print(f"Training with: {trainable_params} parameters")



C.set_lora_mode(True)
trainable_params = sum(p.numel() for p in C.parameters() if p.requires_grad)
print(f"Finetuning with: {trainable_params} parameters")

print()

# Dataset

In [None]:
img_size = (128,128)
max_text_seq_length = 128

dataset = load_dataset("fashion_mnist")
dataset.set_format(type='torch')
num_train_samples = len(dataset['train'])
split_10 = int(0.1 * num_train_samples)
split_90 = num_train_samples - split_10
mnist_train_dataset, mnist_val_dataset = random_split(dataset['train'], [split_90, split_10], generator=torch.Generator().manual_seed(42))
mnist_test_dataset = dataset["test"]


class FashionMNIST(Dataset):
    def __init__(self, split="train"):
        if split=="train":
            self.dataset = mnist_train_dataset
            self.split = "train"
            self.transform = T.Compose([
                T.Resize(img_size),
                #T.ToTensor(),
                T.ConvertImageDtype(torch.float32)  # Convert the image to float32 tensor
            ])
        elif split=="val":
            self.dataset = mnist_val_dataset
            self.split = "val"
            self.transform = T.Compose([
                T.Resize(img_size),
                #T.ToTensor(),
                T.ConvertImageDtype(torch.float32)  # Convert the image to float32 tensor
            ])
        else:
            self.dataset = mnist_test_dataset
            self.split = "test"
            self.transform = T.Compose([
                T.Resize(img_size),
                #T.ToTensor(),
                T.ConvertImageDtype(torch.float32)  # Convert the image to float32 tensor
            ])

        #print(type(self.dataset))
        #print(dir(self.dataset))



        self.captions = {
            0: "t-shirt/top",
            1: "trousers",
            2: "pullover",
            3: "dress",
            4: "coat",
            5: "sandal",
            6: "shirt",
            7: "sneaker",
            8: "bag",
            9: "ankle boot"
        }
        self.captions = {k: f"an image of {v}" for k, v in self.captions.items()}


    def __len__(self):
        return len(self.dataset)

    def __getitem__(self,i):
        img = self.dataset[i]["image"]
        img = self.transform(img)

        cap, mask = tokenizer(self.captions[self.dataset[i]["label"].item()], max_seq_length = max_text_seq_length)

        mask = mask.repeat(len(mask),1)

        return {"image": img, "caption": cap, "mask": mask}
#ds = FashionMNIST(split="train")
#for i in ds:
#    print(i)

In [None]:
transform = T.Compose([
              T.Resize(img_size),  # Resize to a fixed size (e.g., for models like ResNet)
              T.Grayscale(num_output_channels=1),  # Convert to grayscale with 1 output channel
              T.ToTensor(),
            ])

dataset = Caltech101(root="./", target_type='category', transform=transform, download=True)
train_size = int(0.7 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size
cal_train_dataset, cal_val_dataset, cal_test_dataset = random_split(dataset, [train_size, val_size, test_size])

class Caltech101_Wrapper(Dataset):
    def __init__(self, split="train"):

        # Define transformations (you can adjust these based on model requirements

        if split == "train":
          self.dataset = cal_train_dataset
        elif split == "val":
          self.dataset = cal_val_dataset
        else:
          self.dataset = cal_test_dataset



        # Load the Caltech101 dataset

        # Captions for the different object categories
        self.captions = {
          0: "accordion",
          1: "airplane",
          2: "anchor",
          3: "ant",
          4: "from Google backgrounds",
          5: "barrel",
          6: "bass fish",
          7: "beaver",
          8: "binoculars",
          9: "bonsai tree",
          10: "brain",
          11: "brontosaurus",
          12: "Buddha statue",
          13: "butterfly",
          14: "camera",
          15: "cannon",
          16: "car viewed from the side",
          17: "ceiling fan",
          18: "cellphone",
          19: "chair",
          20: "chandelier",
          21: "cougar's body",
          22: "cougar's face",
          23: "crab",
          24: "crayfish",
          25: "crocodile",
          26: "crocodile's head",
          27: "cup",
          28: "dalmatian dog",
          29: "dollar bill",
          30: "dolphin",
          31: "dragonfly",
          32: "electric guitar",
          33: "elephant",
          34: "emu",
          35: "euphonium instrument",
          36: "ewer",
          37: "face",
          38: "face in an easy-to-recognize pose",
          39: "ferry",
          40: "flamingo",
          41: "flamingo's head",
          42: "character Garfield",
          43: "gerenuk",
          44: "gramophone",
          45: "grand piano",
          46: "hawksbill turtle",
          47: "headphones",
          48: "hedgehog",
          49: "helicopter",
          50: "ibis bird",
          51: "inline skate",
          52: "Joshua tree",
          53: "kangaroo",
          54: "ketch sailboat",
          55: "lamp",
          56: "laptop",
          57: "leopard",
          58: "llama",
          59: "lobster",
          60: "lotus flower",
          61: "mandolin",
          62: "mayfly",
          63: "menorah",
          64: "metronome",
          65: "minaret",
          66: "motorbike",
          67: "nautilus shell",
          68: "octopus",
          69: "okapi",
          70: "pagoda",
          71: "panda bear",
          72: "pigeon",
          73: "pizza",
          74: "platypus",
          75: "pyramid",
          76: "revolver gun",
          77: "rhinoceros",
          78: "rooster",
          79: "saxophone",
          80: "schooner sailboat",
          81: "pair of scissors",
          82: "scorpion",
          83: "sea horse",
          84: "soccer ball",
          85: "character Snoopy",
          86: "starfish",
          87: "stapler",
          88: "stegosaurus",
          89: "stop sign",
          90: "strawberry",
          91: "sunflower",
          92: "tick insect",
          93: "trilobite fossil",
          94: "umbrella",
          95: "watch",
          96: "water lily",
          97: "wheelchair",
          98: "wild cat",
          99: "Windsor chair",
          100: "wrench tool",
          101: "yin-yang symbol"
          }
        self.captions = {k: f"an image of {v}" for k, v in self.captions.items()}




    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        # Get the image and label from the dataset
        #print(self.dataset[i])
        #print(type(self.dataset[i]))
        #print(dir(self.dataset[i]))
        img = self.dataset[i][0]
        #print(f"img: {img.shape}")
        label = self.dataset[i][1]
        #print(f"label: {label}")

        # Apply the transformation to the image
        #img = self.transform(img)

        # Generate the caption based on the label
        caption = self.captions.get(label, "An image of an object")  # Default caption if not found
        #print(f"caption: {caption}: {len(caption)}")

        # Tokenize the caption and create the mask (use your tokenizer accordingly)
        cap, mask = tokenizer(caption, max_seq_length = max_text_seq_length)
        #print(f"tokens: {cap} {cap.shape}")
        mask = mask.repeat(len(mask), 1)
        #print(f"mask: {mask} {mask.shape}")
        #print()

        # Return the processed image, caption, and mask
        return {"image": img, "caption": cap, "mask": mask}

#Caltech101_Wrapper(split="train")

# Training Functions

In [None]:
def epoch_func(model, data_loader, optimizer = None):
    start_time = time.time()
    loss_acc = 0
    N = 0
    for i, data in enumerate(data_loader):
        img, cap, mask = data["image"].to(device), data["caption"].to(device), data["mask"].to(device)
        loss = model(img,cap,mask)
        loss_acc += loss
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        N += len(data["image"])
        print(f"\rBatch:{(i+1)}/{len(data_loader):.2f}, Avg Loss: {loss_acc/N:.5f}, {time.time()-start_time:.2f}s", end='', flush=True)
        del loss, img, cap, mask

    print()
    return (loss_acc/N).detach().cpu().numpy()



In [None]:
def train_model(model, train_loader, epochs, lr, save_dir, val_loader):
    os.makedirs(save_dir,  exist_ok=True)

    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_loss = np.inf
    with torch.no_grad():
        initial_training_loss = epoch_func(model, train_loader, optimizer=None)
        training_losses = [initial_training_loss]

        initial_val_loss = epoch_func(model, val_loader, optimizer=None)
        validation_losses = [initial_val_loss]
    torch.cuda.empty_cache()
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    num_optimized_params = sum(p.numel() for group in optimizer.param_groups for p in group['params'] if p.requires_grad)
    assert trainable_params == num_optimized_params
    print(f"Starting training {save_dir} with: {trainable_params} parameters")

    for epoch in range(epochs):
        print(f"Epoch: {epoch+1}/{epochs}")
        avg_training_loss = epoch_func(model, train_loader, optimizer=optimizer)
        training_losses.append(avg_training_loss)


        with torch.no_grad():
            avg_validation_loss = epoch_func(model, val_loader, optimizer=None)
            validation_losses.append(avg_validation_loss)

        if avg_validation_loss < best_loss:
            best_loss = avg_validation_loss
            torch.save(model.state_dict(), f"{save_dir}clip.pt")
            print("Model Saved.")
        print("\n")

    plt.plot(training_losses, label = "training")
    if val_loader is not None:
        plt.plot(validation_losses, label = "validation")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"Training Loss {save_dir}")
    plt.legend()
    plt.savefig(f"{save_dir}Training.png")
    plt.show()

In [None]:
def get_accuracy(model, data_set, data_loader):
    start_time = time.time()

    # Getting dataset captions to compare images to
    text = torch.stack([tokenizer(x, max_seq_length)[0] for x in data_set.captions.values()]).to(device)
    #print(f"{text.shape=}")
    mask = torch.stack([tokenizer(x, max_seq_length)[1] for x in data_set.captions.values()])
    #print(f"{mask.shape=}")

    mask = mask.repeat(1,len(mask[0])).reshape(len(mask),len(mask[0]),len(mask[0])).to(device)
    #print(f"{mask.shape=}")

    text_features = model.text_encoder(text, mask=mask)
    #print(f"{text_features.shape=}")
    text_features /= text_features.norm(dim=-1, keepdim=True)
    #print(f"{text_features.shape=}")

    correct, total = 0,0
    #print("\n\n")
    for i, data in enumerate(data_loader):
        images, labels = data["image"].to(device), data["caption"].to(device)
        #print(f"{images.shape=}")
        #print(f"{labels}{labels.shape}")

        image_features = model.image_encoder(images)
        #print(f"{image_features.shape=}")


        image_features /= image_features.norm(dim=-1, keepdim=True)
        #print(f"{image_features.shape=}")

        similarity = (100.0 * (image_features @ text_features.T)).softmax(dim=-1)
        #print(f"{similarity.shape=}")

        _, indices = torch.max(similarity,1)
        #print(f"{indices.shape=}")
        #print("\n\n")


        pred = torch.stack([tokenizer(data_set.captions[int(i)], max_seq_length=max_seq_length)[0] for i in indices]).to(device)
        #print(f"{pred.shape=}")
        #print(f"{pred==labels.shape}")
        #print(f"{torch.sum((pred==labels),dim=1)=}")
        #print(f"{torch.sum((pred==labels),dim=1)//len(pred[0])=}")
        #print(f"{sum(torch.sum((pred==labels),dim=1)//len(pred[0]))=}")
        correct += int(sum(torch.sum((pred==labels),dim=1)//len(pred[0])))
        #print(f"{correct}")
        total += len(labels)

        print(f"\rBatch:{(i+1)}/{len(data_loader)} Acc:{100 * (correct / total):.2f}%, {time.time()-start_time:0.2f}s", end='', flush=True)
    print()

    return correct/total
#model = CLIP(emb_dim, vit_width, img_size, patch_size, n_channels, vit_layers, vit_heads, vocab_size, text_width, max_seq_length, text_heads, text_layers, lora_rank).to(device)
#get_accuracy(model, cal_test_set, cal_test_loader)

In [None]:
emb_dim = 32

encoder_heads=4
encoder_width=32

patch_size = (8,8)
n_channels = 1
vit_width = encoder_width
vit_heads = encoder_heads

max_seq_length = max_text_seq_length
vocab_size = 256
text_width = encoder_width
text_heads = encoder_heads

training_epochs=32
lora_epochs = 8

training_batch_size = 128
lora_batch_size = 128

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")


In [None]:
mnist_train_set = FashionMNIST(split = "train")
mnist_val_set = FashionMNIST(split = "val")
mnist_test_set = FashionMNIST(split = "test")

mnist_train_loader = DataLoader(mnist_train_set, shuffle=True, batch_size=lora_batch_size, num_workers=4)
mnist_val_loader = DataLoader(mnist_val_set, shuffle=True, batch_size=lora_batch_size, num_workers=4)
mnist_test_loader = DataLoader(mnist_test_set, shuffle=False, batch_size=lora_batch_size, num_workers=4)

cal_train_set = Caltech101_Wrapper(split = "train")
cal_val_set = Caltech101_Wrapper(split = "val")
cal_test_set = Caltech101_Wrapper(split = "test")

cal_train_loader = DataLoader(cal_train_set, shuffle=True, batch_size=training_batch_size, num_workers = 4)
cal_val_loader = DataLoader(cal_val_set, shuffle=True, batch_size=training_batch_size, num_workers = 4)
cal_test_loader = DataLoader(cal_test_set, shuffle=True, batch_size=training_batch_size, num_workers = 4)

def plot_batch(loader):
  batch_size = loader.batch_size
  fig_side_length = int(np.ceil(np.sqrt(batch_size)))
  #print(f"fig_side_length: {fig_side_length}")
  fig, axes = plt.subplots(fig_side_length,fig_side_length)
  for batch in loader:
    for i, img in enumerate(batch["image"], 0):
      caption = batch["caption"][i]
      #print(f"caption.shape: {caption.shape}")64
      mask = batch["mask"][i]
      #print(f"mask.shape: {mask.shape}")
      img = img.permute(1, 2, 0).numpy()
      row = i // fig_side_length
      col = i % fig_side_length
      #print(f"i:{i}, row:{row}, col:{col}")
      axes[row, col].imshow(img,cmap="gray")
      axes[row, col].axis("off")
    plt.show()
    return

#plot_batch(mnist_train_loader)
#plot_batch(mnist_val_loader)
#plot_batch(mnist_test_loader)
#plot_batch(cal_train_loader)
#plot_batch(cal_val_loader)
#plot_batch(cal_test_loader)

In [None]:
def run_expirment(emb_dim, vit_width, img_size, patch_size, n_channels, vit_layers, vit_heads,
                vocab_size, text_width, max_seq_length, text_heads, text_layers, lora_rank, lr, lora_lr, save_dir):



    model = CLIP(emb_dim, vit_width, img_size, patch_size, n_channels, vit_layers, vit_heads, vocab_size, text_width, max_seq_length, text_heads, text_layers, lora_rank).to(device)

    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"When training: {trainable_params} parameters")

    model.set_lora_mode(True)
    ftable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"When finetuning: {ftable_params} parameters")
    model.set_lora_mode(False)

    train_model(model, cal_train_loader, training_epochs, lr, f"{save_dir}CAL/", val_loader=cal_val_loader)

    model.load_state_dict(torch.load(f"{save_dir}CAL/clip.pt", map_location=device, weights_only = True))
    with torch.no_grad():
        cal_accuracy = get_accuracy(model, cal_test_set, cal_test_loader)


    print("\n\n\n")

    model.set_lora_mode(True)

    train_model(model, mnist_train_loader, lora_epochs, lora_lr, f"{save_dir}MNIST/", val_loader=mnist_val_loader)

    model.load_state_dict(torch.load(f"{save_dir}MNIST/clip.pt", map_location=device, weights_only = True))

    with torch.no_grad():
        ft_accuracy = get_accuracy(model, mnist_test_set, mnist_test_loader)
    model.set_lora_mode(False)
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    model.set_lora_mode(True)
    ftable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    with open(f"{save_dir}metrics.txt", "w") as file:
        file.write(f"trainable params: {trainable_params}\n")
        file.write(f"Caltech Accuracy: {cal_accuracy}\n")
        file.write(f"finetuneable params: {ftable_params}\n")
        file.write(f"MNIST accuracy: {ft_accuracy}\n")
    return cal_accuracy, ft_accuracy



In [None]:
start_time=time.time()
best_cal_accuracy =-1
best_cal_params = [0, 0]
cal_performance_history = []
best_ft_accuracy = -1
best_ft_params = [0, 0]
ft_performance_history = []


layers_to_try = [4]#[1, 4, 8, 16, 32]
ranks_to_try = [1, 4, 16]
learning_rates_to_try =[1e-3]#, 1e-4, 1e-5]

parent_dir = f"/content/drive/MyDrive/last_run/"
os.makedirs(parent_dir, exist_ok=True)

epochs_dir = f"{parent_dir}{training_epochs=}-{lora_epochs=}/"
os.makedirs(epochs_dir, exist_ok=True)

for lr in learning_rates_to_try:
    lora_lr = lr
    lr_dir = f"{epochs_dir}{lr=}-{training_batch_size=}-{lora_batch_size=}/"
    os.makedirs(lr_dir, exist_ok=True)
    for layers in layers_to_try:
        layer_dir = f"{lr_dir}{layers=}/"
        for lora_rank in ranks_to_try:
            model_dir = f"{layer_dir}w:{encoder_width}--h:{encoder_heads}-ed:{emb_dim}-rank:{lora_rank}/"
            os.makedirs(model_dir, exist_ok=True)
            cal_accuracy, ft_accuracy = run_expirment(emb_dim, vit_width, img_size, patch_size, n_channels, layers, vit_heads,
                                                      vocab_size, text_width, max_seq_length, text_heads, layers, lora_rank, lr, lora_lr,
                                                       model_dir)
            #cal_accuracy = np.random.rand()
            #ft_accuracy = np.random.rand()

            if cal_accuracy>best_cal_accuracy:
                best_cal_accuracy = cal_accuracy
                best_cal_params =[lr, layers, lora_rank]
                print(f"New best cal model:\n{best_cal_accuracy=}\n{best_cal_params=}")
            if ft_accuracy>best_ft_accuracy:
                best_ft_accuracy = ft_accuracy
                best_ft_params =[lr, layers, lora_rank]
                print(f"New best ft model:\n{best_ft_accuracy=}\n{best_ft_params=}")
            cal_performance_history.append((cal_accuracy, [lr, layers, lora_rank]))
            ft_performance_history.append((ft_accuracy, [lr, layers, lora_rank]))
        print("\n\n\n")

print(f"finished experiment in {time.time()-start_time:.2f} seconds")
with open(f"{epochs_dir}performances.txt", "w") as file:
        file.write(f"{best_cal_accuracy=}\n")
        file.write(f"{best_cal_params=}\n")

        file.write(f"{best_ft_accuracy=}\n")
        file.write(f"{best_ft_params=}\n")

        file.write(f"{cal_performance_history=}\n")
        file.write(f"{ft_performance_history=}\n")







In [None]:
for history, name in [(cal_performance_history, "cal"), (ft_performance_history, "MNIST")]:
    fig_width = len(history)  # Adjust multiplier as needed for spacing
    n_axes = len(learning_rates_to_try)
    points_per_ax = len(layers_to_try) * len(ranks_to_try)
    fig_height = n_axes*6  # Fixed height

    # Create the bar chart with dynamic figure size
    fig, axes = plt.subplots(nrows=n_axes, ncols=1, figsize=(fig_width, fig_height))
    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])

    values = [entry[0] for entry in history]
    labels = [f"({entry[1][1]}, {entry[1][2]})" for entry in history]
    for i in range(n_axes):
        axes[i].set_xlabel(f"learning_rate = {learning_rates_to_try[i]}, (layers, rank)")
        axes[i].set_ylabel("accuracy")
        #axes.set_xlabel(f"learning_rate = {learning_rates_to_try[i]}, (layers, rank)")
        #axes.set_ylabel("accuracy")

    colors = ["red","yellow","green","orange", "blue", "purple"]
    colors = [colors[(i//len(ranks_to_try))%len(colors)] for i in range(len(values))]
    for i in range(n_axes):
             starting_index = i * points_per_ax
             ending_index = (i+1) * points_per_ax
             axes[i].bar(labels[starting_index:ending_index], values[starting_index:ending_index], color=colors)
             #axes.bar(labels[starting_index:ending_index], values[starting_index:ending_index], color=colors)



    fig.suptitle(f'{epochs_dir}{name}_performacne', y=1)
    plt.tight_layout()

    fig.savefig(f"{epochs_dir}{name}_history.png")
    # Show the plot
    plt.show()