In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
from transformers import ViTModel, BertModel, BertTokenizer
from sklearn.model_selection import train_test_split

In [None]:
class CoTransformer(nn.Module):
    def __init__(self, dim):
        super(CoTransformer, self).__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads=8)
        self.linear = nn.Linear(dim, dim)
        self.layer_norm = nn.LayerNorm(dim)

    def forward(self, query, key, value):
        attn_output, _ = self.attention(query, key, value)
        co_transformed_repr = self.layer_norm(self.linear(attn_output) + query)
        return co_transformed_repr

In [None]:
vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
bert_model = BertModel.from_pretrained('bert-base-uncased')

In [None]:
class MyModel(nn.Module):
    def __init__(self, num_classes, dim):
        super(MyModel, self).__init__()
        self.co_transformer_vit_to_bert = CoTransformer(dim=dim)
        self.co_transformer_bert_to_vit = CoTransformer(dim=dim)

        self.layer1 = nn.Sequential(
            nn.Linear(in_features=dim*2, out_features=512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),
        )

        self.layer2 = nn.Sequential(
            nn.Linear(in_features=512, out_features=256),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),
        )
        
        self.output_layer = nn.Linear(in_features=256, out_features=num_classes)


    def forward(self, img_query, img_key, img_value, ques_query, ques_key, ques_value):
        joint_repr_vit_to_bert = self.co_transformer_vit_to_bert(ques_query, img_key, img_value)
        joint_repr_bert_to_vit = self.co_transformer_bert_to_vit(img_query, ques_key, ques_value)
        combined_repr = torch.cat((joint_repr_vit_to_bert, joint_repr_bert_to_vit), dim=-1)
        x = self.layer1(combined_repr)
        x = self.layer2(x)
        output = self.output_layer(x)
        return output

In [None]:
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image, question, answer = self.data[index]
        return image, question, answer

In [None]:
dummy_data = [
    (torch.randn(3, 224, 224), "What is this?", 0),
    (torch.randn(3, 224, 224), "What color is this?", 1),
    # Add more data as needed
]

In [None]:
# Split dataset into training and validation sets
train_data, val_data = train_test_split(dummy_data, test_size=0.2)

# Define data loaders
train_loader = DataLoader(MyDataset(train_data), batch_size=32, shuffle=True)
val_loader = DataLoader(MyDataset(val_data), batch_size=32, shuffle=False)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
num_classes = 2

In [None]:
model = MyModel(num_classes=num_classes, dim=768)  # Assuming dim=768 for ViT and BERT

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for images, questions, answers in train_loader:
        optimizer.zero_grad()

        # Process images
        img_output = vit_model(pixel_values=images)
        img_query, img_key, img_value = img_output.last_hidden_state.split(1, dim=-1)

        # Tokenize questions
        question_tokens = tokenizer(questions, return_tensors='pt', padding=True, truncation=True)
        ques_output = bert_model(**question_tokens)
        ques_query, ques_key, ques_value = ques_output.last_hidden_state.split(1, dim=-1)

        # Forward pass
        output = model(img_query.squeeze(0), img_key.squeeze(0), img_value.squeeze(0),
                       ques_query.squeeze(0), ques_key.squeeze(0), ques_value.squeeze(0))

        # Compute loss
        loss = criterion(output, answers)

        # Backward pass
        loss.backward()
        optimizer.step()

    # Validate the model
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, questions, answers in val_loader:
            # Process images
            img_output = vit_model(pixel_values=images)
            img_query, img_key, img_value = img_output.last_hidden_state.split(1, dim=-1)

            # Tokenize questions
            question_tokens = tokenizer(questions, return_tensors='pt', padding=True, truncation=True)
            ques_output = bert_model(**question_tokens)
            ques_query, ques_key, ques_value = ques_output.last_hidden_state.split(1, dim=-1)

            # Forward pass
            output = model(img_query.squeeze(0), img_key.squeeze(0), img_value.squeeze(0),
                           ques_query.squeeze(0), ques_key.squeeze(0), ques_value.squeeze(0))

            # Compute loss
            val_loss += criterion(output, answers).item()

            # Compute accuracy
            _, predicted = torch.max(output, 1)
            total += answers.size(0)
            correct += (predicted == answers).sum().item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {val_loss/len(val_loader)}, Accuracy: {(correct/total)*100}%")