In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install easyocr

# Test EasyOCR

In [None]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import easyocr

In [None]:
RANDOM_STATE = 8
# load dataset
dataset_path = '/content/drive/My Drive/multimodal_classifier/data/WildFireCan-MMD.csv'
dataset = pd.read_csv(dataset_path)

# drop uneeded columns
columns_to_drop2 = ['tweet_id', 'img_id', 'posted_at', 'author_id', 'author_loc', 'author_name', 'author_usrname', 'media_keys', 'urls', 'predicted_label', 'contains_personal_info']
dataset = dataset.drop(columns=columns_to_drop2)

# fix image paths
dataset['image'] = dataset['image'].apply(lambda x: x.split('\\')[7])
base_path2 = '/content/drive/My Drive/multimodal_classifier/data/images/'
dataset['image'] = dataset['image'].apply(lambda x: base_path2 + x)

# shuffle data
dataset = dataset.sample(frac=1, random_state=RANDOM_STATE).reset_index(drop=True)

In [None]:
k = 3
pth = dataset['image'][k]
txt = dataset['text'][k]

In [None]:
image = Image.open(pth).convert("RGB")

# Plot the image
plt.imshow(image)
plt.axis('off')  # Hide the axes
plt.show()

In [None]:
reader = easyocr.Reader(['en']) # this needs to run only once to load the model into memory

In [None]:
result = reader.readtext(pth)

In [None]:
str1 = ''
for i in range(len(result)):
  #print(result[i][1])
  tmp = result[i][1]
  str1 = str1 + ' ' + tmp

In [None]:
print(str1)
#print(txt)

In [None]:
print(len(str1))

# 3-head Classifier

In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
from PIL import Image
from tqdm import tqdm
import re
import matplotlib.pyplot as plt
import textwrap
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import warnings
warnings.filterwarnings("ignore", category=UserWarning, message="Palette")        #ignore 'palette images expressed in bytes' warning
warnings.filterwarnings("ignore", category=RuntimeWarning, message="os.fork()")   #ignore os.fork() multithreading warning
warnings.filterwarnings("ignore", category=UserWarning, message="Glyph")          #ignore warning about emojies missing from font for displaying predictions

from transformers import AdamW
from transformers import AutoImageProcessor, ViTModel
from transformers import RobertaTokenizer, RobertaModel

#import easyocr
import numpy as np
import torch.nn.functional as F

pd.set_option('display.max_colwidth', None)  # No truncation of text

In [None]:
class MultimodalDataset(Dataset):
    def __init__(self, text_encodings, easyocr_encodings, image_paths, labels, image_processor):
        self.text_encodings = text_encodings
        self.easyocr_encodings = easyocr_encodings
        self.image_paths = image_paths
        self.labels = labels
        self.image_processor = image_processor

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.text_encodings.items()}
        easyocr_item = {key: val[idx] for key, val in self.easyocr_encodings.items()}
        item.update({'easyocr_' + key: val for key, val in easyocr_item.items()})
        image = Image.open(self.image_paths[idx]).convert("RGB")
        image = self.image_processor(image, return_tensors="pt")['pixel_values'].squeeze(0)
        item['image'] = image
        item['label'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

    def __len__(self):
        return len(self.labels)

class MultimodalModel(nn.Module):
    def __init__(self, num_labels):
        super(MultimodalModel, self).__init__()
        self.text_model = RobertaModel.from_pretrained('roberta-base')
        self.ocr_text_model = RobertaModel.from_pretrained('roberta-base')
        self.image_model = ViTModel.from_pretrained(IMAGE_MODEL)
        self.fc = nn.Linear(self.text_model.config.hidden_size + self.ocr_text_model.config.hidden_size + self.image_model.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask, ocr_input_ids, ocr_attention_mask, images):
        # Original text
        text_outputs = self.text_model(input_ids, attention_mask=attention_mask)
        text_features = text_outputs.last_hidden_state[:, 0, :]  # Take the [CLS] token
        # OCR text
        ocr_text_outputs = self.ocr_text_model(ocr_input_ids, attention_mask=ocr_attention_mask)
        ocr_text_features = ocr_text_outputs.last_hidden_state[:, 0, :]  # Take the [CLS] token
        # Image
        image_features = self.image_model(pixel_values=images).last_hidden_state[:, 0, :]
        # Concatenate all features
        combined_features = torch.cat((text_features, ocr_text_features, image_features), dim=1)
        logits = self.fc(combined_features)
        return logits

def train(model, dataloader, optimizer, device, criterion):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for batch in tqdm(dataloader):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].squeeze().to(device)
        attention_mask = batch['attention_mask'].squeeze().to(device)
        ocr_input_ids = batch['easyocr_input_ids'].squeeze().to(device)
        ocr_attention_mask = batch['easyocr_attention_mask'].squeeze().to(device)
        images = batch['image'].to(device)
        labels = batch['label'].to(device)

        final_logits = model(input_ids, attention_mask, ocr_input_ids, ocr_attention_mask, images)
        final_loss = criterion(final_logits, labels)
        total_loss += final_loss.item()
        final_loss.backward(retain_graph=True)
        optimizer.step()

        preds = final_logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    accuracy = correct / total
    return total_loss / len(dataloader), accuracy

