In [1]:
# vit model 
from typing import Callable, NamedTuple, List, Optional, Union
from functools import partial
from collections import OrderedDict
import torch
import math
from torch import nn
from torchvision.ops.misc import Conv2dNormActivation, MLP
from torchvision.models.vision_transformer import MLPBlock, ConvStemConfig, Encoder

class VisionTransformer(nn.Module):
    """Vision Transformer as per https://arxiv.org/abs/2010.11929."""

    def __init__(
        self,
        image_size: int,
        patch_size: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        n_learned_tokens: int,
        mlp_dim: int,
        dropout: float = 0.0,
        attention_dropout: float = 0.0,
        num_classes: int = 1000,
        representation_size: Optional[int] = None,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
        conv_stem_configs: Optional[List[ConvStemConfig]] = None,
    ):
        super().__init__()
        torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
        self.image_size = image_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.n_learned_tokens = n_learned_tokens
        self.mlp_dim = mlp_dim
        self.attention_dropout = attention_dropout
        self.dropout = dropout
        self.num_classes = num_classes
        self.representation_size = representation_size
        self.norm_layer = norm_layer

        if conv_stem_configs is not None:
            # As per https://arxiv.org/abs/2106.14881
            seq_proj = nn.Sequential()
            prev_channels = 3
            for i, conv_stem_layer_config in enumerate(conv_stem_configs):
                seq_proj.add_module(
                    f"conv_bn_relu_{i}",
                    Conv2dNormActivation(
                        in_channels=prev_channels,
                        out_channels=conv_stem_layer_config.out_channels,
                        kernel_size=conv_stem_layer_config.kernel_size,
                        stride=conv_stem_layer_config.stride,
                        norm_layer=conv_stem_layer_config.norm_layer,
                        activation_layer=conv_stem_layer_config.activation_layer,
                    ),
                )
                prev_channels = conv_stem_layer_config.out_channels
            seq_proj.add_module(
                "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
            )
            self.conv_proj: nn.Module = seq_proj
        else:
            self.conv_proj = nn.Conv2d(
                in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
            )

        seq_length = (image_size // patch_size) ** 2

        # Add a class token
        self.class_token = nn.Parameter(torch.zeros(1, n_learned_tokens, hidden_dim))
        seq_length += n_learned_tokens

        self.encoder = Encoder(
            seq_length,
            num_layers,
            num_heads,
            hidden_dim,
            mlp_dim,
            dropout,
            attention_dropout,
            norm_layer,
        )
        self.seq_length = seq_length

        heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
        if representation_size is None:
            heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
        else:
            heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
            heads_layers["act"] = nn.Tanh()
            heads_layers["head"] = nn.Linear(representation_size, num_classes)

        self.heads = nn.Sequential(heads_layers)

        if isinstance(self.conv_proj, nn.Conv2d):
            # Init the patchify stem
            fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
            nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
            if self.conv_proj.bias is not None:
                nn.init.zeros_(self.conv_proj.bias)
        elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
            # Init the last 1x1 conv of the conv stem
            nn.init.normal_(
                self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
            )
            if self.conv_proj.conv_last.bias is not None:
                nn.init.zeros_(self.conv_proj.conv_last.bias)

        if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
            fan_in = self.heads.pre_logits.in_features
            nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
            nn.init.zeros_(self.heads.pre_logits.bias)

        if isinstance(self.heads.head, nn.Linear):
            nn.init.zeros_(self.heads.head.weight)
            nn.init.zeros_(self.heads.head.bias)

    def _process_input(self, x: torch.Tensor) -> torch.Tensor:
        n, c, h, w = x.shape
        p = self.patch_size
        torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
        torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
        n_h = h // p
        n_w = w // p

        # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
        x = self.conv_proj(x)
        # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
        x = x.reshape(n, self.hidden_dim, n_h * n_w)

        # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
        # The self attention layer expects inputs in the format (N, S, E)
        # where S is the source sequence length, N is the batch size, E is the
        # embedding dimension
        x = x.permute(0, 2, 1)

        return x

    def forward(self, x: torch.Tensor):
        # Reshape and permute the input tensor
        x = self._process_input(x)
        n = x.shape[0]

        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.encoder(x)

        # Classifier "token" as used by standard language architectures
        return x


In [2]:
import transformers
from transformers import AutoTokenizer, AutoModel


class BertEmbedding(nn.Module):
    def __init__(self):
        super(BertEmbedding, self).__init__()
                
        bert_id = 'NeuML/pubmedbert-base-embeddings'
        self.bert = AutoModel.from_pretrained(bert_id)
        self.tokenizer = AutoTokenizer.from_pretrained(bert_id)

    def forward(self, x):
        input_ids = self.tokenizer(x, return_tensors='pt', padding=True, truncation=True, max_length=30)['input_ids']
        return self.bert(input_ids).pooler_output

In [3]:
vit_model = VisionTransformer(image_size=512,
        patch_size=16,
        num_layers=16,
        num_heads=8,
        hidden_dim=768,
        n_learned_tokens=32,
        mlp_dim=1280)

bert_model = BertEmbedding()


In [None]:
from torch.nn import functional as F
class CrossAttentionModel(nn.Module):
    def __init__(self, d_h=768, d_g=768, d=768, K=32, M=8, temperature=0.07, use_p_tokens=False):
        """
        d_h: Dimension of each visual mixture token embedding (determined by ViT).
        d_g: Dimension of text encoder embeddings.
        d: Output dimension after cross-attention.
        K: Number of learnable tokens for cross-attention.
        M: Number of attention heads.
        temperature: Temperature for the softmax operation.
        """
        super(CrossAttentionModel, self).__init__()
        
        # Parameters
        self.K = K
        self.M = M
        self.d = d
        self.temperature = temperature

        # Initialize ViT as the image encoder
        self.image_encoder = VisionTransformer(image_size=512,
                                                    patch_size=16,
                                                    num_layers=16,
                                                    num_heads=M,
                                                    hidden_dim=d_h,
                                                    n_learned_tokens=K,
                                                    mlp_dim=d_h)
        self.d_h = d_h
        self.d_g = d_g
        self.use_p_tokens = use_p_tokens
        # Linear transformations for multi-head cross-attention
        self.W_q = nn.Parameter(torch.randn(self.d_g, self.d))      # Query projection for each head
        self.W_k = nn.Parameter(torch.randn(self.d_h, self.d))      # Key projection for each head
        self.W_v = nn.Parameter(torch.randn(self.d_h, self.d))      # Value projection for each head
        self.W_o = nn.Parameter(torch.randn(self.d, self.d))        # Final output projection

        # Placeholder for text encoder (can be any pretrained text model)
        self.text_encoder = BertEmbedding()             # Replace with actual text encoder
        
        # Projection matrix for text vector
        self.W_T = nn.Parameter(torch.randn(d_g, d))
        self.a = torch.log(torch.tensor(10.0))
        self.b = torch.tensor(-10.0)


    def forward(self, image, text):
        """
        image: Tensor representing an image, with shape (batch_size, channels, height, width).
        text: Tensor representing a text caption, with shape (batch_size, d_g).
        """
        batch_size = image.size(0)


        # Step 2: Pass the image and the learnable tokens through the ViT
        # Concatenate the learnable tokens to the patch embeddings as extra tokens
        h_i = self.image_encoder(image)          # (batch_size, N, d_h) N = Numer of patches + K
        if not self.use_p_tokens:
            # Extract only the learnable token outputs as visual mixture tokens
            h_i = h_i[:, :self.K, :]                               # (batch_size, K, d_h)

        # Step 3: Compute text representation
        g_j = self.text_encoder(text)                          # (batch_size, d_g)

        # Step 4: Project text representation for each head to obtain queries
        Q_j = torch.einsum('bg,gd->bd', g_j, self.W_q)        # (batch_size, d)

        # Step 5: Project each visual mixture token for keys and values
        K_i = torch.einsum('bkd,dm->bkm', h_i, self.W_k)       # (batch_size, K, d)
        V_i = torch.einsum('bkd,dm->bkm', h_i, self.W_v)       # (batch_size, K, d)
        

        Q_j = Q_j.view(batch_size, self.M, self.d // self.M)  # (batch_size, M, d // M)
        K_i = K_i.view(batch_size, self.K, self.M, self.d // self.M)  # (batch_size, K, M, d // M)
        V_i = V_i.view(batch_size, self.K, self.M, self.d // self.M)  # (batch_size, K, M, d // M)
        print("Multi-head queries (text)", Q_j.shape, "Multi-head Keys(image)", K_i.shape, "multi-head values(images)",  V_i.shape, "expected (batch_size, M, d // M), (batch_size, K, M, d // M), (batch_size, K, M, d // M)")
        # Step 6: Calculate attention scores for each head and mixture token
        attention_logits = torch.einsum('bmd,ckmd->bckm', Q_j, K_i)  # (batch_size, batch_size, K, M)
        print("attention_logits", attention_logits.shape, "expected (batch_size, batch_size, K, M)")
        attention_weights = F.softmax(attention_logits / self.temperature, dim=1)  # (batch_size, batch_size, K, M)

        # Step 7: Apply attention weights to values
        weighted_values = torch.einsum('bckm,bkmd->bkmd', attention_weights, V_i).sum(dim=1)  # (batch_size, K, M, d // M)
        # weighted_values = (attention_weights * V_i).sum(dim=1)  # (batch_size, M, d // M)
        print("weighted value", weighted_values.shape, "expected (batch_size, M, d // M)")
        # Step 8: Concatenate heads and project to final dimension d
        z_ij = weighted_values.view(batch_size, -1)           # (batch_size, d)
        z_ij = torch.matmul(z_ij, self.W_o)                   # (batch_size, d)

        # Step 9: Project text vector
        z_j = torch.matmul(g_j, self.W_T)                     # (batch_size, d)

        # Step 10: Normalize both representations
        z_ij = F.normalize(z_ij, p=2, dim=1)                  # (batch_size, d)
        z_j = F.normalize(z_j, p=2, dim=1)                    # (batch_size, d)

        return z_ij, z_j
    
    def llip_loss(self, Z, Z_prime):
        """
        Compute the Llip loss from text-conditioned image embeddings and text embeddings.

        Args:
            Z (torch.Tensor): Tensor of shape (batch_size, d) representing embeddings for the first modality (e.g., images).
            Z_prime (torch.Tensor): Tensor of shape (batch_size, d) representing embeddings for the second modality (e.g., text).
            a (float): Scaling factor for the similarities.
            b (float): Shifting factor for the similarities.

        Returns:
            torch.Tensor: The computed L_Llip loss.
        """
        batch_size = Z.size(0)
        a = torch.exp(self.a)
        logits = torch.einsum('ab,cd->ac', Z, Z_prime)  # (batch_size, batch_size)
        logits = (a * logits) + self.b
        pos_ids = torch.eye(batch_size)
        neg_ids = 1 - pos_ids
        pos_logits = logits * pos_ids
        
        neg_logits = logits * neg_ids
        pos_loss = -F.logsigmoid(pos_logits)
        neg_loss = -F.logsigmoid(-neg_logits).mean(dim=1)
        llip_loss_value = (pos_loss.sum() + neg_loss.sum()) / batch_size


        return llip_loss_value
    

In [5]:
text = ["aboba", "aboba", "aboba", "aboba"]
image = torch.rand(4, 3, 512, 512)
model = CrossAttentionModel()
z_ij, z_j = model(image, text)

Multi-head queries (text) torch.Size([4, 8, 96]) Multi-head Keys(image) torch.Size([4, 32, 8, 96]) multi-head values(images) torch.Size([4, 32, 8, 96]) expected (batch_size, M, d // M), (batch_size, K, M, d // M), (batch_size, K, M, d // M)
attention_logits torch.Size([4, 4, 32, 8]) expected (batch_size, batch_size, K, M)
weighted value torch.Size([4, 8, 96]) expected (batch_size, M, d // M)


In [6]:
def llip_loss(Z, Z_prime, a=1.0, b=0.0):
        """
        Compute the Llip loss from text-conditioned image embeddings and text embeddings.

        Args:
            Z (torch.Tensor): Tensor of shape (batch_size, d) representing embeddings for the first modality (e.g., images).
            Z_prime (torch.Tensor): Tensor of shape (batch_size, d) representing embeddings for the second modality (e.g., text).
            a (float): Scaling factor for the similarities.
            b (float): Shifting factor for the similarities.

        Returns:
            torch.Tensor: The computed L_Llip loss.
        """
        batch_size = Z.size(0)

        logits = torch.einsum('ab,cd->ac', Z, Z_prime)  # (batch_size, batch_size)
        logits = (a * logits) + b
        pos_ids = torch.eye(batch_size)
        neg_ids = 1 - pos_ids
        pos_logits = logits * pos_ids
        print(pos_logits)
        
        neg_logits = logits * neg_ids
        print(neg_logits)
        pos_loss = -F.logsigmoid(pos_logits)
        neg_loss = -F.logsigmoid(-neg_logits).mean(dim=1)
        llip_loss_value = (pos_loss.sum() + neg_loss.sum()) / batch_size


        return llip_loss_value