In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
from transformers import RobertaTokenizer, ViTFeatureExtractor
import torch
import torch.nn as nn
from transformers import RobertaModel, ViTModel, T5ForConditionalGeneration
import torch.optim as optim
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
class MemeDataset(Dataset):
    def __init__(self, csv_file, csv_file2, image_dir, tokenizer, feature_extractor, max_text_length=128):
        """
        Args:
            csv_file (str): Path to the CSV file with image paths and LLaVA rationales.
            csv_file2 (str): Path to the CSV file with OCR text.
            image_dir (str): Directory with all the meme images.
            tokenizer (transformers tokenizer): Tokenizer for OCR text.
            feature_extractor (transformers feature extractor): Feature extractor for images.
            max_text_length (int): Maximum length for text tokens.
        """
        self.data = pd.read_csv(csv_file)  # CSV with image paths and rationales
        self.data2 = pd.read_csv(csv_file2)  # CSV with OCR text
        self.image_dir = image_dir
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor
        self.max_text_length = max_text_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Extract relevant data using `idx`
        image_path = os.path.join(self.image_dir, str(self.data.iloc[idx]['image_path']))  # Get image path for idx
        ocr_text = str(self.data2.iloc[idx]['ocr'])  # Get OCR text for idx
        target_explanation = str(self.data.iloc[idx]['rationale'])  # Get LLaVA rationale for idx

        # Process image
        image = Image.open(image_path).convert("RGB")
        image = self.feature_extractor(images=image, return_tensors="pt")['pixel_values'].squeeze()

        # Process text (OCR text)
        text_encoding = self.tokenizer(
            ocr_text,
            padding='max_length',
            truncation=True,
            max_length=self.max_text_length,
            return_tensors="pt"
        )

        # Process target explanation (rationale)
        target_encoding = self.tokenizer(
            target_explanation,
            padding='max_length',
            truncation=True,
            max_length=self.max_text_length,
            return_tensors="pt"
        )

        return {
            'image': image,
            'ocr_text_input_ids': text_encoding['input_ids'].squeeze(),
            'ocr_text_attention_mask': text_encoding['attention_mask'].squeeze(),
            'target_ids': target_encoding['input_ids'].squeeze(),
            'target_attention_mask': target_encoding['attention_mask'].squeeze(),
        }

# Example usage:
# tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
# feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
# dataset = MemeDataset('path_to_csv.csv', 'path_to_csv2.csv', 'path_to_images/', tokenizer, feature_extractor)
# dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [8]:
class MultimodalExplanationModel(nn.Module):
    def __init__(self, text_model_name='roberta-base', vision_model_name='google/vit-base-patch16-224', t5_model_name='t5-small'):
        super(MultimodalExplanationModel, self).__init__()
        
        # Text and Vision encoders
        self.text_encoder = RobertaModel.from_pretrained(text_model_name)
        self.vision_encoder = ViTModel.from_pretrained(vision_model_name)
        
        # Decoder for generating explanations
        self.decoder = T5ForConditionalGeneration.from_pretrained(t5_model_name)

        # Linear layer to combine the two modalities
        self.fc = nn.Linear(self.text_encoder.config.hidden_size + self.vision_encoder.config.hidden_size, self.decoder.config.d_model)

    def forward(self, ocr_text_input_ids, ocr_text_attention_mask, image, target_ids=None, target_attention_mask=None):
    # Textual features
        text_outputs = self.text_encoder(input_ids=ocr_text_input_ids, attention_mask=ocr_text_attention_mask)
        text_features = text_outputs.last_hidden_state[:, 0, :]  # Use the <s> token (CLS token) for classification
        
        # Visual features
        vision_outputs = self.vision_encoder(pixel_values=image)
        vision_features = vision_outputs.last_hidden_state[:, 0, :]  # Use the [CLS] token for classification
        
        # Combine features
        combined_features = torch.cat((text_features, vision_features), dim=1)
        combined_features = self.fc(combined_features)
        
        # Repeat the combined features to create a pseudo-sequence
        repeated_features = combined_features.unsqueeze(1).repeat(1, target_ids.size(1), 1)  # Repeat along the sequence length
        
        # Generate explanation
        decoder_outputs = self.decoder(
            input_ids=target_ids,
            attention_mask=target_attention_mask,
            encoder_outputs=(repeated_features,),  # Pass the repeated features as the encoder output
            labels=target_ids
        )

        return decoder_outputs.loss, decoder_outputs.logits

    
    # def forward(self, ocr_text_input_ids, ocr_text_attention_mask, image, target_ids=None, target_attention_mask=None):
    #     # Textual features
    #     text_outputs = self.text_encoder(input_ids=ocr_text_input_ids, attention_mask=ocr_text_attention_mask)
    #     text_features = text_outputs.last_hidden_state[:, 0, :]  # Use the <s> token (CLS token) for classification
        
    #     # Visual features
    #     vision_outputs = self.vision_encoder(pixel_values=image)
    #     vision_features = vision_outputs.last_hidden_state[:, 0, :]  # Use the [CLS] token for classification
        
    #     # Combine features
    #     combined_features = torch.cat((text_features, vision_features), dim=1)
    #     combined_features = self.fc(combined_features)
        
    #     # Generate explanation
    #     decoder_outputs = self.decoder(
    #         input_ids=target_ids,
    #         attention_mask=target_attention_mask,
    #         encoder_outputs=(combined_features.unsqueeze(1),),
    #         labels=target_ids
    #     )

    #     return decoder_outputs.loss, decoder_outputs.logits

