<a href="https://colab.research.google.com/github/abdulkadir-erol/multimodal-toxicity-detection/blob/main/CLIPvBridgeTower.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers
!pip install torcheval

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import json
from PIL import Image
import logging
from transformers import CLIPProcessor, CLIPModel, BridgeTowerProcessor, BridgeTowerModel
from tqdm import tqdm
import random
import time
from numpy.linalg import norm
import torch.nn.functional as F

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [None]:
def collate_fn(batch):
  text_input_ids = [sample["text"]["input_ids"] for sample in batch]
  text_attention_mask = [sample["text"]["attention_mask"] for sample in batch]
  for i in range(len(text_input_ids)):
    text_input_ids[i] = text_input_ids[i].squeeze(0)
    text_attention_mask[i] = text_attention_mask[i].squeeze(0)

  max_text_length = max(len(input_ids) for input_ids in text_input_ids)
  padded_input_ids = pad_sequence([torch.cat([input_ids, torch.zeros(max_text_length - len(input_ids), dtype=torch.long)]) for input_ids in text_input_ids], batch_first=True, padding_value=0)
  padded_attention_mask = pad_sequence([torch.cat([input_ids, torch.zeros(max_text_length - len(input_ids), dtype=torch.long)]) for input_ids in text_attention_mask], batch_first=True, padding_value=0)

  return {
    "text": {
      "input_ids": padded_input_ids,
      "attention_mask": padded_attention_mask
    },
    "image": torch.stack([sample["image"] for sample in batch]),
    "label": torch.tensor([sample["label"] for sample in batch])
  }

In [None]:
class CustomDataset(Dataset):
    def __init__(self, json_file, model, processor, max_length):
      self.data = self.load_data(json_file)
      self.max_length = max_length
      self.clip = model
      self.processor = processor

    def __len__(self):
      return len(self.data)

    def __getitem__(self, idx):
      item = self.data[idx]
      label = item["label"]
      text = item["text"]
      if len(text) > 77:
        text = text[0:77]
      image_id = item["img"]
      image = "/content/gdrive/MyDrive/GSU - Research/Emergent Abilities/dataset/facebook_hateful_memes/" + image_id
      img = Image.open(image)

      tokens = self.processor(
          text=text,
          padding='longest',
          images=None,
          return_tensors='pt',
      )

      images = self.processor(
          text=None,
          images=img,
          return_tensors='pt'
      )['pixel_values']

      return {
            'image': images.squeeze(0),
            'text': tokens,
            'label': label
        }

    def load_data(self, json_file):
      with open(json_file, "r") as f:
          data = [json.loads(line) for line in f]
      return data

In [None]:
model_id = 'openai/clip-vit-base-patch32'

In [None]:
train_dataset = CustomDataset(json_file='/content/gdrive/MyDrive/GSU - Research/Emergent Abilities/dataset/facebook_hateful_memes/train.jsonl', model=CLIPModel.from_pretrained(model_id), processor=CLIPProcessor.from_pretrained(model_id), max_length=77)
dev_dataset = CustomDataset(json_file='/content/gdrive/MyDrive/GSU - Research/Emergent Abilities/dataset/facebook_hateful_memes/dev_seen.jsonl', model=CLIPModel.from_pretrained(model_id), processor=CLIPProcessor.from_pretrained(model_id), max_length=77)
test_dataset = CustomDataset(json_file='/content/gdrive/MyDrive/GSU - Research/Emergent Abilities/dataset/facebook_hateful_memes/test_seen.jsonl', model=CLIPModel.from_pretrained(model_id), processor=CLIPProcessor.from_pretrained(model_id), max_length=77)

BATCH_SIZE = 32
train_loader = DataLoader(train_dataset , batch_size = BATCH_SIZE, shuffle = True, collate_fn=collate_fn)
val_loader = DataLoader(dev_dataset , batch_size = BATCH_SIZE, shuffle = False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset , batch_size = BATCH_SIZE, shuffle = False, collate_fn=collate_fn)

