## MMTB

## Import

In [1]:
import glob
import json
import logging
import os
import random

from PIL import Image, ImageFile

import numpy as np
import torch

from sklearn.metrics import f1_score
from sklearn.metrics import precision_recall_fscore_support
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from memotion_utility import load_data

import transformers
from transformers import (
    WEIGHTS_NAME,
    AdamW,
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    MMBTConfig,
    MMBTForClassification,
    get_linear_schedule_with_warmup,
)
from transformers.trainer_utils import is_main_process

## Configs

In [2]:
ImageFile.LOAD_TRUNCATED_IMAGES = True
CSV_FILE = '/kaggle/input/memotion-dataset-7k/memotion_dataset_7k/labels.csv'
ROOT_DIR = '/kaggle/input/memotion-dataset-7k/memotion_dataset_7k/images'
MAX_LEN = 512
LR = 1e-6
batch_size = 16
IMAGE_SIZE = (224,224)
epochs = 5
NUM_WARMUP_STEPS = 12
NUM_TRAINING_STEPS = 1230
downsample = True
caption = None
tokenizer_name = 'bert-base-uncased'
num_image_embeds = 1

## Load data

In [3]:
df_train,df_val,df_test = load_data(CSV_FILE,downsample = downsample,captions = caption)

train : 
 label
1    1953
0    1953
Name: count, dtype: int64
val : 
 label
1    343
0    217
Name: count, dtype: int64
test : 
 label
1    856
0    543
Name: count, dtype: int64


In [4]:
import json
import os
import pandas as pd
from collections import Counter

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch import nn
from torch.utils.data import Dataset



POOLING_BREAKDOWN = {1: (1, 1), 2: (2, 1), 3: (3, 1), 4: (2, 2), 5: (5, 1), 6: (3, 2), 7: (7, 1), 8: (4, 2), 9: (3, 3)}


class ImageEncoder(nn.Module):
    def __init__(self, num_image_embeds):
        super().__init__()
        model = torchvision.models.resnet152(pretrained=True)
        modules = list(model.children())[:-2]
        self.model = nn.Sequential(*modules)
        self.pool = nn.AdaptiveAvgPool2d(POOLING_BREAKDOWN[num_image_embeds])

    def forward(self, x):
        # Bx3x224x224 -> Bx2048x7x7 -> Bx2048xN -> BxNx2048
        out = self.pool(self.model(x))
        out = torch.flatten(out, start_dim=2)
        out = out.transpose(1, 2).contiguous()
        return out  # BxNx2048


class HatefulMemesData(Dataset):
    def __init__(self, df, tokenizer, transforms, max_seq_length, data_dir):
        self.df = df
        self.tokenizer = tokenizer
        self.transforms = transforms
        self.max_seq_length = max_seq_length
        self.data_dir = data_dir


    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        
        example = self.df.iloc[index]
        sentence = torch.LongTensor(self.tokenizer.encode(example["text"], add_special_tokens=True))
        start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1]
        sentence = sentence[:self.max_seq_length]

        label = torch.tensor(example["label"], dtype=torch.float)

        image = Image.open(os.path.join(self.data_dir, example["image_name"])).convert("RGB")
        image = self.transforms(image)
        
        inputs = {
            "modal_start_tokens": start_token,
            "modal_end_tokens": end_token,
            "input_ids": sentence,
            "input_modal":image,
            "labels": label,
        }

#         if self.print_text:
#             for k in inputs.keys():
#                 print(k, inputs[k].shape, inputs[k].dtype)

        return inputs

    def get_label_frequencies(self):
        label_freqs = Counter()
        for label in self.data["label"]:
            label_freqs.update([label])
        return label_freqs


