In [None]:
!gdown "https://drive.google.com/uc?id=1Fqd6UDB02avLsg5g9nZ5AUGzo8bysQHP"

In [None]:
!unzip "hateful_memes.zip" -d .

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import json
from transformers import AutoModel, AutoTokenizer


In [None]:
ARTEMIS_EMOTIONS = ['amusement', 'awe', 'contentment', 'excitement',
                    'anger', 'disgust',  'fear', 'sadness', 'something else']

EMOTION_TO_IDX = {e: i for i, e in enumerate(ARTEMIS_EMOTIONS)}


IDX_TO_EMOTION = {EMOTION_TO_IDX[e]: e for e in EMOTION_TO_IDX}

In [None]:
class MoodBoardDataset(Dataset):
    def __init__(self, jsonl_file, img_dir, transform=None):
        self.img_annotations = self._load_annotations(jsonl_file)
        self.img_dir = img_dir
        self.transform = transform

    def _load_annotations(self, jsonl_file):
        with open(jsonl_file, 'r') as file:
            img_annots = [json.loads(line) for line in file]
        return img_annots

    def __len__(self):
        return len(self.img_annots)

    def __getitem__(self, x):
        img_info = self.img_annots[x]
        img_path = os.path.join(self.img_dir, img_info['img'])
        image = Image.open(img_path).convert('RGB')
        emotion = img_info['emotion']
        caption = img_info['caption']
        if self.transform:
            image = self.transform(image)
        sample = {"image": image, "emotion": emotion, "caption": caption}
        return sample

In [None]:
tokenizer = AutoTokenizer.from_pretrained("ayoubkirouane/BERT-Emotions-Classifier")
def data_collate_fn(batch):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    images = [transform(sample['image']) for sample in batch]
    images = torch.stack(images, dim=0)

    label_mapping = {
        "amusement": 0,
        "awe": 1,
        "contentment": 2,
        "excitement": 3,
        "anger": 4,
        "disgust": 5,
        "fear": 6,
        "sadness": 7,
        "something else":8
    }
    labels = [label_mapping[sample['emotion']] for sample in batch]
    labels = torch.tensor(emotion, dtype=torch.long).detach()

    texts = [sample["text"] for sample in batch]

    # Tokenized in the collate fn for convenience when used in model, directly passed into forward
    tokenized_texts = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')

    return {"images": image, "input_ids": tokenized_texts['input_ids'], "labels": labels}

In [None]:
train_dataset = MoodBoardDataset(jsonl_file='data/train.jsonl', img_dir='data/')
train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True, collate_fn=custom_collate_fn)

In [None]:
val_dataset = MoodBoardDataset(jsonl_file='data/val.jsonl', img_dir='data/')
val_dataloader = DataLoader(dev_dataset, batch_size=10, shuffle=True, collate_fn=custom_collate_fn)

In [None]:
class MultiModalEarlyFusionBertModel(BertModel):
    def __init__(self, config, num_classes, cnn_model):
        super().__init__(config)

        self.cnn_model = resnet_model
        # Created f_cnn because we want to run the images through the layers before linear layers
        self.f_cnn = torch.nn.Sequential(*list(self.cnn_model.children())[:-2])
        self.image_layer= nn.Linear(512, config.hidden_size)

        self.bert_model = AutoModel.from_pretrained("ayoubkirouane/BERT-Emotions-Classifier", output_hidden_states=True)

        # self.embedding_model = torch.nn.Sequential(bert_model.embeddings)
        self.modality_embedding = nn.Embedding(num_classes, config.hidden_size)
        self.fusion_linear = nn.Linear(config.hidden_size, num_classes)



    def forward(self, input_ids=None, attention_mask=None, images=None, labels=None):

        image_embeddings = self.f_cnn(image)
        flat_image_embeddings = image_embeddings.view(image_embeddings.size(0), -1, image_embeddings.size(1))

        # Take image embeddings before the linear layer B, 7*7, 2000 and then flatten to 3d
        # with torch.no_grad():
        #   text_embeddings = self.embedding_model(input_ids)

        text_outputs = self.bert_model(**input_ids)

        # Hidden states are in the third element of the outputs tuple
        text_embeddings = outputs.hidden_states[-2]

        batch_size = image_embeddings.size(0)

        # Reshaped to make it compatible with our text embeddings
        reshaped_image_embedding = self.image_layer(flat_image_embeddings)

        image_modality_embeddings = self.modality_embedding(torch.zeros_like(reshaped_image_embedding[:, :, 0]).long().to(image[0].device))
        text_modality_embeddings = self.modality_embedding(torch.zeros_like(text_embeddings[:, :, 0]).long().to(input_ids[0].device))

        # FUSE image embeddings with modality embeddings, same for text

        fused_image_embeddings = reshaped_image_embedding + image_modality_embeddings
        text_size = text_embeddings.size(0)

        fused_text_embeddings = text_modality_embeddings + text_embeddings

        fused_image_embeddings = torch.reshape(fused_image_embeddings, (batch_size, -1, 768))
        fused_text_embeddings = torch.reshape(fused_text_embeddings, (batch_size, -1, 768))

        fused_embeddings = torch.cat((fused_image_embeddings, fused_text_embeddings), dim=1)

        pooled_embeddings, _ = torch.max(fused_embeddings, dim=1)
        # Makes it 2Dim gets rid of second dim (95)
        logits = self.fusion_linear(pooled_embeddings)
        # aggregated_logits = torch.mean(logits, dim=1)

        outputs = {'logits': logits}
        if label is not None:
          loss_fct = nn.CrossEntropyLoss()
          loss = loss_fct(logits,label.view(-1))
          outputs ['loss'] = loss
        return outputs

        # return(logits)

config = BertConfig.from_pretrained('bert-base-uncased')
resnet_model = torchvision.models.resnet18(pretrained=True)
model = MultiModalEarlyFusionBertModel(config, num_classes=9, cnn_model=resnet_model)

In [None]:
from sklearn.metrics import accuracy_score
import numpy as np
def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    accuracy = accuracy_score(p.label_ids, preds)
    print("Accuracy:", accuracy)
    return {"accuracy": accuracy}

In [None]:
training_args = TrainingArguments(
    num_train_epochs=10,
    learning_rate=5e-6,
    weight_decay=0.01,
    logging_dir='./logs',
    output_dir='/results',
    lr_scheduler_type="cosine",
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    evaluation_strategy="epoch",
    remove_unused_columns=False,
    label_names=["label"]
)

optimizer = torch.optim.AdamW(bert_model.parameters(), lr=training_args.learning_rate)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=500)

trainer = Trainer(
    model=bert_model,
    args=training_args,
    train_dataset=MoodBoardDataset(jsonl_file='data/train.jsonl', img_dir='data/'),  # Specify your training dataset
    eval_dataset=MoodBoardDataset(jsonl_file='data/dev.jsonl', img_dir='data/'),    # Specify your validation dataset
    # optimizers=(optimizer, scheduler),
    data_collator=data_collate_fn,
    compute_metrics=compute_metrics,
)

trainer.train()
trainer.evaluate()