<a href="https://colab.research.google.com/github/Murad-pitafi/Computer-Vision/blob/main/Cross_Modal_Data_Generation_Using_VAEs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import torchvision.transforms as T
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import AdamW
from PIL import Image
import random


In [3]:
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor()
])

coco_data = datasets.CocoDetection(root='path_to_coco_images', annFile='path_to_annotations', transform=transform)
train_loader = DataLoader(coco_data, batch_size=16, shuffle=True)

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token


# you can update the path , because coco dataset is large just replace it

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, latent_dim=256):
        super(ImageEncoder, self).__init__()
        self.cnn = models.resnet18(pretrained=True)
        self.cnn.fc = nn.Linear(self.cnn.fc.in_features, latent_dim)

    def forward(self, x):
        return self.cnn(x)


In [None]:
class TextDecoder(nn.Module):
    def __init__(self, latent_dim=256):
        super(TextDecoder, self).__init__()
        self.model = GPT2LMHeadModel.from_pretrained('gpt2')
        self.fc = nn.Linear(latent_dim, self.model.config.n_embd)

    def forward(self, latent_vec, input_ids):
        context = self.fc(latent_vec).unsqueeze(1)
        return self.model(input_ids=input_ids, encoder_hidden_states=context, labels=input_ids)


In [None]:
class CrossModalVAE(nn.Module):
    def __init__(self, latent_dim=256):
        super(CrossModalVAE, self).__init__()
        self.encoder = ImageEncoder(latent_dim)
        self.decoder = TextDecoder(latent_dim)

    def forward(self, images, captions):
        latent = self.encoder(images)
        output = self.decoder(latent, captions)
        return output


In [None]:
latent_dim = 256
model = CrossModalVAE(latent_dim=latent_dim)
optimizer = AdamW(model.parameters(), lr=5e-5)
criterion = nn.CrossEntropyLoss()


In [None]:
epochs = 5
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for images, targets in train_loader:
        images = images.to('cuda')
        captions = [random.choice(ann['caption']) for ann in targets]
        encoded_caps = tokenizer(captions, return_tensors='pt', padding=True, truncation=True)
        input_ids = encoded_caps.input_ids.to('cuda')

        optimizer.zero_grad()
        output = model(images, input_ids)
        loss = output.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}')


In [None]:
model.eval()
test_image = Image.open('path_to_test_image').convert('RGB')
test_image = transform(test_image).unsqueeze(0).to('cuda')

with torch.no_grad():
    latent = model.encoder(test_image)
    generated = model.decoder.model.generate(latent)

generated_caption = tokenizer.decode(generated[0], skip_special_tokens=True)
print(f'Generated Caption: {generated_caption}')
