In [48]:
import torch
import torch.nn as nn
from transformers import RobertaModel, RobertaTokenizer, ViTModel
import pandas as pd

In [49]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

class HatefulMemeDataset(Dataset):
    def __init__(self, texts, images, rationales, img_transform=None, max_length=512):
        self.texts = texts
        self.images = images
        self.rationales = rationales
        # self.tokenizer = tokenizer
        self.img_transform = img_transform
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        rationale = self.rationales[idx]
        img = Image.open(self.images[idx])

        if self.img_transform:
            img = self.img_transform(img)

        # Tokenize text and rationale
        # text_tokenized = self.tokenizer(text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt')
        # rationale_tokenized = self.tokenizer(rationale, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt')

        # Return items
        return {
            'text': text,
            'image': img,
            'rationale': rationale
        }

def create_trainloader(texts, images, rationales, batch_size=8, max_length=512):
    img_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])


    dataset = HatefulMemeDataset(texts, images, rationales, img_transform, max_length)
    trainloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return trainloader


# Example usage:
# texts = ["Meme text 1", "Meme text 2", ...]
# images = ["path/to/image1.jpg", "path/to/image2.jpg", ...]
# rationales = ["Rationale for meme 1", "Rationale for meme 2", ...]
# tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
# trainloader = create_trainloader(texts, images, rationales, tokenizer)


In [50]:
df1 = pd.read_csv('train.csv')
df2 = pd.read_csv('rationale.csv')

texts = df1['ocr'][:50].tolist()
images = df2['image_path'][:50].tolist()
rationales = df2['rationale'][:50].to_list()

# tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
trainloader = create_trainloader(texts, images, rationales)

In [51]:
for batch in trainloader:
    print(batch)


{'text': ['✓\nTweet\nNot an IlTian\n@Not_an_IITian\n@ShaadiDotCom\nloses two of its\ncustomer, every time it fulfills its\npurpose of existence.\n7:17 PM Dec 12, 2020 Twitter Web App\n000', 'Met mees\ndash\ndewan s\nWhen everyone is talking about lockdown 4\nAnd you are mentally prepared for lockdown 10', "Discovering something that doesn't exist\nPeople\nmassively\nimmigrating\nto socialist\ncountries", 'urganabedanaxtpras\nSTATES\nKamala Harris\n@KamalaHarris\nWe did it, @JoeBiden.', 'The EU\nPIXAR\nProblem with\nInmigration\nRols in France\nMomes\nBreat\nComption of\nMembers of\nthe EU\nAre article 13 memes dead yet?', 'Allerhang\nMujhe exam nahi dena hai', 'Aisi Koi Jagah Nahi\nJahan Me Gaya Nahi\nBhai Kabhi Sasural Gaye Ho?', 'TOI\nwmein car\n(Acur gyehr ram poal\nOur other'], 'image': tensor([[[[ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  2.2489,  2.2489],
          [ 2.2489,  2.2489,  2.2489,  ...,  2.2489,  

In [52]:
class MultimodalHatefulMemeClassifier(nn.Module):
    def __init__(self, text_hidden_dim=768, img_hidden_dim=768, decoder_hidden_dim=512, vocab_size=50265):
        super(MultimodalHatefulMemeClassifier, self).__init__()

        # Text Encoder (RoBERTa)
        self.text_encoder = RobertaModel.from_pretrained('roberta-base')
        self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

        # Image Encoder (Vision Transformer - ViT)
        self.img_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224')

        # Fusion layer
        self.fusion = nn.Linear(text_hidden_dim + img_hidden_dim, decoder_hidden_dim)

        # Classification Head
        self.classification_head = nn.Linear(decoder_hidden_dim, 2)  # Binary classification: hateful or not

        # Decoder (Transformer Decoder for Reasoning)
        self.decoder = nn.TransformerDecoderLayer(d_model=decoder_hidden_dim, nhead=8)
        self.fc_out = nn.Linear(decoder_hidden_dim, vocab_size)

    def encode_text(self, text):
        inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
        outputs = self.text_encoder(**inputs)
        return outputs.last_hidden_state[:, 0, :]  # Take the [CLS] token representation

    def encode_image(self, image):
        outputs = self.img_encoder(pixel_values=image)
        return outputs.last_hidden_state[:, 0, :]  # Take the [CLS] token representation

    def generate_rationale(self, fused_features, max_length=50):
        # Initialize the target sequence with the start token
        tgt_seq = torch.tensor([[self.tokenizer.cls_token_id]])  # <s> token

        rationale_output = []

        for _ in range(max_length):
            tgt_emb = self.fc_out(tgt_seq)  # Embed the current sequence
            output = self.decoder(tgt_emb, fused_features)
            output = self.fc_out(output)

            # Get the last token's logits and predict the next token
            next_token_logits = output[:, -1, :]
            next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)

            # Append predicted token to rationale output
            rationale_output.append(next_token_id.item())

            # Break if the end token is generated
            if next_token_id == self.tokenizer.eos_token_id:
                break

            # Append the predicted token to the tgt_seq for the next iteration
            tgt_seq = torch.cat((tgt_seq, next_token_id.unsqueeze(0)), dim=1)

        return torch.tensor(rationale_output).unsqueeze(0)  # Return as a tensor

    def forward(self, text_features, image_features, ground_truth_rationale=None, max_length=50):
        # Encode text and image
    
        # Fuse the features
        fused_features = torch.cat((text_features, image_features), dim=1)
        fused_features = self.fusion(fused_features).unsqueeze(0)

        # Classification output (hateful or non-hateful)
        classification_output = self.classification_head(fused_features)

        # Generate rationale
        generated_rationale = self.generate_rationale(fused_features, max_length=max_length)

        # If ground truth rationale is provided, calculate the loss
        rationale_loss = None
        if ground_truth_rationale is not None:
            criterion = nn.CrossEntropyLoss()
            rationale_loss = criterion(generated_rationale.view(-1), ground_truth_rationale.view(-1))

        return classification_output, generated_rationale, rationale_loss

In [53]:
import torch.optim as optim 

def train_model(model, trainloader, num_epochs=10, learning_rate=1e-4, device='cuda'):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion_classification = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for batch in trainloader:
            # Ensure that all inputs are moved to the specified device
            text = batch['text']
            image = batch['image']
            rationale = batch['rationale']
            
            text_features = encode_text(text)
            image_features = self.encode_image(image)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            classification_output, generated_rationale, rationale_loss = model(text, image, rationale)

            # Classification loss
            classification_labels = torch.zeros(classification_output.size(0), dtype=torch.long).to(device)
            classification_loss = criterion_classification(classification_output, classification_labels)

            # Total loss
            total_loss = classification_loss + rationale_loss

            # Backward pass and optimization
            total_loss.backward()
            optimizer.step()

            # Print statistics
            running_loss += total_loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(trainloader):.4f}")

    print("Training completed.")


In [None]:
# Example usage:

# model = MultimodalHatefulMemeClassifier()
# classification_output, generated_rationale, rationale_loss = model(text, image, ground_truth_rationale)
# print(classification_output.shape, generated_rationale.shape, rationale_loss)