In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch

class ImageEncoder(nn.Module):
    def __init__(self):
        super(ImageEncoder, self).__init__()

        self.embedding_dim = 16

        # Downscaling layers for Q, K, V
        self.w_q = nn.Linear(16, 2)
        self.w_k = nn.Linear(16, 2)
        self.w_v = nn.Linear(16, 2)

        # Upscaling back to embedding dim
        self.latent_upscale = nn.Linear(2, 16)

        # Layer norm
        self.layer_norm = nn.LayerNorm(16)

        # Feedforward block
        self.feed_fwd = nn.Sequential(
            nn.Linear(16, 16),
            nn.Linear(16, 16),
            nn.Linear(16, 16)
        )

        # Final projection
        self.output_proj = nn.Linear(16, 16)

    def forward(self, x):
        Q = self.w_q(x)
        K = self.w_k(x)
        V = self.w_v(x)

        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embedding_dim ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)
        context = torch.matmul(attention_weights, V)

        context = self.latent_upscale(context)

        # Residual + Norm
        x = self.layer_norm(context + x)

        # Feedforward + Norm
        ff_out = self.feed_fwd(x)
        out = self.layer_norm(ff_out + x)

        # Final linear (optional)
        return self.output_proj(out)

In [2]:
class TextEncoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(TextEncoder, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.embedding_dim = embedding_dim

        self.w_q = nn.Linear(16, 2)
        self.w_k = nn.Linear(16, 2)
        self.w_v = nn.Linear(16, 2)

        self.latent_upscale = nn.Linear(2, 16)
        self.layer_norm = nn.LayerNorm(16)

        self.feed_fwd = nn.Sequential(
            nn.Linear(16, 16),
            nn.Linear(16, 16),
            nn.Linear(16, 16)
        )

        self.output_proj = nn.Linear(16, 16)

    def forward(self, token_ids, pe):
        x = self.embedding(token_ids) + pe  # Add positional encodings

        Q = self.w_q(x)
        K = self.w_k(x)
        V = self.w_v(x)

        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embedding_dim ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)
        context = torch.matmul(attention_weights, V)

        context = self.latent_upscale(context)
        x = self.layer_norm(context + x)

        ff_out = self.feed_fwd(x)
        out = self.layer_norm(ff_out + x)

        return self.output_proj(out)

In [None]:
class CLIPMini(nn.Module):
    def __init__(self, vocab_size, embedding_dim=16):
        super(CLIPMini, self).__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder(vocab_size, embedding_dim)

    def forward(self, image_patches, text_tokens, pe):
        img_embs = self.image_encoder(image_patches)  # [B, 196, 16]
        txt_embs = self.text_encoder(text_tokens, pe) # [B, seq_len, 16]

        # Pool
        img_vec = torch.mean(img_embs, dim=1)      # [B, 16]
        txt_vec = txt_embs[:, 0, :]                # [B, 16] - CLS

        # Normalize
        img_vec = F.normalize(img_vec, dim=-1)
        txt_vec = F.normalize(txt_vec, dim=-1)

        return img_vec, txt_vec