<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Building_a_Simple_Multi_Modal_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from transformers import BertModel

# Example dimensions and number of classes
text_encoder_dim = 768  # Dimension of BERT output
image_encoder_dim = 512  # Example dimension for image encoder
num_classes = 10  # Number of output classes

# Example Text Encoder (BERT)
class TextEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")

    def forward(self, text):
        outputs = self.bert(text)
        return outputs.last_hidden_state[:, 0, :]  # Use the CLS token representation

# Example Image Encoder (Simple CNN)
class ImageEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(16 * 16 * 16, image_encoder_dim)
        )

    def forward(self, image):
        return self.cnn(image)

# MultiModal Model
class MultiModalModel(nn.Module):
    def __init__(self, text_encoder, image_encoder):
        super().__init__()
        self.text_encoder = text_encoder
        self.image_encoder = image_encoder
        self.fc = nn.Linear(text_encoder_dim + image_encoder_dim, num_classes)

    def forward(self, text, image):
        text_embedding = self.text_encoder(text)
        image_embedding = self.image_encoder(image)
        combined = torch.cat([text_embedding, image_embedding], dim=1)
        return self.fc(combined)

# Instantiate encoders and multi-modal model
text_encoder = TextEncoder()
image_encoder = ImageEncoder()
model = MultiModalModel(text_encoder, image_encoder)

# Example inputs
text_input = torch.randint(0, 10000, (8, 64))  # Example tokenized text input
image_input = torch.randn(8, 3, 32, 32)  # Example image input

# Forward pass
outputs = model(text_input, image_input)
print(outputs)