In [4]:
import torch
import clip
from PIL import Image
import numpy as np


device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/16", device=device)

image = preprocess(Image.fromarray(np.uint8(np.zeros((224, 224, 3))))).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "darkness"]).to(device)
print(image.min(),image.max())
with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    print(image_features.shape, text_features.shape)
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

tensor(-1.7923, device='cuda:0') tensor(-1.4802, device='cuda:0')
torch.Size([1, 512]) torch.Size([3, 512])
Label probs: [[0.7686  0.06018 0.1714 ]]


In [5]:
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [7]:
model, preprocess = clip.load("RN50x4", device=device)
preprocess

 92%|███████████████████████████████████▉   | 371M/402M [02:24<00:08, 3.82MiB/s]

: 

In [36]:
import torch
import torch.nn.functional as F
import torch.nn as nn

class CrossAttentionTransformer(nn.Module):
    def __init__(self, q_dim,cross_attention_dim,embed_dim, num_heads=1,dropout=0.2):
        super(CrossAttentionTransformer, self).__init__()

        self.num_heads = num_heads
        self.head_dim = embed_dim 
        total_embed_dim = num_heads * embed_dim
        # Linear transformations for queries, keys, and values
        self.query_linear = nn.Linear(q_dim, total_embed_dim)
        self.key_linear = nn.Linear(cross_attention_dim, total_embed_dim)
        self.value_linear = nn.Linear(cross_attention_dim, total_embed_dim)
        self.dropout = nn.Dropout(dropout)
        # Output linear layer
        self.out_linear = nn.Linear(total_embed_dim,cross_attention_dim)

    def forward(self, x, conditional_x):
        batch_size, len_x, _     = x.size()
        _, len_conditional_x, _ = conditional_x.size()

        # Linear transformations
        query = self.query_linear(x)
        key   = self.key_linear(conditional_x)
        value = self.value_linear(conditional_x)

        # Reshape for multi-head attention
        query = query.view(batch_size, len_x            , self.num_heads, self.head_dim).transpose(1, 2)
        key   =   key.view(batch_size, len_conditional_x, self.num_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch_size, len_conditional_x, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute attention scores
        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # Apply softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)

        # Apply attention weights to values
        attended_values = torch.matmul(attention_weights, value)

        # Reshape and concatenate heads
        attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, len_x, -1)

        # Apply output linear layer
        output = self.out_linear(attended_values)
        output = self.dropout(output)
        output = output + x
        return output


In [37]:

# Example usage
batch_size = 2
sequence_length = 5*(512//8)
emp = 8
heads = 16
embed_dim = 8 


images_emps = torch.rand((2, sequence_length, emp))  # Batch size of 2, sequence length of 10, embedding dimension of 512
text_emp    = torch.rand((2, 64, emp))               # Batch size of 2, sequence length of 8, embedding dimension of 512

model = CrossAttentionTransformer(q_dim=emp, cross_attention_dim=emp,embed_dim=embed_dim, num_heads=heads)
output = model(images_emps, text_emp)
print(output.shape)  # Output shape: torch.Size([2, 10, 512])


RuntimeError: shape '[2, 320, 16, 8]' is invalid for input of size 16384

In [34]:
import torch
import torch.nn.functional as F
import torch.nn as nn

class CrossAttentionTransformerMultiLayer(nn.Module):
    def __init__(self, num_layers,q_dim,cross_attention_dim,embed_dim, num_heads=1):
        super(CrossAttentionTransformerMultiLayer, self).__init__()

        self.layers = nn.ModuleList([
            CrossAttentionTransformer(q_dim=q_dim, cross_attention_dim=cross_attention_dim,embed_dim=embed_dim, num_heads=num_heads) for _ in range(num_layers)
        ])

    def forward(self, vectors1, vector2):
        for layer in self.layers:
            vectors1 = layer(vectors1, vector2)
        return vectors1


In [35]:
# Example usage
batch_size = 2
sequence_length = 5*(512//8)
emp = 8
heads = 16
embed_dim = 8 
layers = 2

images_emps = torch.rand((2, sequence_length, emp))  # Batch size of 2, sequence length of 10, embedding dimension of 512
text_emp    = torch.rand((2, 64, emp))               # Batch size of 2, sequence length of 8, embedding dimension of 512

model = CrossAttentionTransformerMultiLayer(num_layers=layers,q_dim=emp, cross_attention_dim=emp,embed_dim=embed_dim, num_heads=heads)
output = model(images_emps, text_emp)
print(output.shape)  # Output shape: torch.Size([2, 10, 512])

torch.Size([2, 320, 8])
