In [4]:
import torch
import clip
from PIL import Image
import numpy as np


device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/16", device=device)

image = preprocess(Image.fromarray(np.uint8(np.zeros((224, 224, 3))))).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "darkness"]).to(device)
print(image.min(),image.max())
with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    print(image_features.shape, text_features.shape)
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

tensor(-1.7923, device='cuda:0') tensor(-1.4802, device='cuda:0')
torch.Size([1, 512]) torch.Size([3, 512])
Label probs: [[0.7686  0.06018 0.1714 ]]


In [5]:
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [7]:
model, preprocess = clip.load("RN50x4", device=device)
preprocess

 92%|███████████████████████████████████▉   | 371M/402M [02:24<00:08, 3.82MiB/s]

: 

In [38]:
import torch
import torch.nn.functional as F
import torch.nn as nn

class CrossAttentionTransformer(nn.Module):
    def __init__(self, q_dim,cross_attention_dim,embed_dim, num_heads=1,dropout=0.2):
        super(CrossAttentionTransformer, self).__init__()

        self.num_heads = num_heads
        self.head_dim = embed_dim 
        total_embed_dim = num_heads * embed_dim
        # Linear transformations for queries, keys, and values
        self.query_linear = nn.Linear(q_dim, total_embed_dim)
        self.key_linear = nn.Linear(cross_attention_dim, total_embed_dim)
        self.value_linear = nn.Linear(cross_attention_dim, total_embed_dim)
        self.dropout = nn.Dropout(dropout)
        # Output linear layer
        self.out_linear = nn.Linear(total_embed_dim,cross_attention_dim)

    def forward(self, x, conditional_x):
        batch_size, len_x, _     = x.size()
        _, len_conditional_x, _ = conditional_x.size()

        # Linear transformations
        query = self.query_linear(x)
        key   = self.key_linear(conditional_x)
        value = self.value_linear(conditional_x)

        # Reshape for multi-head attention
        query = query.view(batch_size, len_x            , self.num_heads, self.head_dim).transpose(1, 2)
        key   =   key.view(batch_size, len_conditional_x, self.num_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch_size, len_conditional_x, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute attention scores
        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # Apply softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)

        # Apply attention weights to values
        attended_values = torch.matmul(attention_weights, value)

        # Reshape and concatenate heads
        attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, len_x, -1)

        # Apply output linear layer
        output = self.out_linear(attended_values)
        output = self.dropout(output)
        output = output + x
        return output


In [39]:

