In [7]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
from transformers import RobertaTokenizer, ViTFeatureExtractor
import torch
import torch.nn as nn
from transformers import RobertaModel, ViTModel, T5ForConditionalGeneration
import torch.optim as optim
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm

In [1]:


class MemeDataset(Dataset):
    def __init__(self, csv_file, image_dir, tokenizer, feature_extractor, max_text_length=128):
        """
        Args:
            csv_file (str): Path to the CSV file with image paths, OCR text, and LLaVA rationales.
            image_dir (str): Directory with all the meme images.
            tokenizer (transformers tokenizer): Tokenizer for OCR text.
            feature_extractor (transformers feature extractor): Feature extractor for images.
            max_text_length (int): Maximum length for text tokens.
        """
        self.data = pd.read_csv(csv_file)
        self.image_dir = image_dir
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor
        self.max_text_length = max_text_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        image_path = os.path.join(self.image_dir, item['image_name'])
        ocr_text = item['ocr_text']
        target_explanation = item['llava_rationale']  # LLaVA-generated rationale

        # Process image
        image = Image.open(image_path).convert("RGB")
        image = self.feature_extractor(images=image, return_tensors="pt")['pixel_values'].squeeze()

        # Process text
        text_encoding = self.tokenizer(
            ocr_text,
            padding='max_length',
            truncation=True,
            max_length=self.max_text_length,
            return_tensors="pt"
        )

        # Process target explanation
        target_encoding = self.tokenizer(
            target_explanation,
            padding='max_length',
            truncation=True,
            max_length=self.max_text_length,
            return_tensors="pt"
        )

        return {
            'image': image,
            'ocr_text_input_ids': text_encoding['input_ids'].squeeze(),
            'ocr_text_attention_mask': text_encoding['attention_mask'].squeeze(),
            'target_ids': target_encoding['input_ids'].squeeze(),
            'target_attention_mask': target_encoding['attention_mask'].squeeze(),
        }

# Example usage:
# tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
# feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
# dataset = MemeDataset('path_to_csv.csv', 'path_to_images/', tokenizer, feature_extractor)
# dataloader = DataLoader(dataset, batch_size=8, shuffle=True)


In [4]:
class MultimodalExplanationModel(nn.Module):
    def __init__(self, text_model_name='roberta-base', vision_model_name='google/vit-base-patch16-224', t5_model_name='t5-small'):
        super(MultimodalExplanationModel, self).__init__()
        
        # Text and Vision encoders
        self.text_encoder = RobertaModel.from_pretrained(text_model_name)
        self.vision_encoder = ViTModel.from_pretrained(vision_model_name)
        
        # Decoder for generating explanations
        self.decoder = T5ForConditionalGeneration.from_pretrained(t5_model_name)

        # Linear layer to combine the two modalities
        self.fc = nn.Linear(self.text_encoder.config.hidden_size + self.vision_encoder.config.hidden_size, self.decoder.config.d_model)
    
    def forward(self, ocr_text_input_ids, ocr_text_attention_mask, image, target_ids=None, target_attention_mask=None):
        # Textual features
        text_outputs = self.text_encoder(input_ids=ocr_text_input_ids, attention_mask=ocr_text_attention_mask)
        text_features = text_outputs.last_hidden_state[:, 0, :]  # Use the <s> token (CLS token) for classification
        
        # Visual features
        vision_outputs = self.vision_encoder(pixel_values=image)
        vision_features = vision_outputs.last_hidden_state[:, 0, :]  # Use the [CLS] token for classification
        
        # Combine features
        combined_features = torch.cat((text_features, vision_features), dim=1)
        combined_features = self.fc(combined_features)
        
        # Generate explanation
        decoder_outputs = self.decoder(
            input_ids=target_ids,
            attention_mask=target_attention_mask,
            encoder_outputs=(combined_features.unsqueeze(1),),
            labels=target_ids
        )

        return decoder_outputs.loss, decoder_outputs.logits

In [9]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = MultimodalExplanationModel()
optimizer = AdamW(model.parameters(), lr=5e-5)

# Data
dataset = MemeDataset('./DATA/train.csv', './DATA/trainImages', tokenizer, feature_extractor)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
for i in tqdm(dataloader):
    print(i)

  0%|          | 0/875 [00:00<?, ?it/s]


KeyError: 'image_name'

In [6]:


# Training loop
model.train()
for epoch in range(10):  # Set the number of epochs
    epoch_loss = 0
    for batch in tqdm(dataloader):
        optimizer.zero_grad()
        
        # Get data
        ocr_text_input_ids = batch['ocr_text_input_ids']
        ocr_text_attention_mask = batch['ocr_text_attention_mask']
        image = batch['image']
        target_ids = batch['target_ids']
        target_attention_mask = batch['target_attention_mask']
        
        # Forward pass
        loss, logits = model(
            ocr_text_input_ids=ocr_text_input_ids,
            ocr_text_attention_mask=ocr_text_attention_mask,
            image=image,
            target_ids=target_ids,
            target_attention_mask=target_attention_mask
        )
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{10} - Loss: {epoch_loss/len(dataloader)}")

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

NameError: name 'AdamW' is not defined

In [None]:
model.eval()
with torch.no_grad():
    for batch in dataloader:
        ocr_text_input_ids = batch['ocr_text_input_ids']
        ocr_text_attention_mask = batch['ocr_text_attention_mask']
        image = batch['image']

        generated_ids = model.decoder.generate(
            input_ids=None, 
            encoder_outputs=(model.fc(torch.cat(
                (model.text_encoder(ocr_text_input_ids, attention_mask=ocr_text_attention_mask).last_hidden_state[:, 0, :],
                 model.vision_encoder(pixel_values=image).last_hidden_state[:, 0, :]), dim=1)).unsqueeze(1),),
            max_length=150
        )

        generated_explanation = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        print("Generated Explanation:", generated_explanation)


In [None]:
# Save the model
torch.save(model.state_dict(), 'multimodal_explanation_model.pth')

# Load the model
model.load_state_dict(torch.load('multimodal_explanation_model.pth'))