def collate_fn(batch):
    #print(batch)
    lens = [len(row["input_ids"]) for row in batch]
    bsz, max_seq_len = len(batch), max(lens)

    mask_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long)
    text_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long)

    for i_batch, (input_row, length) in enumerate(zip(batch, lens)):
        text_tensor[i_batch, :length] = input_row["input_ids"]
        mask_tensor[i_batch, :length] = 1

    img_tensor = torch.stack([row["input_modal"] for row in batch])
    tgt_tensor = torch.stack([row["labels"] for row in batch])
    img_start_token = torch.stack([row["modal_start_tokens"] for row in batch])
    img_end_token = torch.stack([row["modal_end_tokens"] for row in batch])

    return text_tensor, mask_tensor, img_tensor, img_start_token, img_end_token, tgt_tensor


def get_labels():
    return [
        "not offensive",
        "offensive"
    ]


def get_image_transforms():
    return transforms.Compose(
        [
            transforms.Resize(IMAGE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.46777044, 0.44531429, 0.40661017],
                std=[0.12221994, 0.12145835, 0.14380469],
            ),
        ]
    )

In [5]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [6]:
transforms = get_image_transforms()
dataset_train = HatefulMemesData(df_train, tokenizer, transforms, MAX_LEN,ROOT_DIR)
dataset_val = HatefulMemesData(df_val, tokenizer, transforms, MAX_LEN,ROOT_DIR)
dataset_test = HatefulMemesData(df_test, tokenizer, transforms, MAX_LEN,ROOT_DIR)

In [7]:
dataset_train[0]