In [4]:
%pip install sentencepiece

Note: you may need to restart the kernel to use updated packages.


In [9]:
# tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
from transformers import T5Tokenizer

# Ensure the tokenizer is correctly set up
tokenizer = T5Tokenizer.from_pretrained('t5-small')

feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = MultimodalExplanationModel()
optimizer = AdamW(model.parameters(), lr=5e-5)

# Data
dataset = MemeDataset('./DATA/scripts/rationale.csv','./DATA/scripts/train.csv', './DATA', tokenizer, feature_extractor)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# for batch in tqdm(dataloader):
#         optimizer.zero_grad()
        
#         # Get data
#         print(
#         ocr_text_input_ids = batch['ocr_text_input_ids'],
#         ocr_text_attention_mask = batch['ocr_text_attention_mask'],
#         image = batch['image'],
#         target_ids = batch['target_ids'],
#         )

**sample rationale**
Generated Explanation: Target Group or Person: The meme targets a specific individual, Narendra Modi, who is a political figure. It references his actions and the implication that he is not in a position of influence. Content Evaluation: The text is potentially offensive due to its reference to a political figure's actions and the use of a political figure to satirize the situation. Context and Implications: The context is political commentary on the political climate, which may be seen as a political commentary on the political climate. Overall Assessment: The meme uses humor to comment on political issues and political

In [10]:
# Training loop
model.train()
for epoch in range(10):  # Set the number of epochs
    epoch_loss = 0
    for batch in tqdm(dataloader):
        optimizer.zero_grad()
        
        # Get data
        ocr_text_input_ids = batch['ocr_text_input_ids']
        ocr_text_attention_mask = batch['ocr_text_attention_mask']
        image = batch['image']
        target_ids = batch['target_ids']
        target_attention_mask = batch['target_attention_mask']
        
        # Forward pass
        loss, logits = model(
            ocr_text_input_ids=ocr_text_input_ids,
            ocr_text_attention_mask=ocr_text_attention_mask,
            image=image,
            target_ids=target_ids,
            target_attention_mask=target_attention_mask
        )
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{10} - Loss: {epoch_loss/len(dataloader)}")

100%|██████████| 875/875 [1:03:47<00:00,  4.37s/it]


Epoch 1/10 - Loss: 3.57098759160723


100%|██████████| 875/875 [1:06:38<00:00,  4.57s/it]


Epoch 2/10 - Loss: 2.8444256003243584


100%|██████████| 875/875 [1:02:35<00:00,  4.29s/it]


Epoch 3/10 - Loss: 2.705521892820086


100%|██████████| 875/875 [49:42<00:00,  3.41s/it]


Epoch 4/10 - Loss: 2.6137118225097655


100%|██████████| 875/875 [49:47<00:00,  3.41s/it]


Epoch 5/10 - Loss: 2.5509440326690673


100%|██████████| 875/875 [49:44<00:00,  3.41s/it]


Epoch 6/10 - Loss: 2.497531508854457


100%|██████████| 875/875 [49:38<00:00,  3.40s/it]


Epoch 7/10 - Loss: 2.4435541902269637


100%|██████████| 875/875 [49:40<00:00,  3.41s/it]


Epoch 8/10 - Loss: 2.389111866269793


100%|██████████| 875/875 [50:03<00:00,  3.43s/it]


Epoch 9/10 - Loss: 2.3292890570504325


100%|██████████| 875/875 [49:41<00:00,  3.41s/it]

Epoch 10/10 - Loss: 2.26340590436118





In [None]:
from transformers.modeling_outputs import BaseModelOutput

model.eval()
with torch.no_grad():
    for batch in dataloader:
        ocr_text_input_ids = batch['ocr_text_input_ids']
        ocr_text_attention_mask = batch['ocr_text_attention_mask']
        image = batch['image']

        # Get text and vision features
        text_features = model.text_encoder(ocr_text_input_ids, attention_mask=ocr_text_attention_mask).last_hidden_state[:, 0, :]
        vision_features = model.vision_encoder(pixel_values=image).last_hidden_state[:, 0, :]
        
        # Concatenate and pass through the linear layer
        combined_features = model.fc(torch.cat((text_features, vision_features), dim=1))
        
        # Repeat features to simulate sequence and wrap in BaseModelOutput
        repeated_features = combined_features.unsqueeze(1).repeat(1, 1, 1)
        encoder_outputs = BaseModelOutput(last_hidden_state=repeated_features)

        # Generate explanation using T5 generate method
        generated_ids = model.decoder.generate(
            input_ids=None, 
            encoder_outputs=encoder_outputs,
            max_length=150
        )

        generated_explanation = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        print("Generated Explanation:", generated_explanation)


In [13]:
# Save the model
torch.save(model.state_dict(), 'multimodal_explanation_model.pth')

# Load the model
model.load_state_dict(torch.load('multimodal_explanation_model.pth'))


<All keys matched successfully>