In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
from PIL import Image
from tqdm import tqdm
import re
import matplotlib.pyplot as plt
import textwrap
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import warnings
warnings.filterwarnings("ignore", category=UserWarning, message="Palette")        #ignore 'palette images expressed in bytes' warning
warnings.filterwarnings("ignore", category=RuntimeWarning, message="os.fork()")   #ignore os.fork() multithreading warning
warnings.filterwarnings("ignore", category=UserWarning, message="Glyph")          #ignore warning about emojies missing from font for displaying predictions

from transformers import AdamW
from transformers import AutoImageProcessor, ViTModel
from transformers import RobertaTokenizer, RobertaModel

import numpy as np
import torch.nn.functional as F

pd.set_option('display.max_colwidth', None)  # No truncation of text

In [None]:
class MultimodalDataset(Dataset):
    def __init__(self, text_encodings, image_paths, labels, image_processor):
        self.text_encodings = text_encodings
        self.image_paths = image_paths
        self.labels = labels
        self.image_processor = image_processor

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.text_encodings.items()}
        image = Image.open(self.image_paths[idx]).convert("RGB")
        image = self.image_processor(image, return_tensors="pt")['pixel_values'].squeeze(0)
        item['image'] = image
        item['label'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

    def __len__(self):
        return len(self.labels)

class MultimodalModel(nn.Module):
    def __init__(self, num_labels):
        super(MultimodalModel, self).__init__()
        self.text_model = RobertaModel.from_pretrained('roberta-base')
        self.image_model = ViTModel.from_pretrained(IMAGE_MODEL)
        # Separate classifiers for each modality
        self.text_classifier = nn.Linear(self.text_model.config.hidden_size, num_labels)
        self.image_classifier = nn.Linear(self.image_model.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask, images):
        # Original text
        text_outputs = self.text_model(input_ids, attention_mask=attention_mask)
        text_features = text_outputs.last_hidden_state[:, 0, :]  # Take the [CLS] token
        text_logits = self.text_classifier(text_features)
        # Image
        image_features = self.image_model(pixel_values=images).last_hidden_state[:, 0, :]
        image_logits = self.image_classifier(image_features)
        # Late fusion: average the logits from all arms
        logits = (text_logits + image_logits) / 2
        return logits

def train(model, dataloader, optimizer, device, criterion):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for batch in tqdm(dataloader):
        optimizer.zero_grad()
        # Move inputs to the device
        input_ids = batch['input_ids'].squeeze().to(device)
        attention_mask = batch['attention_mask'].squeeze().to(device)
        images = batch['image'].to(device)
        labels = batch['label'].to(device)
        # Get logits from the model
        logits = model(input_ids, attention_mask, images)
        # Calculate the loss
        loss = criterion(logits, labels)
        total_loss += loss.item()
        # Backpropagation
        loss.backward()
        optimizer.step()
        # Calculate predictions and accuracy
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    accuracy = correct / total
    return total_loss / len(dataloader), accuracy

def evaluate(model, dataloader, device, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_predictions = []
    all_targets = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            # Move inputs to the device
            input_ids = batch['input_ids'].squeeze().to(device)
            attention_mask = batch['attention_mask'].squeeze().to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            # Get logits from the model
            logits = model(input_ids, attention_mask, images)
            # Calculate loss
            loss = criterion(logits, labels)
            total_loss += loss.item()
            # Calculate predictions and accuracy
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            # Collect predictions and targets for reporting
            all_predictions.extend(preds.cpu().numpy())
            all_targets.extend(labels.cpu().numpy())
    accuracy = correct / total
    avg_loss = total_loss / len(dataloader)
    # Generate classification report
    report = classification_report(all_targets, all_predictions, target_names=label_encoder.classes_, digits=4)
    print(report)
    return all_targets, all_predictions, avg_loss, accuracy

# load and process data

In [None]:
RANDOM_STATE = 8
IMAGE_MODEL = 'google/vit-base-patch16-384'

In [None]:
# load dataset
dataset_path = '/content/drive/My Drive/multimodal_classifier/data/WildFireCan-MMD.csv'
dataset = pd.read_csv(dataset_path)

# Replace samples without extracted text with filler
for i, row in dataset.iterrows():
  msg = row['extracted_text']
  if msg == '' or None:
    dataset.at[i, 'extracted_text'] = 'Image does not contain text.'

# shuffle data
dataset = dataset.sample(frac=1, random_state=RANDOM_STATE).reset_index(drop=True)

# Split the data into train and test sets (80/20 split), stratifying by 'label'
train_df, test_df = train_test_split(dataset, test_size=0.2, random_state=RANDOM_STATE, stratify=dataset['label'])

In [None]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
image_processor = AutoImageProcessor.from_pretrained(IMAGE_MODEL)

In [None]:
# Tokenize the data
def tokenize_data(text):
    return tokenizer(text, padding='max_length', truncation=True, max_length=200, return_tensors='pt')

# Generate text encodings
text_train_encodings = tokenize_data(train_df['text'].tolist())
text_test_encodings = tokenize_data(test_df['text'].tolist())

# Make vars for image paths
train_imgs = train_df['image'].tolist()
test_imgs = test_df['image'].tolist()

# Encode labels
label_encoder = LabelEncoder()
train_labels = label_encoder.fit_transform(train_df['label'].values)
test_labels = label_encoder.transform(test_df['label'].values)

# Make train and test datasets
train_dataset = MultimodalDataset(text_train_encodings, train_imgs, train_labels, image_processor=image_processor)
test_dataset = MultimodalDataset(text_test_encodings, test_imgs, test_labels, image_processor=image_processor)

In [None]:
# Make train and test dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=12, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=12, pin_memory=True)

# train and test

In [None]:
num_labels = len(label_encoder.classes_)
model = MultimodalModel(num_labels)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
# Directory to save the best model
model_save_path = '/content/drive/My Drive/multimodal_classifier/model/late_fusion/model_2-head'
if not os.path.exists(model_save_path):
    os.makedirs(model_save_path)
best_test_accuracy = 0.0

In [None]:
# train model
num_epochs = 15
for epoch in range(num_epochs):
    train_loss, train_accuracy = train(model, train_loader, optimizer, device, criterion)
    all_targets, all_predictions, test_loss, test_accuracy = evaluate(model, test_loader, device, criterion)
    print(f'Epoch {epoch + 1}/{num_epochs}')
    print(f'Training Acc:  {train_accuracy:.4f}')
    print(f'Training Loss: {train_loss:.4f}')

    # Save the best model based on validation accuracy
    if test_accuracy > best_test_accuracy:
        best_test_accuracy = test_accuracy
        torch.save(model.state_dict(), os.path.join(model_save_path, 'model.bin'))
        tokenizer.save_pretrained(model_save_path)
        print(f"Best model saved with test accuracy: {best_test_accuracy:.4f}")

In [None]:
# load saved model
model_save_path = '/content/drive/My Drive/multimodal_classifier/model/late_fusion/model_2-head'
model.load_state_dict(torch.load(os.path.join(model_save_path, 'model.bin')))

In [None]:
all_targets, all_predictions, _, _ = evaluate(model, test_loader, device, criterion)

In [None]:
# show confusion matrix
cm = confusion_matrix(all_targets, all_predictions)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=label_encoder.classes_)
disp.plot(xticks_rotation=90)