**Table of contents**<a id='toc0_'></a>    
- [Common Components](#toc1_)    
  - [Position Encoding](#toc1_1_)    
  - [Multi Headed Attention](#toc1_2_)    
  - [Feed Forward Network](#toc1_3_)    
  - [Encoder Block](#toc1_4_)    
- [Text Processing](#toc2_)    
  - [Text Tokenizer](#toc2_1_)    
  - [Text Encoder](#toc2_2_)    
- [Image Processing](#toc3_)    
- [CLIP Model](#toc4_)    
  - [Modality Projector](#toc4_1_)    
  - [Final CLIP Model](#toc4_2_)    
  - [Contrastive Loss](#toc4_3_)    
  - [Data Sets](#toc4_4_)    
- [Zero-Shot Experiment](#toc5_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision.datasets import FashionMNIST

import math

from tqdm import tqdm

In [None]:
from dataclasses import dataclass

@dataclass
class Config:
    # Vision Encoder
    vision_image_size:int = 32
    vision_patch_size:int = 8
    vision_num_channels: int = 1
    vision_num_layers: int = 6
    vision_num_heads: int = 4
    vision_hidden_size: int = 256
    vision_ffn_dim: int = 1024

    # Text Encoder
    lm_vocab_size: int = 256
    lm_max_seq_len: int = 32
    lm_hidden_size: int = 512
    lm_hidden_size: int = 256
    lm_num_layers: int = 6
    lm_num_heads: int = 4
    lm_ffn_dim: int = 1024

    # CLIP
    clip_hidden_dim: int = 512
    clip_loss_temperature: float = 0.3

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

# <a id='toc1_'></a>[Common Components](#toc0_)

## <a id='toc1_1_'></a>[Position Encoding](#toc0_)
For the simplty we both use sin/cos position encoding

In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(self, config: Config, type: str='lm'):
        super().__init__()

        if type == 'lm':
            seq_len = config.lm_max_seq_len
        else:
            seq_len =( config.vision_image_size // config.vision_patch_size) ** 2


        pos = torch.arange(0, config.lm_max_seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, config.lm_hidden_size, 2, dtype=torch.float) * -(math.log(10000.0) / config.lm_hidden_size))

        pe = torch.zeros(config.lm_max_seq_len, config.lm_hidden_size)
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))  # Add batch dimension

    def forward(self, x):
        """
        x: Tensor of shape (batch_size, seq_len, hidden_size)
        """
        return x + self.pe[:, :x.size(1), :]

## <a id='toc1_2_'></a>[Multi Headed Attention](#toc0_)

In [None]:
class Attention(nn.Module):
    def __init__(self, hidden_dim: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads

        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
        self.head_dim = hidden_dim // num_heads


        self.qkv_proj = nn.Linear(hidden_dim, hidden_dim * 3)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)


    def forward(self, x, mask = None):
        B, S, _ = x.shape
        Q, K, V = map(
            lambda t: t.view(t.size(0), t.size(1), self.num_heads, self.head_dim).transpose(1, 2),
            self.qkv_proj(x).chunk(3, dim=-1)
        )


        attn_weights = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            mask = mask[:, None, :, :]
            attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))

        #([64, 4, 32, 32]) torch.Size([64, 1, 32, 32])
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, S, -1)
        return self.out_proj(attn_output)

## <a id='toc1_3_'></a>[Feed Forward Network](#toc0_)

In [None]:
class FFN(nn.Module):
    def __init__(self, hidden_dim: int, ffn_dim: int):
        super().__init__()
        self.linear1 = nn.Linear(hidden_dim, ffn_dim)
        self.linear2 = nn.Linear(ffn_dim, hidden_dim)

    def forward(self, x):
        return self.linear2(F.relu(self.linear1(x)))

## <a id='toc1_4_'></a>[Encoder Block](#toc0_)

In [None]:
class Encoder(nn.Module):
    def __init__(self,  hidden_dim: int, num_heads: int, ffn_dim: int, is_causal: bool = False):
        super().__init__()


        self.attention = Attention(hidden_dim, num_heads)
        self.ffn = FFN(hidden_dim, ffn_dim)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
    def forward(self, x, mask=None):
        """
        x: Tensor of shape (batch_size, seq_len, hidden_dim)
        """
        x = x + self.attention(self.norm1(x), mask=mask)
        x = x + self.ffn(self.norm2(x))
        return x

# <a id='toc2_'></a>[Text Processing](#toc0_)




## <a id='toc2_1_'></a>[Text Tokenizer](#toc0_)
We just use byte-level tokenizer:
- Adds `chr(2)` as `[SOS`] (start of sequence) and `chr(3)` as `[EOS]` (end of sequence)
- Encoding: Uses `.encode("utf-8")` to convert characters into bytes, and then wraps that in torch.tensor(...)
- Decoding: Converts byte indices back to characters using chr(...), and removes [SOS]/[EOS] during decoding

In [None]:
def tokenizer(text, encode=True, mask = None, max_length=32):
    """
    A simple tokenizer that splits text into tokens.
    For demonstration purposes, it just splits by whitespace.
    """
    if encode:
        out = chr(2) + text + chr(3) # add start and end tokens
        out = out + "".join([chr(0) for _ in range(max_length - len(out))])  # pad to max_length
        out = torch.IntTensor(list(out.encode('utf-8')))

        mask = torch.ones(len(out.nonzero()), dtype=torch.int32)
        mask = torch.cat([mask, torch.zeros(max_length - len(mask), dtype=torch.int32)])
    else:
        out = [chr(x) for x in text[1:len(mask.nonzero())-1]]
        out = "".join(out)
        mask = None

    return out, mask

## <a id='toc2_2_'></a>[Text Encoder](#toc0_)

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.config = config

        self.embedding = nn.Embedding(config.lm_vocab_size, config.lm_hidden_size)
        self.positional_embedding = PositionalEmbedding(config, type='lm')
        self.transformer = nn.ModuleList([
            Encoder(config.lm_hidden_size, config.lm_num_heads, config.lm_ffn_dim, is_causal=True)
            for _ in range(config.lm_num_layers)
        ])

        self.eos_token_id = config.eos_token_id if hasattr(config, 'eos_token_id') else 3


    def forward(self, text, mask=None):
        """
        text: Tensor of shape (batch_size, seq_len)
        mask: Tensor of shape (batch_size, seq_len) or None
        """
        x = self.embedding(text)
        x = self.positional_embedding(x)

        for layer in self.transformer:
            x = layer(x, mask=mask)

        # Extract the [eos] token representation
        eos_pos = (text == self.eos_token_id).nonzero(as_tuple=True)[1]
        eos_feature = x[torch.arange(x.size(0)), eos_pos, :]

        # Normalize the output
        eos_feature = F.normalize(eos_feature, dim=-1)

        return eos_feature


# <a id='toc3_'></a>[Image Processing](#toc0_)

In [None]:
import torchvision.transforms as transforms

# Mean and std of FashionMNIST grayscale images
FASHION_MNIST_MEAN = 0.2860
FASHION_MNIST_STD = 0.3530

image_transform = transforms.Compose([
    transforms.Resize((32, 32)),                     # Resize to 32x32
    transforms.ToTensor(),                           # Convert to tensor, scales [0,255] → [0.0,1.0]
    transforms.Normalize(mean=[FASHION_MNIST_MEAN],  # Normalize grayscale images
                         std=[FASHION_MNIST_STD]),
])

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.config = config


        self.conv_proj = nn.Conv2d(
            in_channels=3,
            out_channels=config.vision_hidden_size,
            kernel_size=config.vision_patch_size,
            stride=config.vision_patch_size,
            padding='valid'
            )
        self.positional_embedding = PositionalEmbedding(config, type='vision')

        self.transformer = nn.ModuleList([
            Encoder(config.vision_hidden_size, config.vision_num_heads, config.vision_ffn_dim, is_causal=False)
            for _ in range(config.vision_num_layers)
        ])



    def forward(self, images):
        """
        images: Tensor of shape (batch_size, channels, height, width)
        """
        # Implement the forward pass for the image encoder
        x = self.conv_proj(images)
        x = x.flatten(2).transpose(1, 2)  # (batch_size, num_patches, hidden_dim)
        x = self.positional_embedding(x)

        for layer in self.transformer:
            x = layer(x)

        cls_feature =  x[: , 0, :]
        return  cls_feature # Return the representation of the first patch (CLS token)

# <a id='toc4_'></a>[CLIP Model](#toc0_)

## <a id='toc4_1_'></a>[Modality Projector](#toc0_)

In [None]:
class ModalityProjector(nn.Module):
    def __init__(self, in_dim, out_dim:int):
        super().__init__()

        self.linear = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        """        x: Tensor of shape (batch_size, in_dim)
        """
        return self.linear(x)

## <a id='toc4_2_'></a>[Final CLIP Model](#toc0_)

In [None]:
class CLIP(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.text_encoder = TextEncoder(config)
        self.image_encoder = ImageEncoder(config)

        self.text_projector = ModalityProjector(config.lm_hidden_size, config.clip_hidden_dim)
        self.image_projector = ModalityProjector(config.vision_hidden_size, config.clip_hidden_dim)

    def encode_text(self, text, mask = None):
        text_features = self.text_encoder(text, mask=mask)
        text_features = self.text_projector(text_features)

        return text_features

    def encode_image(self, images):
        image_features = self.image_encoder(images)
        image_features = self.image_projector(image_features)

        return image_features


    def forward(self, text, images, mask=None):
        image_features = self.encode_image(images)
        text_features = self.encode_text(text, mask)


        return text_features, image_features

## <a id='toc4_3_'></a>[Contrastive Loss](#toc0_)

In [None]:
def clip_loss(text_features, image_features, temperature=0.07):
    text_features = F.normalize(text_features, dim=-1)
    image_features = F.normalize(image_features, dim=-1)

    logits = torch.matmul(text_features, image_features.transpose(-2, -1))* (1 / temperature)

    labels = torch.arange(text_features.size(0), device=text_features.device)

    loss_i = F.cross_entropy(logits, labels)
    loss_t = F.cross_entropy(logits.transpose(-2, -1), labels)

    return (loss_i + loss_t) / 2.0

## <a id='toc4_4_'></a>[Data Sets](#toc0_)

In [None]:
class CLIPDataset(Dataset):
    def __init__(self, train: bool = True, transform = None ):
        self.train = train
        self.transform = transform

        self.dataset = FashionMNIST(root='./data', train=train, download=True)
        self.captions = {i: "a photo of a " + self.dataset.classes[i] for i in range(len(self.dataset.classes))}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]

        image = image.convert('RGB')
        if self.transform:
            image = self.transform(image)

        caption = self.captions[label]

        text, pad_mask = tokenizer(caption, encode=True, max_length=32)
        causal_mask = torch.tril(torch.ones((len(text), len(text))), diagonal=0).bool() # (S, S)
        causal_mask = causal_mask  # Add head dimensions
        if self.train and pad_mask is not None:
            pad_mask = pad_mask  # Add head dimensions
            mask = causal_mask & pad_mask
        else:
            mask = causal_mask

        return {
            'image': image,
            'text': text,
            'mask': mask,
            'label': label
        }

In [None]:
from torch.utils.data import DataLoader
batch_size = 128
epochs = 10

train_set = CLIPDataset(train = True, transform=image_transform)
test_set = CLIPDataset(train = False, transform=image_transform)

train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_set, shuffle=False, batch_size=batch_size)

In [None]:
from tqdm import tqdm
from torch.optim import Adam

model = CLIP(Config())
mdoel = model.to(device)
optimizer = Adam(model.parameters(), lr=1e-4)


model.train()
best_loss = float('inf')

losses = []

for epoch in range(epochs):
    best_loss = float('inf')
    epoch_bar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{epochs}]", leave=False)

    for data in epoch_bar:
        img, text, mask = data['image'], data['text'], data['mask']
        img = img.to(device)
        text = text.to(device)
        mask = mask.to(device)

        img_features, text_features = model(text, img, mask=mask)
        optimizer.zero_grad()
        loss = clip_loss(img_features, text_features)
        loss.backward()
        optimizer.step()

        epoch_bar.set_postfix(loss=loss.item())
        losses.append(loss.item())
        if loss.item() <= best_loss:
            best_loss = loss.item()
            torch.save(model.state_dict(), "best_clip.pt")

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4))
plt.plot(losses, label="Training Loss")
plt.xlabel("Step")
plt.ylabel("Loss")
plt.title("CLIP Training Loss Curve")
plt.legend()
plt.grid(True)
plt.show()

# <a id='toc5_'></a>[Zero-Shot Experiment](#toc0_)

In [None]:
# import matplotlib.pyplot as plt
# %matplotlib inline

In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision.datasets import FashionMNIST
from torchvision import transforms
import random

device = "cuda" if torch.cuda.is_available() else "cpu"

# ---- Load model and tokenizer ----
model = CLIP(Config()).to(device)
model.load_state_dict(torch.load("best_clip.pt"))
model.eval()

def encode_prompts(prompts, max_length=32):
    text_tokens, pad_masks = [], []
    for p in prompts:
        t, pad_mask = tokenizer(p, encode=True, max_length=max_length)
        text_tokens.append(t)
        pad_masks.append(pad_mask.to(dtype=torch.bool))
    return torch.stack(text_tokens).to(device), torch.stack(pad_masks).to(device)

# ---- Class prompts ----
class_names = FashionMNIST(root='./data', train=False).classes
class_prompts = [f"a photo of a {cls}" for cls in class_names]
text, pad_mask = encode_prompts(class_prompts)
seq_len = text.shape[1]
causal_mask = torch.tril(torch.ones((seq_len, seq_len), device=device)).bool()
text_mask = causal_mask.unsqueeze(0) & pad_mask.unsqueeze(1)
text_features = model.encode_text(text, mask=text_mask)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)

# ---- Image preprocessing ----
image_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# ---- Load one image ----
test_dataset = FashionMNIST(root='./data', train=False)
img, label = test_dataset[random.randint(0, len(test_dataset) - 1)]
img_rgb = img.convert("RGB")
img_tensor = image_transform(img_rgb).unsqueeze(0).to(device)

# ---- Encode and compare ----
with torch.no_grad():
    image_features = model.encode_image(img_tensor)
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)

    logits = image_features @ text_features.T
    probs = logits.softmax(dim=-1).squeeze(0)  # [num_classes]

    topk_probs, topk_indices = probs.topk(10)
    topk_labels = [class_names[i] for i in topk_indices]
    topk_texts = [f"{label}: {prob.item() * 100:.2f}%" for label, prob in zip(topk_labels, topk_probs)]

# ---- Display image and predictions ----
plt.figure(figsize=(5, 6))
plt.imshow(img, cmap="gray")
plt.title(f"An image of a {class_names[label]}", fontsize=12)
plt.axis('off')

plt.tight_layout()
plt.show()

# Show top-5 predictions below the image
for i, line in enumerate(topk_texts):
    print(line)