def evaluate(model, dataloader, device, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for batch in tqdm(dataloader):
            input_ids = batch['input_ids'].squeeze().to(device)
            attention_mask = batch['attention_mask'].squeeze().to(device)
            ocr_input_ids = batch['easyocr_input_ids'].squeeze().to(device)
            ocr_attention_mask = batch['easyocr_attention_mask'].squeeze().to(device)
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            final_logits = model(input_ids, attention_mask, ocr_input_ids, ocr_attention_mask, images)
            loss = criterion(final_logits, labels)
            total_loss += loss.item()

            preds = final_logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            all_predictions.extend(preds.cpu().numpy())
            all_targets.extend(labels.cpu().numpy())

    accuracy = correct / total
    avg_loss = total_loss / len(dataloader)
    report = classification_report(all_targets, all_predictions, target_names=label_encoder.classes_, digits=4)
    print(report)
    return all_targets, all_predictions, avg_loss, accuracy

# Scrape text from images and save

In [None]:
# Load dataset - bc_ab
dataset_path = '/content/drive/My Drive/multimodal_classifier/data/WildFireCan-MMD.csv'
dataset = pd.read_csv(dataset_path)

# drop uneeded columns
columns_to_drop2 = ['tweet_id', 'img_id', 'posted_at', 'author_id', 'author_loc', 'author_name', 'author_usrname', 'media_keys', 'urls', 'predicted_label', 'contains_personal_info']
dataset = dataset.drop(columns=columns_to_drop2)

# fix image paths
dataset['image'] = dataset['image'].apply(lambda x: x.split('\\')[7])
base_path2 = '/content/drive/My Drive/multimodal_classifier/data/images/'
dataset['image'] = dataset['image'].apply(lambda x: base_path2 + x)

# Encode labels
label_encoder = LabelEncoder()
dataset['label_encoded'] = label_encoder.fit_transform(dataset['label'])

# Initialize new column in the DataFrame
dataset['extracted_text'] = None

In [None]:
# Initialize EasyOCR reader
ocr_reader = easyocr.Reader(['en'])

# Extract text from images using EasyOCR
def extract_text(image_path):
    image = Image.open(image_path).convert("RGB")
    image_np = np.array(image)
    ocr_result = ocr_reader.readtext(image_np)
    ocr_text = ' '.join([text[1] for text in ocr_result])
    return ocr_text

# Apply OCR and save extracted text with tqdm progress bar
print("Starting text extraction...")

total_images = len(dataset)

for idx, image_path in tqdm(enumerate(dataset['image']), total=total_images, desc="Processing images"):
    extracted_text = extract_text(image_path)
    dataset.at[idx, 'extracted_text'] = extracted_text

print("Text extraction completed.")

In [None]:
dataset.to_json('/content/drive/My Drive/multimodal_classifier/data/extracted.json')

In [None]:
dataset['extracted_text']

# Check that scrape was successful

In [None]:
kfgh = pd.read_json('/content/drive/My Drive/multimodal_classifier/data/extracted.json')

In [None]:
kfgh['extracted_text'][0]

In [None]:
import matplotlib.image as mpimg

img = mpimg.imread(kfgh['image'][0])
plt.imshow(img)
plt.axis('off')
plt.show()

In [None]:
# count num of samples with extracted text
ctr = 0
for i, row in kfgh.iterrows():
  msg = row['extracted_text']
  if msg == '' or None:
    ctr += 1
print(ctr, '/', len(dataset))

In [None]:
# Step 1: Count the number of characters in each row of the 'extracted_text' column
char_counts = kfgh['extracted_text'].apply(lambda x: len(x.split())).tolist() #kfgh['extracted_text'].apply(len).tolist()

# Step 2: Calculate the average number of characters per row
average_chars = sum(char_counts) / len(char_counts)

# Step 3: Determine the maximum and minimum number of characters
max_chars = max(char_counts)
min_chars = min(char_counts)

# Store the results in a variable or list
results = {
    'average_chars': average_chars,
    'max_chars': max_chars,
    'min_chars': min_chars
}

# Step 4: Plot the distribution of the character counts
plt.figure(figsize=(10, 6))
counts, bins, patches = plt.hist(char_counts, bins=20, color='skyblue', edgecolor='black')

# Add counts above each bar
for count, bin_edge in zip(counts, bins):
    plt.text(bin_edge + (bins[1] - bins[0]) / 2, count, int(count), ha='center', va='bottom')

# Add the average, max, and min lines
plt.axvline(average_chars, color='red', linestyle='dashed', linewidth=1, label=f'Average: {average_chars:.2f}')
plt.axvline(max_chars, color='green', linestyle='dashed', linewidth=1, label=f'Max: {max_chars}')
plt.axvline(min_chars, color='orange', linestyle='dashed', linewidth=1, label=f'Min: {min_chars}')

plt.xlabel('Number of Words') #Characters
plt.ylabel('Frequency')
plt.legend(loc='upper right')
plt.show()

# Process data

In [None]:
RANDOM_STATE = 8
IMAGE_MODEL = 'google/vit-base-patch16-384'

In [None]:
# load dataset
dataset_path = '/content/drive/My Drive/multimodal_classifier/data/WildFireCan-MMD.csv'
dataset = pd.read_csv(dataset_path)

# Replace samples without extracted text with filler
for i, row in dataset.iterrows():
  msg = row['extracted_text']
  if msg == '' or None:
    dataset.at[i, 'extracted_text'] = 'Image does not contain text.'

# shuffle data
dataset = dataset.sample(frac=1, random_state=RANDOM_STATE).reset_index(drop=True)

# Split the data into train and test sets (80/20 split), stratifying by 'label'
train_df, test_df = train_test_split(dataset, test_size=0.2, random_state=RANDOM_STATE, stratify=dataset['label'])

In [None]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
image_processor = AutoImageProcessor.from_pretrained(IMAGE_MODEL)

In [None]:
# Tokenize the data
def tokenize_data(text):
    return tokenizer(text, padding='max_length', truncation=True, max_length=200, return_tensors='pt')

# Generate text encodings
text_train_encodings = tokenize_data(train_df['text'].tolist())
text_test_encodings = tokenize_data(test_df['text'].tolist())

# Generate easyocr-text encodings
easyocr_train_encodings = tokenize_data(train_df['extracted_text'].tolist())
easyocr_test_encodings = tokenize_data(test_df['extracted_text'].tolist())

# Make vars for image paths
train_imgs = train_df['image'].tolist()
test_imgs = test_df['image'].tolist()

# Encode labels
label_encoder = LabelEncoder()
train_labels = label_encoder.fit_transform(train_df['label'].values)
test_labels = label_encoder.transform(test_df['label'].values)

# Make train and test datasets
train_dataset = MultimodalDataset(text_train_encodings, easyocr_train_encodings, train_imgs, train_labels, image_processor=image_processor)
test_dataset = MultimodalDataset(text_test_encodings, easyocr_test_encodings, test_imgs, test_labels, image_processor=image_processor)

In [None]:
# Make train and test dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=12, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=12, pin_memory=True)

# Train and test

In [None]:
num_labels = len(label_encoder.classes_)
model = MultimodalModel(num_labels)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
# Directory to save the best model
model_save_path = '/content/drive/My Drive/multimodal_classifier/model/early_fusion/model_3-head'
if not os.path.exists(model_save_path):
    os.makedirs(model_save_path)
best_test_accuracy = 0.0

In [None]:
# train model
num_epochs = 15
for epoch in range(num_epochs):
    train_loss, train_accuracy = train(model, train_loader, optimizer, device, criterion)
    test_loss, test_accuracy = evaluate(model, test_loader, device, criterion)
    print(f'Epoch {epoch + 1}/{num_epochs}')
    print(f'Training Acc:  {train_accuracy:.4f}')
    print(f'Training Loss: {train_loss:.4f}')

    # Save the best model based on validation accuracy
    if test_accuracy > best_test_accuracy:
        best_test_accuracy = test_accuracy
        torch.save(model.state_dict(), os.path.join(model_save_path, 'model.bin'))
        tokenizer.save_pretrained(model_save_path)
        print(f"Best model saved with test accuracy: {best_test_accuracy:.4f}")

In [None]:
# load saved model and evaluate
model_save_path = '/content/drive/My Drive/multimodal_classifier/model/early_fusion/model_3-head'
model.load_state_dict(torch.load(os.path.join(model_save_path, 'model.bin')))

In [None]:
all_targets, all_predictions, _, _ = evaluate(model, test_loader, device, criterion)

In [None]:
# show confusion matrix
cm = confusion_matrix(all_targets, all_predictions)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=label_encoder.classes_)
disp.plot(xticks_rotation=90)