# 1. Import Necessary Libraries

In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
from torchvision import models, transforms
from sklearn.metrics import classification_report
from PIL import Image
import pandas as pd

# 2. Dataset Preparation

In [3]:
class MemeDataset(Dataset):
    def __init__(self, df, image_folder, tokenizer, max_len, transform=None):
        self.df = df
        self.image_folder = image_folder
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = f"{self.image_folder}/{row['image_id']}.jpg"
        text = row['transcriptions']
        label = row.get('labels', -1)  # -1 for test set

        # Text Preprocessing
        text_inputs = self.tokenizer(
            text,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Image Preprocessing
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        
        return {
            'image': image,
            'text_inputs': {k: v.squeeze(0) for k, v in text_inputs.items()},
            'label': torch.tensor(label, dtype=torch.long) if label != -1 else label
        }

# 3. Model Definition

In [4]:
class MultimodalModel(nn.Module):
    def __init__(self, muril_model_name, num_labels):
        super(MultimodalModel, self).__init__()
        
        # MuRIL for Malayalam text
        self.muril_model = AutoModel.from_pretrained(muril_model_name)
        self.muril_fc = nn.Linear(self.muril_model.config.hidden_size, 256)
        
        # Image branch
        self.image_model = models.resnet50(pretrained=True)
        self.image_model.fc = nn.Linear(self.image_model.fc.in_features, 256)
        
        # Combined classifier
        self.classifier = nn.Sequential(
            nn.Linear(256 + 256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_labels)
        )

    def forward(self, text_inputs, image):
        # Text features from MuRIL
        muril_outputs = self.muril_model(**text_inputs)
        muril_features = self.muril_fc(muril_outputs.pooler_output)

        # Image features
        image_features = self.image_model(image)

        # Combine features
        combined_features = torch.cat((muril_features, image_features), dim=1)
        logits = self.classifier(combined_features)
        return logits

# 4. Training Function Definition

In [None]:
def train_model(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    for batch in dataloader:
        images = batch['image'].to(device)
        text_inputs = {k: v.to(device) for k, v in batch['text_inputs'].items()}
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        outputs = model(text_inputs, images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Collect predictions and labels for classification report
        # ------------ Code Goes Here -----------

        # ------------ Code Goes Here -----------
    
    # Classification report
    print("Train Classification Report:")
    print(classification_report(all_labels, all_preds))
    
    return total_loss / len(dataloader)

# 5. Validation Function Definition

In [None]:
def validate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            images = batch['image'].to(device)
            text_inputs = {k: v.to(device) for k, v in batch['text_inputs'].items()}
            labels = batch['label'].to(device)
            
            outputs = model(text_inputs, images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            # Collect predictions and labels for classification report
            # ------------ Code Goes Here -----------

            # ------------ Code Goes Here -----------
    
    # Classification report
    print("Validation Classification Report:")
    print(classification_report(all_labels, all_preds))
    
    return total_loss / len(dataloader)

# 6. Prediction Function Definition

In [None]:
def predict(model, dataloader, device):
    model.eval()
    predictions = []
    with torch.no_grad():
        for batch in dataloader:
            images = batch['image'].to(device)
            text_inputs = {k: v.to(device) for k, v in batch['text_inputs'].items()}
            outputs = model(text_inputs, images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            predictions.extend(preds)
    return predictions

# Main Function Definition

In [None]:
def main():
    # File paths and image folders
    
    # ------------ Code Goes Here -----------

    # ------------ Code Goes Here -----------
    
    muril_model_name = "google/muril-base-cased"
    max_len = 128
    batch_size = 16
    num_epochs = 45
    lr = 2e-5
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Data Loader
    tokenizer = AutoTokenizer.from_pretrained(muril_model_name)
    transform = transforms.Compose([
        # transforms.Resize((224, 224)),
        # transforms.ToTensor(),
        # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    train_df = pd.read_csv(train_csv)
    dev_df = pd.read_csv(dev_csv)
    test_df = pd.read_csv(test_csv)
    
    train_dataset = MemeDataset(train_df, train_image_folder, tokenizer, max_len, transform)
    dev_dataset = MemeDataset(dev_df, dev_image_folder, tokenizer, max_len, transform)
    test_dataset = MemeDataset(test_df, test_image_folder, tokenizer, max_len, transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    # Model Initialization
    model = MultimodalModel(muril_model_name, num_labels=2).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
    
    # Train and validate model
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        train_loss = train_model(model, train_loader, optimizer, criterion, device)
        val_loss = validate_model(model, dev_loader, criterion, device)
        scheduler.step(val_loss)
        print(f"Train Loss = {train_loss:.4f}, Validation Loss = {val_loss:.4f}")
    
    # Predict on test set
    predictions = predict(model, test_loader, device)
    test_df['labels'] = predictions

    # Prediction saving in the file
    # ------------ Code Goes Here -----------

    # ------------ Code Goes Here -----------

In [None]:
if __name__ == "__main__":
    main()