# Example usage
batch_size = 2
sequence_length = 5*(512//8)
emp = 8
heads = 16
embed_dim = 8 


images_emps = torch.rand((2, sequence_length, emp))  # Batch size of 2, sequence length of 10, embedding dimension of 512
text_emp    = torch.rand((2, 64, emp))               # Batch size of 2, sequence length of 8, embedding dimension of 512

model = CrossAttentionTransformer(q_dim=emp, cross_attention_dim=emp,embed_dim=embed_dim, num_heads=heads)
output = model(images_emps, text_emp)
print(output.shape)  # Output shape: torch.Size([2, 10, 512])


torch.Size([2, 320, 8])


In [34]:
import torch
import torch.nn.functional as F
import torch.nn as nn

class CrossAttentionTransformerMultiLayer(nn.Module):
    def __init__(self, num_layers,q_dim,cross_attention_dim,embed_dim, num_heads=1):
        super(CrossAttentionTransformerMultiLayer, self).__init__()

        self.layers = nn.ModuleList([
            CrossAttentionTransformer(q_dim=q_dim, cross_attention_dim=cross_attention_dim,embed_dim=embed_dim, num_heads=num_heads) for _ in range(num_layers)
        ])

    def forward(self, vectors1, vector2):
        for layer in self.layers:
            vectors1 = layer(vectors1, vector2)
        return vectors1


In [35]:
# Example usage
batch_size = 2
sequence_length = 5*(512//8)
emp = 8
heads = 16
embed_dim = 8 
layers = 2

images_emps = torch.rand((2, sequence_length, emp))  # Batch size of 2, sequence length of 10, embedding dimension of 512
text_emp    = torch.rand((2, 64, emp))               # Batch size of 2, sequence length of 8, embedding dimension of 512

model = CrossAttentionTransformerMultiLayer(num_layers=layers,q_dim=emp, cross_attention_dim=emp,embed_dim=embed_dim, num_heads=heads)
output = model(images_emps, text_emp)
print(output.shape)  # Output shape: torch.Size([2, 10, 512])

torch.Size([2, 320, 8])


In [None]:

class CrossAttentionTransformerMultiLayer(nn.Module):
    def __init__(self, num_layers,q_dim,cross_attention_dim,embed_dim, num_heads=1,dropout=0.2):
        super(CrossAttentionTransformerMultiLayer, self).__init__()

        self.layers = nn.ModuleList([
            CrossAttentionLayer(q_dim=q_dim, cross_attention_dim=cross_attention_dim,embed_dim=embed_dim, num_heads=num_heads,dropout=dropout) for _ in range(num_layers)
        ])

    def forward(self,x, conditional_x):
        for layer in self.layers:
            x = layer(x, conditional_x)
        return x

class CrossAttentionEncoder(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.att_head_emp = args.att_head_emp
       
        self.encoder = CrossAttentionTransformerMultiLayer(num_layers=args.neck_layers,
                                                                q_dim=args.att_head_emp,
                                                                cross_attention_dim=args.att_head_emp,
                                                                embed_dim=args.att_head_emp,
                                                                num_heads=args.n_heads,
                                                                dropout=args.neck_dropout)
        
        self.flatten = nn.Flatten()
    def forward(self, input_x):
        images_emps,text_emps,pos_emps = input_x
        shape = images_emps.shape

        images_emps = images_emps.reshape(shape[0],-1,self.att_head_emp)
        text_emps   = text_emps.reshape(shape[0],-1,self.att_head_emp)
        images_emps  = self.encoder(images_emps,text_emps)
        images_emps = images_emps
        text_emps   = text_emps

        text_images_embeddings = torch.cat([images_emps,text_emps],dim=1)
        text_images_embeddings = self.flatten(text_images_embeddings)
        return torch.cat([text_images_embeddings,pos_emps],dim=1)
         
         
    

      
    def get_opt_params(self):
        return  [
            {"params": self.encoder.parameters()}
             ]


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(CrossAttention, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads

        assert (
            self.head_dim * num_heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(num_heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values = values.reshape(N, value_len, self.num_heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.num_heads, self.head_dim)
        queries = query.reshape(N, query_len, self.num_heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.num_heads * self.head_dim
        )

        out = self.fc_out(out)
        return out


class CrossAttentionEncoderLayer(nn.Module):
    def __init__(self, embed_size, num_heads, ff_hidden_size, dropout):
        super(CrossAttentionEncoderLayer, self).__init__()
        self.cross_attention = CrossAttention(embed_size, num_heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, ff_hidden_size),
            nn.ReLU(),
            nn.Linear(ff_hidden_size, embed_size),
        )
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, conditional_src, src_mask):
        cross_attended_src = self.cross_attention(src, conditional_src, src, src_mask)

        x = self.dropout(self.norm1(cross_attended_src + src))
        forward = self.feed_forward(x)

        out = self.dropout(self.norm2(forward + x))
        return out


class CrossAttentionEncoder(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        embed_size,
        num_layers,
        num_heads,
        ff_hidden_size,
        dropout,
        max_length,
    ):
        super(CrossAttentionEncoder, self).__init__()

        self.embed_size = embed_size
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                CrossAttentionEncoderLayer(
                    embed_size, num_heads, ff_hidden_size, dropout
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, src, conditional_src, mask):
        N, seq_length = src.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(device)
        src = self.dropout(self.word_embedding(src) + self.position_embedding(positions))

        for layer in self.layers:
            src = layer(src, conditional_src, mask)

        return src


In [70]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(CrossAttention, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads

        assert (
            self.head_dim * num_heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(num_heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask=None):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values  = values.reshape(N, value_len, self.num_heads, self.head_dim)
        keys    = keys.reshape(N, key_len, self.num_heads, self.head_dim)
        queries = query.reshape(N, query_len, self.num_heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.num_heads * self.head_dim
        )

        out = self.fc_out(out)
        return out


class CrossAttentionEncoderLayer(nn.Module):
    def __init__(self, embed_size, num_heads, dropout):
        super(CrossAttentionEncoderLayer, self).__init__()
        self.cross_attention = CrossAttention(embed_size, num_heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, embed_size),
            nn.ReLU(),
            nn.Linear(embed_size, embed_size),
        )
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, conditional_src , src_mask=None):
        cross_attended_src = self.cross_attention(src, conditional_src, src, src_mask)

        x = self.dropout(self.norm1(cross_attended_src + src))
        forward = self.feed_forward(x)

        out = self.dropout(self.norm2(forward + x))
        return out


class CrossAttentionEncoder(nn.Module):
    def __init__(
        self,
        embed_size,
        num_layers,
        num_heads,
        dropout,
        max_length,
    ):
        super(CrossAttentionEncoder, self).__init__()

        self.embed_size = embed_size
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                CrossAttentionEncoderLayer(
                    embed_size, num_heads, dropout
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, src, conditional_src, mask=None):
        N, seq_length,emps = src.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length)
        src = self.dropout(src + self.position_embedding(positions))

        for layer in self.layers:
            src = layer(src, conditional_src, mask)

        return src


In [71]:
src      = torch.zeros((2, 5, 512))
cond_src = torch.zeros((2, 1, 512))


model = CrossAttentionEncoder(embed_size=512,num_layers=2, num_heads=8, dropout=0.1,max_length=5)

model(cond_src,src).shape


torch.Size([2, 5, 512])

In [6]:
import torch
from transformers import PerceiverModel, PerceiverConfig
class CrossAttentionPerceiver(torch.nn.Module):
    def __init__(self, config):
        super(CrossAttentionPerceiver, self).__init__()

        self.perceiver = PerceiverModel(config)
        
    def forward(self, input1, input2, attention_mask1=None, attention_mask2=None):
        """
        Args:
            input1: Input tensor for the first sequence.
            input2: Input tensor for the second sequence.
            attention_mask1: Attention mask for the first sequence (optional).
            attention_mask2: Attention mask for the second sequence (optional).
        Returns:
            outputs: Model outputs.
        """
        # Assuming both inputs have the same sequence length
        input_shape = input1.size()

        # Combine the two sequences
        combined_input = torch.cat([input1, input2], dim=1)

        # Combine the attention masks if provided
        if attention_mask1 is not None and attention_mask2 is not None:
            combined_attention_mask = torch.cat([attention_mask1, attention_mask2], dim=1)
        else:
            combined_attention_mask = None

        # Forward pass through the Perceiver model
        outputs = self.perceiver(
            inputs_embeds=combined_input,
            attention_mask=combined_attention_mask
        )

        return outputs

# Example usage:
config = PerceiverConfig()
cross_attention_perceiver = CrossAttentionPerceiver(config)

# Dummy input tensors
input1 = torch.randn(1, 10, 768)  # Sequence length of 10, assuming embedding size is 768
input2 = torch.randn(1, 10, 768)

# Forward pass
outputs = cross_attention_perceiver(input1, input2)
print(outputs)


ImportError: cannot import name 'PerceiverModel' from 'transformers' (unknown location)

In [1]:
import torch 

In [2]:
model = torch.nn.MultiheadAttention(512,8)

In [7]:
# Dummy input tensors
from torch import nn
input1 = torch.randn(1, 10, 512)  # Sequence length of 10, assuming embedding size is 768
input2 = torch.randn(1, 10, 512)
position_embedding = nn.Embedding(20, 512)
positions = torch.arange(0, 10).expand(1, 10)
input1 = input1 + position_embedding(positions)

model(input1,input2,input1)

IndexError: index out of range in self