In [None]:
class FusionNet(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(FusionNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.sigmoid = nn.Sigmoid()

    def forward(self, emb1, emb2):
        mean_emb = torch.mean(torch.stack([emb1, emb2]), dim=0)
        x = self.fc1(mean_emb)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x

class CustomCLIPModel(nn.Module):
    def __init__(self):
        super(CustomCLIPModel, self).__init__()
        self.clip = CLIPModel.from_pretrained(model_id)
        self.fusion_net = FusionNet(input_size=512, hidden_size=128)

    def forward(self, image_data, text_data):
        text_data['input_ids'] = text_data['input_ids'].type(torch.LongTensor)
        text_emb = self.clip.get_text_features(input_ids=text_data['input_ids'], attention_mask=text_data['attention_mask'])
        img_emb = self.clip.get_image_features(image_data)

        fused_probs = self.fusion_net(text_emb, img_emb)
        fused_probs = fused_probs.squeeze(dim=1)
        return fused_probs

In [None]:
model = CustomCLIPModel()
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
criterion = nn.BCELoss()

In [None]:
patience = 3
best_val_loss = float('inf')
counter = 0
early_stop = False

In [None]:
def train(model, train_loader, val_loader, optimizer, criterion, epochs=10):
    global best_val_loss, counter, early_stop

    for epoch in range(epochs):
        if not early_stop:
            model.train()
            train_loss = 0
            pbar = tqdm(train_loader, total=len(train_loader))
            for batch in pbar:

                optimizer.zero_grad()

                text_data = batch['text']
                image_data = batch['image']
                labels = batch['label']

                text_data = {key: value.to(device) for key, value in text_data.items()}
                image_data = image_data.to(device)
                labels = labels.to(device)

                output = model(image_data, text_data)

                output = output.type(torch.FloatTensor)
                labels = labels.type(torch.FloatTensor)

                loss = criterion(output, labels)
                loss.backward()

                optimizer.step()
                train_loss += loss.item()

            train_loss /= len(train_loader)
            #wandb.log({"epoch": epoch+1, "loss": train_loss})
            logging.info(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {train_loss:.4f}")

            # Validation
            model.eval()
            with torch.no_grad():
                val_loss = 0
                correct = 0
                total = 0
                pbar2 = tqdm(val_loader, total=len(val_loader))
                for batch in pbar2:

                    text_data = batch['text']
                    image_data = batch['image']
                    labels = batch['label']

                    text_data = {key: value.to(device) for key, value in text_data.items()}
                    image_data = image_data.to(device)
                    labels = labels.to(device)

                    output = model(image_data, text_data)

                    output = output.type(torch.FloatTensor)
                    labels = labels.type(torch.FloatTensor)

                    val_loss += criterion(output, labels).item()
                    predicted = (output > 0.5).float()
                    correct += (predicted == labels).sum().item()

                    total += labels.size(0)
                val_loss /= len(val_loader)
                accuracy = 100 * correct / total
                print("correct",correct)
                print("total", total)
                logging.info(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.2f}%")

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    counter = 0
                else:
                    counter += 1
                    if counter >= patience:
                        early_stop = True
                        logging.info("Early stopping")
                        break

In [None]:
# Training
train(model, train_loader, val_loader, optimizer, criterion)

# Testing
model.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
    pbar = tqdm(test_loader, total=len(test_loader))
    for batch in pbar:

        text_data = batch['text']
        image_data = batch['image']
        labels = batch['label']

        text_data = {key: value.to(device) for key, value in text_data.items()}
        image_data = image_data.to(device)
        labels = labels.to(device)

        output = model(image_data, text_data)

        output = output.type(torch.FloatTensor)
        labels = labels.type(torch.FloatTensor)

        test_loss += criterion(output, labels).item()
        predicted = (output > 0.5).float()  # Adjust threshold for binary classification
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

test_loss /= len(test_loader)
test_accuracy = 100 * correct / total
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

In [None]:
class BridgeTowerHead(nn.Module):
    def __init__(self):
        super().__init__()

        self.linear_stack = nn.Sequential(
            nn.Linear(1536, 512),
            nn.ReLU(),
            nn.Linear(512, 2),
        )

    def forward(self, x):
        logits = self.linear_stack(x)
        return logits

In [None]:
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base")
model = BridgeTowerModel.from_pretrained("BridgeTower/bridgetower-base")

model.to(device)

In [None]:
train_dataset = CustomDataset(json_file='/content/gdrive/MyDrive/GSU - Research/Emergent Abilities/dataset/facebook_hateful_memes/train.jsonl', model=model, processor=processor, max_length=77)
dev_dataset = CustomDataset(json_file='/content/gdrive/MyDrive/GSU - Research/Emergent Abilities/dataset/facebook_hateful_memes/dev_seen.jsonl', model=model, processor=processor, max_length=77)
test_dataset = CustomDataset(json_file='/content/gdrive/MyDrive/GSU - Research/Emergent Abilities/dataset/facebook_hateful_memes/test_seen.jsonl', model=model, processor=processor, max_length=77)

BATCH_SIZE = 32
train_loader = DataLoader(train_dataset , batch_size = BATCH_SIZE, shuffle = True, collate_fn=collate_fn)
val_loader = DataLoader(dev_dataset , batch_size = BATCH_SIZE, shuffle = False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset , batch_size = BATCH_SIZE, shuffle = False, collate_fn=collate_fn)

In [None]:
head = BridgeTowerHead()
head.to(device)

lr = 1e-4
weight_decay = 1e-3
epochs = 10

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(head.parameters(), lr=lr, weight_decay=weight_decay)

In [None]:
# Training
train(model, train_loader, val_loader, optimizer, criterion)

# Testing
model.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
    pbar = tqdm(test_loader, total=len(test_loader))
    for batch in pbar:

        text_data = batch['text']
        image_data = batch['image']
        labels = batch['label']

        text_data = {key: value.to(device) for key, value in text_data.items()}
        image_data = image_data.to(device)
        labels = labels.to(device)

        output = model(image_data, text_data)

        output = output.type(torch.FloatTensor)
        labels = labels.type(torch.FloatTensor)

        test_loss += criterion(output, labels).item()
        predicted = (output > 0.5).float()  # Adjust threshold for binary classification
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

test_loss /= len(test_loader)
test_accuracy = 100 * correct / total
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")