<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Example_Cross_Attention_for_Text_Image_Integration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, CLIPProcessor, CLIPModel
from PIL import Image
import requests
from io import BytesIO

class CrossAttentionModel(nn.Module):
    def __init__(self, text_model, image_model, embed_dim=768, num_heads=8):
        super().__init__()
        self.text_model = text_model
        self.image_model = image_model
        self.cross_attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)

    def forward(self, text_inputs, text_attention_mask, pixel_values):
        text_outputs = self.text_model(input_ids=text_inputs, attention_mask=text_attention_mask)
        text_feats = text_outputs.last_hidden_state.transpose(0, 1)  # (seq_len, batch, embed_dim)

        image_outputs = self.image_model(pixel_values=pixel_values)
        image_feats = image_outputs.last_hidden_state.transpose(0, 1)  # (seq_len, batch, embed_dim)

        # Cross-attention between text and image features
        attn_output, _ = self.cross_attention(text_feats, image_feats, image_feats)
        return attn_output

# Example usage
# Load a text model and image model from Hugging Face
text_model = AutoModel.from_pretrained("bert-base-uncased")
image_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

# Initialize the CrossAttentionModel with the text and image models
cross_attention_model = CrossAttentionModel(text_model, image_model)

# Example input
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
text_input = tokenizer("This is an example sentence.", return_tensors="pt")

# Using a valid image URL
image_url = "https://images.unsplash.com/photo-1516637090014-3aceb6cd25b3"  # Replace with a valid image URL
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))

image_input = processor(images=image, return_tensors="pt")["pixel_values"]

# Forward pass through the cross-attention model
output = cross_attention_model(text_input["input_ids"], text_input["attention_mask"], image_input)

print(output)