{'modal_start_tokens': tensor(101),
 'modal_end_tokens': tensor(102),
 'input_ids': tensor([2292, 2053, 2028, 2425, 2017, 2017, 2024, 4895, 5714, 6442, 4630]),
 'input_modal': tensor([[[-3.7631, -3.7631, -3.7631,  ..., -2.3192, -2.3192, -2.2871],
          [-3.7631, -3.7631, -3.7631,  ..., -2.3192, -2.3192, -2.2551],
          [-3.7631, -3.7631, -3.7631,  ..., -2.2871, -2.2871, -2.2551],
          ...,
          [-3.7952, -3.7952, -3.7952,  ..., -3.0251, -2.9610, -2.9930],
          [-3.7952, -3.7952, -3.7952,  ..., -3.0572, -2.9610, -3.0251],
          [-3.7952, -3.7952, -3.7952,  ..., -3.1535, -3.0893, -3.1856]],
 
         [[-3.6018, -3.6018, -3.6018,  ..., -2.9238, -2.9238, -2.9238],
          [-3.6018, -3.6018, -3.6018,  ..., -2.8915, -2.8915, -2.8915],
          [-3.6018, -3.6018, -3.6018,  ..., -2.8915, -2.8592, -2.8915],
          ...,
          [-3.6341, -3.6341, -3.6341,  ..., -3.0529, -3.0852, -3.0852],
          [-3.6341, -3.6341, -3.6341,  ..., -3.0852, -3.0852, -3.1175],


In [8]:
transformer_config = AutoConfig.from_pretrained(tokenizer_name)
transformer = AutoModel.from_pretrained(tokenizer_name)
img_encoder = ImageEncoder(num_image_embeds)

config = MMBTConfig(transformer_config, num_labels = 2)
model = MMBTForClassification(config, transformer, img_encoder)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to /root/.cache/torch/hub/checkpoints/resnet152-394f9c45.pth
100%|██████████| 230M/230M [00:02<00:00, 92.6MB/s]


In [9]:
print(MMBTForClassification.__doc__)


    MMBT Model with a sequence classification/regression head on top (a linear layer on top of the pooled output)
    
    MMBT model was proposed in [Supervised Multimodal Bitransformers for Classifying Images and
    Text](https://github.com/facebookresearch/mmbt) by Douwe Kiela, Suvrat Bhooshan, Hamed Firooz, Davide Testuggine.
    It's a supervised multimodal bitransformer model that fuses information from text and other image encoders, and
    obtain state-of-the-art performance on various multimodal classification benchmark tasks.

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter re

In [10]:
# Import necessary libraries
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup

# Create data loaders
train_sampler = RandomSampler(dataset_train)
train_dataloader = DataLoader(dataset_train, sampler=train_sampler, batch_size=batch_size, collate_fn=collate_fn)

eval_sampler = SequentialSampler(dataset_val)
eval_dataloader = DataLoader(dataset_val, sampler=eval_sampler, batch_size=batch_size, collate_fn=collate_fn)

# Initialize optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=0.01)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=NUM_WARMUP_STEPS, num_training_steps=NUM_TRAINING_STEPS)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.BCEWithLogitsLoss()

#model.config.use_return_dict = False

# Training loop
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(train_dataloader):
        # Unpack the batch and move to device
        batch = tuple(t.to(device) for t in batch)
        labels = batch[5].to(torch.int64)
        inputs = {
            "input_ids": batch[0],
            "input_modal": batch[2],
            "attention_mask": batch[1],
            "modal_start_tokens": batch[3],
            "modal_end_tokens": batch[4],
            'return_dict' : True
        }
        outputs = model(**inputs)
       
        logits = outputs[0]  # model outputs are always tuple in transformers (see doc)
        labels = nn.functional.one_hot(labels, num_classes=2).to(torch.float32)
        loss = criterion(logits, labels)
        total_loss += loss
        
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
    
    
    avg_train_loss = total_loss / len(train_dataloader)
    print(f"Average training loss for epoch {epoch + 1}: {avg_train_loss}")
    
    # Evaluation loop
    model.eval()
    eval_loss = 0
    correct_predictions = 0
    total_predictions = 0
    
    all_labels = []
    all_predictions = []
    
    with torch.no_grad():
        for batch in eval_dataloader:
            # Unpack the batch and move to device
            batch = tuple(t.to(device) for t in batch)
            labels = batch[5].to(torch.int64)
            inputs = {
                "input_ids": batch[0],
                "input_modal": batch[2],
                "attention_mask": batch[1],
                "modal_start_tokens": batch[3],
                "modal_end_tokens": batch[4],
                'return_dict' : True
            }
            
            # Forward pass
            outputs = model(**inputs)
            logits = outputs[0]
            labels_ohe = nn.functional.one_hot(labels, num_classes=2).to(torch.float32)
            loss = criterion(logits, labels_ohe)
            eval_loss += loss
            
           # Store the predictions and true labels
            predictions = torch.argmax(logits, dim=-1)
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
            # Calculate the number of correct predictions
            correct_predictions += (predictions == labels).sum().item()
            total_predictions += labels.size(0)
    
        # Calculate average evaluation loss
        avg_eval_loss = eval_loss / len(eval_dataloader)

        # Calculate accuracy
        accuracy = correct_predictions / total_predictions

        # Convert lists to numpy arrays
        all_predictions = np.array(all_predictions)
        all_labels = np.array(all_labels)

        # Calculate precision, recall, and F1 score for each class and macro average
        precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_predictions, average=None)
        precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(all_labels, all_predictions, average='macro')

        # Print results
        print(f"Average evaluation loss for epoch {epoch + 1}: {avg_eval_loss}")
        print(f"Evaluation accuracy for epoch {epoch + 1}: {accuracy}")
        print(f"Precision for each class: {precision}")
        print(f"Recall for each class: {recall}")
        print(f"F1 score for each class: {f1}")
        print(f"Macro average precision: {precision_macro}")
        print(f"Macro average recall: {recall_macro}")
        print(f"Macro average F1 score: {f1_macro}")

print("Training complete.")



Average training loss for epoch 1: 0.7004415392875671
Average evaluation loss for epoch 1: 0.6925012469291687
Evaluation accuracy for epoch 1: 0.525
Precision for each class: [0.37563452 0.60606061]
Recall for each class: [0.34101382 0.64139942]
F1 score for each class: [0.35748792 0.62322946]
Macro average precision: 0.4908475619135518
Macro average recall: 0.4912066208972068
Macro average F1 score: 0.49035869223084394




Average training loss for epoch 2: 0.6970082521438599
Average evaluation loss for epoch 2: 0.6930735111236572
Evaluation accuracy for epoch 2: 0.5089285714285714
Precision for each class: [0.384      0.60967742]
Recall for each class: [0.44239631 0.55102041]
F1 score for each class: [0.4111349  0.57886677]
Macro average precision: 0.49683870967741933
Macro average recall: 0.4967083607636603
Macro average F1 score: 0.4950008361999141




Average training loss for epoch 3: 0.6942203640937805
Average evaluation loss for epoch 3: 0.6940150856971741
Evaluation accuracy for epoch 3: 0.5214285714285715
Precision for each class: [0.40925267 0.6344086 ]
Recall for each class: [0.52995392 0.51603499]
F1 score for each class: [0.46184739 0.56913183]
Macro average precision: 0.5218306355948418
Macro average recall: 0.522994451236716
Macro average F1 score: 0.5154896111778302




Average training loss for epoch 4: 0.6946079730987549
Average evaluation loss for epoch 4: 0.6934839487075806
Evaluation accuracy for epoch 4: 0.5321428571428571
Precision for each class: [0.41509434 0.63728814]
Recall for each class: [0.50691244 0.54810496]
F1 score for each class: [0.45643154 0.58934169]
Macro average precision: 0.526191237607931
Macro average recall: 0.5275086993322675
Macro average F1 score: 0.5228866140298392




Average training loss for epoch 5: 0.6926446557044983
Average evaluation loss for epoch 5: 0.69283527135849
Evaluation accuracy for epoch 5: 0.5339285714285714
Precision for each class: [0.41603053 0.63758389]
Recall for each class: [0.50230415 0.55393586]
F1 score for each class: [0.45511482 0.59282371]
Macro average precision: 0.5268072134842974
Macro average recall: 0.5281200037618734
Macro average F1 score: 0.5239692677477454
Training complete.


In [11]:
## test 
test_sampler = SequentialSampler(dataset_test)
test_dataloader = DataLoader(dataset_test, sampler=test_sampler, batch_size=batch_size, collate_fn=collate_fn)

model.eval()
test_loss = 0
correct_predictions = 0
total_predictions = 0

all_labels = []
all_predictions = []
with torch.no_grad():
    test_loss = 0
    for batch in test_dataloader:
        # Unpack the batch and move to device
        batch = tuple(t.to(device) for t in batch)
        labels = batch[5].to(torch.int64)
        inputs = {
            "input_ids": batch[0],
            "input_modal": batch[2],
            "attention_mask": batch[1],
            "modal_start_tokens": batch[3],
            "modal_end_tokens": batch[4],
            'return_dict': True
        }
        
        # Forward pass
        outputs = model(**inputs)
        logits = outputs[0]
        labels_ohe = nn.functional.one_hot(labels, num_classes=2).to(torch.float32)
        loss = criterion(logits, labels_ohe)
        test_loss += loss
        
        # Store the predictions and true labels
        predictions = torch.argmax(logits, dim=-1)
        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

        # Calculate the number of correct predictions
        correct_predictions += (predictions == labels).sum().item()
        total_predictions += labels.size(0)

    # Calculate average test loss
    avg_test_loss = test_loss / len(test_dataloader)

    # Calculate accuracy
    accuracy = correct_predictions / total_predictions

    # Convert lists to numpy arrays
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)

    # Calculate precision, recall, and F1 score for each class and macro average
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_predictions, average=None)
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(all_labels, all_predictions, average='macro')

    # Print results
    print(f"Average test loss: {avg_test_loss}")
    print(f"Test accuracy: {accuracy}")
    print(f"Precision for each class: {precision}")
    print(f"Recall for each class: {recall}")
    print(f"F1 score for each class: {f1}")
    print(f"Macro average precision: {precision_macro}")
    print(f"Macro average recall: {recall_macro}")
    print(f"Macro average F1 score: {f1_macro}")

print("Training and evaluation complete.")



Average test loss: 0.6953418254852295
Test accuracy: 0.5089349535382416
Precision for each class: [0.398017   0.62193362]
Recall for each class: [0.5174954  0.50350467]
F1 score for each class: [0.44995997 0.55648806]
Macro average precision: 0.5099753095503803
Macro average recall: 0.5105000344228154
Macro average F1 score: 0.5032240123926126
Training and evaluation complete.
