# 1. Import Libraries and Load Dataset

In [1]:
%%capture
!pip install gdown

In [2]:
import json
import gdown
train_json = json.load(open('/kaggle/input/dsc24-vimmsd/vimmsd-train.json', encoding='utf-8'))
dev_json = json.load(open('/kaggle/input/dsc24-vimmsd/vimmsd-public-test.json', encoding='utf-8'))

In [48]:
import torch
from torch.nn.functional import normalize
from tqdm.notebook import tqdm

In [3]:

visual_embeds = torch.load('/kaggle/input/dsc-visual-embeddings/visual_embeds.pt')
# img_w = torch.load('/kaggle/input/lovecat-beitv2-b-p/beitv2-b-p.pt') # already-saved features
len(visual_embeds)

12218

In [4]:
visual_embeds['ac7931bb887ad853b41675f07595bf04469970d1b099ffc8806a4ceaac7d7940.jpg'].shape

torch.Size([100, 1024])

# Config

In [61]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from transformers import BertTokenizer, AdamW, get_linear_schedule_with_warmup
from transformers import VisualBertModel, VisualBertConfig

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score

In [6]:
class Config:
   def __init__(self, random_seed = 42, max_len = 450, n_epochs = 2, batch_size = 8, lrate=2.5e-5,
                n_warmup_steps=400, warmup_ratio=0.05,
                visual_embedding_dim=1024, visual_model_name = 'uclanlp/visualbert-vqa-coco-pre',
                classes = ['not-sarcasm', 'text-sarcasm', 'image-sarcasm', 'multi-sarcasm']):
       self.random_seed = random_seed
       self.max_len = max_len
       self.n_classes = len(classes)
       self.classes = classes
       self.n_epochs = n_epochs
       self.batch_size = batch_size
       self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
       self.n_warmup_steps = n_warmup_steps
       self.warmup_ratio = warmup_ratio
       self.visual_embedding_dim = visual_embedding_dim
       self.visual_model_name = visual_model_name
       self.n_training_steps = n_epochs * 9724 // batch_size
       self.lrate = lrate

In [7]:
import random

import numpy as np

import torch

def set_SEED():
    SEED = 42
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
set_SEED()

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Dataset

In [41]:
class MemesDataset(Dataset):
    '''Wrap the tokenization process in a PyTorch Dataset, along with converting the labels to tensors'''
    def __init__(self, data: pd.DataFrame, tokenizer: BertTokenizer, config: Config, visual_embeds):
        self.tokenizer = tokenizer
        self.data = data
        self.max_len = config.max_len
        self.visual_embeds = visual_embeds
        self.classes = config.classes

        # One-hot encode the labels
        for class_name in self.classes:
            self.data[class_name] = self.data['label'].apply(lambda x: 1 if x == class_name else 0)
        self.data[self.classes] = self.data[self.classes].astype('int8')
        print(self.data[self.classes].head())
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):

        data_row = self.data.iloc[index]
        text = data_row.caption
        labels = data_row[self.classes].values.astype(int)
        image_id = data_row['image_id']

        tokens = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )

        input_ids = tokens["input_ids"].flatten()
        attention_mask = tokens["attention_mask"].flatten()

        visual_embedding = self.visual_embeds[image_id].to('cpu')
        visual_attention_mask = torch.ones(visual_embedding.shape[:-1], dtype=torch.float)
        visual_token_type_ids = torch.ones(visual_embedding.shape[:-1], dtype=torch.long)

        return dict(
            input_ids=input_ids,
            attention_mask=attention_mask,
            visual_embedding=visual_embedding,
            visual_attention_mask=visual_attention_mask,
            visual_token_type_ids=visual_token_type_ids,
            labels=torch.tensor(labels, dtype=torch.long)
        )

## Prepare datasets

In [11]:
emoji_file_path = '/kaggle/input/datasets-preprocesing/emoji_to_vietnamese.json'
stopword_path = '/kaggle/input/datasets-preprocesing/vietnamese-stopwords.txt'

def load_resources(stopword_path, emoji_file_path):
    # ƒê·ªçc stopword t·ª´ file txt
    with open(stopword_path, 'r', encoding='utf-8') as f:
        stopwords = set(f.read().splitlines())

    # ƒê·ªçc emoji t·ª´ file JSON
    with open(emoji_file_path, 'r', encoding='utf-8') as emoji_file:
        emoji_dict = json.load(emoji_file)

    return stopwords, emoji_dict

stopwords, emoji_dict = load_resources(stopword_path, emoji_file_path)

In [12]:
import pandas as pd

train_df = pd.DataFrame(train_json).T
test_df = pd.DataFrame(dev_json).T # public test

In [13]:
train_df['image_id'] = train_df['image'].astype(str)
test_df['image_id'] = test_df['image'].astype(str)

train_df['caption'] = train_df['caption'].astype(str)
test_df['caption'] = test_df['caption'].astype(str)

train_df['label'] = train_df['label'].astype(str)
test_df['label'] = test_df['label'].astype(str)

train_df.drop(columns=['image'], inplace=True)
test_df.drop(columns=['image'], inplace=True)

train_df.head()

Unnamed: 0,caption,label,image_id
0,C√¥ ·∫•y tr√™n m·∫°ng vs c√¥ ·∫•y ngo√†i ƒë·ªùi =))),multi-sarcasm,8ae451edcd8ebf697f8763ece249115813149c55733bf8...
1,Ng∆∞·ªùi t√¢m linh giao ti·∫øp v·ªõi ng∆∞·ªùi th·ª±c t·∫ø :))),not-sarcasm,35370ffd6c791d6f8c4ab3dd4363ed468fab41e4824ee9...
2,H√¨nh nh∆∞ TrƒÉng h√¥m nay ƒë·∫πp qu√° m·ªçi ng∆∞·ªùi ·∫°! üòÉ ...,multi-sarcasm,316fdd1477725b9fb1a55015ac06b68b92b50bd4303e08...
3,M·ªåI NG∆Ø·ªúI NGHƒ® SAO V·ªÄ PH√ÅT BI·ªÇU C·ª¶A SHARK VI·ªÜT...,not-sarcasm,8a0f34e0e30e4e5cfb306933c1d25fa801a5da78646b59...
4,2 tay hai n√†ng ch·ª© vi·ªác g√¨ ph·∫£i l·ªá hai h√†ng,multi-sarcasm,e517a5e95d1065886a7c815e82fe254381d4f9f4b244d4...


### Features enrichment

In [14]:
!gdown 1q7_-PEQQ6IR3Ortz45vEiSiOnweRJ40H # cap train
!gdown 1wldmw8IJgX-nK2_yfo8fLx55KfWGIZhp # cap dev 

Downloading...
From: https://drive.google.com/uc?id=1q7_-PEQQ6IR3Ortz45vEiSiOnweRJ40H
To: /kaggle/working/vi_train_captions.json
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11.3M/11.3M [00:00<00:00, 217MB/s]
Downloading...
From: https://drive.google.com/uc?id=1wldmw8IJgX-nK2_yfo8fLx55KfWGIZhp
To: /kaggle/working/vi_dev_captions.json
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1.14M/1.14M [00:00<00:00, 126MB/s]


In [15]:
!gdown 1AnM0RUMfyGYWaiUgafufEKMjB8zo5dUt # object reg train
!gdown 1sk2vJutRJLCUwQ6ZKJQeYFKpvdkjUiBs # object reg dev

Downloading...
From: https://drive.google.com/uc?id=1AnM0RUMfyGYWaiUgafufEKMjB8zo5dUt
To: /kaggle/working/objects-recognition-train.json
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2.49M/2.49M [00:00<00:00, 178MB/s]
Downloading...
From: https://drive.google.com/uc?id=1sk2vJutRJLCUwQ6ZKJQeYFKpvdkjUiBs
To: /kaggle/working/objects-recognition-dev.json
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 317k/317k [00:00<00:00, 106MB/s]


In [16]:
!gdown 1nh3y-lXq2CEc_rwzeTqGIU69VZA4eVEn # OCR dev
!gdown 1YSn-dWwprc0nhOgRUIj5aFPaW9lxKWZT

Downloading...
From: https://drive.google.com/uc?id=1nh3y-lXq2CEc_rwzeTqGIU69VZA4eVEn
To: /kaggle/working/ocr-results-dev.json
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 669k/669k [00:00<00:00, 98.8MB/s]
Downloading...
From: https://drive.google.com/uc?id=1YSn-dWwprc0nhOgRUIj5aFPaW9lxKWZT
To: /kaggle/working/ocr-results-train.json
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4.06M/4.06M [00:00<00:00, 204MB/s]


In [17]:

# Open JSON files with utf-8 encoding to handle non-ASCII characters
with open('/kaggle/working/vi_train_captions.json', encoding='utf-8') as f:
    cap_train = json.load(f)

with open('/kaggle/working/vi_dev_captions.json', encoding='utf-8') as f:
    cap_test = json.load(f)

with open('/kaggle/working/objects-recognition-train.json', encoding='utf-8') as f:
    obj_train = json.load(f)

with open('/kaggle/working/objects-recognition-dev.json', encoding='utf-8') as f:
    obj_test = json.load(f)

with open('/kaggle/working/ocr-results-dev.json', encoding='utf-8') as f:
    ocr_test = json.load(f)

with open('/kaggle/working/ocr-results-train.json', encoding='utf-8') as f:
    ocr_train = json.load(f)

In [18]:
def json_to_df(json):
    df = pd.DataFrame(json)
    df['image_id'] = df['image'].astype(str)
    df.drop(columns=['image'], inplace=True)

    return df

In [19]:
for item in ocr_train:
    item["OCR"] = ", ".join(item["OCR"])
for item in ocr_test:
    item["OCR"] = ", ".join(item["OCR"])

In [20]:
cap_test_df = json_to_df(cap_test)

cap_train_df = json_to_df(cap_train)

obj_train_df = json_to_df(obj_train)

obj_test_df = json_to_df(obj_test)

ocr_train_df = json_to_df(ocr_train)

ocr_test_df = json_to_df(ocr_test)

In [21]:
obj_train_df['object_recognition'] = obj_train_df['object_recognition'].apply(lambda x: "Trong h√¨nh c√≥ " + x + ". " if len(x) > 0 else "")
obj_test_df['object_recognition'] = obj_test_df['object_recognition'].apply(lambda x: "Trong h√¨nh c√≥ " + x + ". " if len(x) > 0 else "")
obj_test_df.head()

Unnamed: 0,object_recognition,image_id
0,"Trong h√¨nh c√≥ ng∆∞·ªùi, ƒë·ªì v·∫≠t, phim ho·∫°t h√¨nh, ƒë...",2d06d8c77c741d001916199346cc112847e6bcf61b3dce...
1,"Trong h√¨nh c√≥ truy·ªán, ng∆∞·ªùi, d·∫£i, phim ho·∫°t h√¨...",c981f23fc77cebd06ea872ea2c0ff6ec43a9d2517366ed...
2,"Trong h√¨nh c√≥ s·ªë, ghi, ch·ªØ.",342c9a8f91adeacde0f2c26dee3e6b86861b43e948d10b...
3,"Trong h√¨nh c√≥ di·ªÅu, ng∆∞·ªùi, vƒÉn b·∫£n.",2aa95c65c0a6444caff0657ed21e27fbc403af1727749a...
4,"Trong h√¨nh c√≥ suv, xe th·ªÉ, qu·∫£ng c√°o, xe h∆°i, ...",9d6ebb26087b8d6051f77ef7cbf3e9a0d750baa41b45d7...


In [22]:
ocr_train_df['OCR'] = ocr_train_df['OCR'].apply(lambda row: "Ch·ªØ trong h√¨nh l√† " + row + ". " if len(row) > 0 else "")
ocr_test_df['OCR'] = ocr_test_df['OCR'].apply(lambda row: "Ch·ªØ trong h√¨nh l√† " + row + ". " if len(row) > 0 else "")

# cap_train_df['caption'] = cap_train_df['caption'].apply(lambda row: row[:150] if len(row) > 150 else row)
# cap_test_df['caption'] = cap_test_df['caption'].apply(lambda row: row[:150] if len(row) > 150 else row)

In [23]:
def enrich(df1, df2, add_field):
    temp = df2.set_index('image_id')
    
    df1['caption'] = df1.apply(
        lambda row: row['caption'] + ' ' + temp.loc[row['image_id'], add_field]
        if row['image_id'] in temp.index else row['caption'], axis=1
    )

    return df1


In [24]:
train_df['caption'] = train_df['caption'].apply(lambda x: x[:150] if len(x) > 150 else x)
test_df['caption'] = test_df['caption'].apply(lambda x: x[:150] if len(x) > 150 else x)

train_df = enrich(train_df, ocr_train_df, 'OCR')
test_df = enrich(test_df, ocr_test_df, 'OCR')

train_df = enrich(train_df, cap_train_df, 'caption')
test_df = enrich(test_df, cap_test_df, 'caption')

# train_df = enrich(train_df, obj_train_df, 'object_recognition')
# test_df = enrich(test_df, obj_test_df, 'object_recognition')

In [25]:
train_df.head()

Unnamed: 0,caption,label,image_id
0,C√¥ ·∫•y tr√™n m·∫°ng vs c√¥ ·∫•y ngo√†i ƒë·ªùi =))) B·ª©c ·∫£...,multi-sarcasm,8ae451edcd8ebf697f8763ece249115813149c55733bf8...
1,Ng∆∞·ªùi t√¢m linh giao ti·∫øp v·ªõi ng∆∞·ªùi th·ª±c t·∫ø :))...,not-sarcasm,35370ffd6c791d6f8c4ab3dd4363ed468fab41e4824ee9...
2,H√¨nh nh∆∞ TrƒÉng h√¥m nay ƒë·∫πp qu√° m·ªçi ng∆∞·ªùi ·∫°! üòÉ ...,multi-sarcasm,316fdd1477725b9fb1a55015ac06b68b92b50bd4303e08...
3,M·ªåI NG∆Ø·ªúI NGHƒ® SAO V·ªÄ PH√ÅT BI·ªÇU C·ª¶A SHARK VI·ªÜT...,not-sarcasm,8a0f34e0e30e4e5cfb306933c1d25fa801a5da78646b59...
4,2 tay hai n√†ng ch·ª© vi·ªác g√¨ ph·∫£i l·ªá hai h√†ng H...,multi-sarcasm,e517a5e95d1065886a7c815e82fe254381d4f9f4b244d4...


### Text preprocessing

In [26]:
import re
def preprocess_text(text):
    def remove_stopwords(text):
        return text
    def replace_emojis(text):
        for emoji, description in emoji_dict.get('emoji', {}).items():
            text = text.replace(emoji, description)  # Thay th·∫ø emoji b·∫±ng m√¥ t·∫£
        return text

    def replace_emoticons(text):
        for emoticon, meaning in emoji_dict.get('bi·ªÉu_t∆∞·ª£ng', {}).items():
            emoticon_pattern = re.escape(emoticon) + r"{1,}"
            text = re.sub(emoticon_pattern, meaning, text)
        return text

    def normalize_text(text):
        text = text.lower()  # Chuy·ªÉn th√†nh ch·ªØ th∆∞·ªùng
        text = re.sub(r'(?<=\w)[\/\.\-\_,\\](?=\w)', '', text)  # Lo·∫°i b·ªè d·∫•u ch·∫•m ho·∫∑c g·∫°ch n·ªëi trong t·ª´
        return text

    text = replace_emojis(text)       # Thay th·∫ø emoji
    text = replace_emoticons(text)    # Thay th·∫ø bi·ªÉu c·∫£m
    text = normalize_text(text)       # Chu·∫©n h√≥a vƒÉn b·∫£n
    text = remove_stopwords(text)
    return text

In [27]:
train_df['caption'] = train_df['caption'].astype(str)
test_df['caption'] = test_df['caption'].astype(str)

In [28]:
train_df['caption'] = train_df['caption'].apply(preprocess_text)
test_df['caption'] = test_df['caption'].apply(preprocess_text)

In [29]:
train_df['caption'].iloc[2003]

'may m√† g·∫∑p ƒë∆∞·ª£c t√¥i ch·ªØ trong h√¨nh l√† tr·ªùl olll_, l√†m sao th·∫ø n√†y?!?, c·∫≠u b·∫°n n√†y ƒëang v·∫Ω d·ªü, th·ªâ lƒÉn ƒë√πng ra (o gi·∫≠t, s√πl b·ªçt m√©p!, nguy qu√°!, ƒë·ªÉ t√¥l gl√∫p!, 1.  h√¨nh ·∫£nh m√¥ t·∫£ m·ªôt d·∫£i truy·ªán tranh c√≥ m·ªôt nh√¢n v·∫≠t ho·∫°t h√¨nh v√† m·ªôt ng∆∞·ªùi ƒëang v·∫Ω tr√™n gi√° v·∫Ω. nh√¢n v·∫≠t ho·∫°t h√¨nh ƒëang c·∫ßm m·ªôt c√¢y c·ªç v·∫Ω tr√™n m·ªôt tay v√† m·ªôt c√¢y c·ªç tr√™n tay kia, trong khi ng∆∞·ªùi ƒë√≥ ƒëang ng·ªìi tr√™n gi√° v·∫Ω v·ªõi n·ª• c∆∞·ªùi tr√™n m√¥i. n·ªôi dung c·ªßa h√¨nh ·∫£nh c√≥ th·ªÉ mang t√≠nh ch√¢m bi·∫øm, v√¨ ng∆∞·ªùi ƒë√≥ ƒëang c·ªë g·∫Øng tr√™u ch·ªçc n·ªó l·ª±c v·∫Ω c·ªßa ng∆∞·ªùi kh√°c. tuy nhi√™n, n√≥ c≈©ng c√≥ th·ªÉ mang t√≠nh gi·∫£i tr√≠, v√¨ ng∆∞·ªùi ƒë√≥ d∆∞·ªùng nh∆∞ ƒëang c√≥ m·ªôt kho·∫£ng th·ªùi gian vui v·∫ª khi v·∫Ω tr√™n gi√° v·∫Ω. nh√¨n chung, n·ªôi dung c·ªßa h√¨nh ·∫£nh c√≥ th·ªÉ ƒë∆∞·ª£c coi l√† ch√¢m bi·∫øm ho·∫∑c gi·∫£i tr√≠, t√πy thu·ªôc v√†o quan ƒëi·ªÉm c·ªßa ng∆∞·ªùi ƒë√≥.'

In [30]:
import pandas as pd

from sklearn.model_selection import train_test_split

X = train_df.drop(columns=['label'])  # Features
y = train_df['label']  # Labels

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, stratify=y, random_state=42)

train_df = pd.concat([X_train, y_train], axis=1)
val_df = pd.concat([X_test, y_test], axis=1)

In [31]:
print(train_df.shape, val_df.shape, test_df.shape)

(9724, 3) (1081, 3) (1413, 3)


## Create datasets

In [32]:
from transformers import AutoTokenizer
config = Config()
tokenizer = AutoTokenizer.from_pretrained('uitnlp/visobert')

Downloading config.json:   0%|          | 0.00/644 [00:00<?, ?B/s]

Downloading (‚Ä¶)tencepiece.bpe.model:   0%|          | 0.00/471k [00:00<?, ?B/s]

In [42]:
train_dataset = MemesDataset(train_df, tokenizer, config, visual_embeds)
val_dataset = MemesDataset(val_df, tokenizer, config, visual_embeds)
test_dataset = MemesDataset(test_df, tokenizer, config, visual_embeds)

       not-sarcasm  text-sarcasm  image-sarcasm  multi-sarcasm
4004             0             0              0              1
10369            0             0              0              1
1157             1             0              0              0
6181             1             0              0              0
3987             1             0              0              0
      not-sarcasm  text-sarcasm  image-sarcasm  multi-sarcasm
5129            1             0              0              0
7471            0             0              0              1
2846            0             0              0              1
1643            1             0              0              0
3796            1             0              0              0
   not-sarcasm  text-sarcasm  image-sarcasm  multi-sarcasm
0            0             0              0              0
1            0             0              0              0
2            0             0              0              0
3            0

In [34]:
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)

In [35]:
import gc
torch.cuda.empty_cache()
gc.collect()

64

# Model

In [36]:
class MemesClassifier(pl.LightningModule):
  '''Wrap the training of VisualBERT model to classify memes'''

  def __init__(self, config: Config):
    super().__init__()
    self.configuration = VisualBertConfig.from_pretrained(config.visual_model_name, visual_embedding_dim=config.visual_embedding_dim)
    self.model = VisualBertModel(self.configuration)
    self.n_warmup_steps = config.n_warmup_steps
    self.criterion = nn.CrossEntropyLoss()
    self.dropout = nn.Dropout(0.2)
    self.classifier = nn.Linear(self.model.config.hidden_size, config.n_classes)
    self.n_training_steps = config.n_training_steps
    self.lrate = config.lrate
  
  def forward(self, input_ids, attention_mask, visual_embeds, visual_attention_mask, visual_token_type_ids, labels=None):
    output = self.model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        visual_embeds=visual_embeds,
                        visual_attention_mask=visual_attention_mask,
                        visual_token_type_ids=visual_token_type_ids
    )
    
    output = self.dropout(output.pooler_output)
    output = self.classifier(output)

    return output
  
  def training_step(self, batch, batch_idx):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    visual_embeds = batch['visual_embedding'].to(device)
    visual_attention_mask = batch['visual_attention_mask'].to(device)
    visual_token_type_ids = batch['visual_token_type_ids'].to(device)

    labels = batch['labels'].type(torch.float).to(device)
    
    outputs = self(input_ids, attention_mask, visual_embeds, visual_attention_mask, visual_token_type_ids, labels)
    loss = self.criterion(outputs, labels)
    self.log('train_loss', loss, prog_bar=True, logger=True)

    return {"loss":loss, 'predictions':outputs, 'labels':labels}

  def validation_step(self, batch, batch_idx):
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    visual_embeds = batch['visual_embedding']
    visual_attention_mask = batch['visual_attention_mask']
    visual_token_type_ids = batch['visual_token_type_ids'].to(device)
    labels = batch['labels'].type(torch.float).to(device)
    
    outputs = self(input_ids, attention_mask, visual_embeds, visual_attention_mask, visual_token_type_ids, labels)
    loss = self.criterion(outputs, labels)
    self.log('val_loss', loss, prog_bar=True, logger=True)

    return loss
  
  def configure_optimizers(self):
    optimizer = AdamW(self.parameters(), lr=self.lrate)

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=self.n_warmup_steps,
        num_training_steps=self.n_training_steps
    )

    return dict(
        optimizer=optimizer,
        lr_scheduler=dict(
            scheduler=scheduler,
            interval='step'
        )
    )

# Train

In [46]:
def train_model(model, train_dataset, val_dataset, tokenizer, config, visual_embeds):
    checkpoint_callback = ModelCheckpoint(
        dirpath="checkpoints",
        filename="best-checkpoint",
        save_top_k=5,
        verbose=True,
        monitor="val_loss",
        mode="min",
    )
    
    logger = TensorBoardLogger("lightning_logs", name="memes-text")
    early_stopping_callback = EarlyStopping(monitor='val_loss', patience=3)

    trainer = pl.Trainer(
        logger=logger,
        callbacks=[early_stopping_callback, checkpoint_callback],
        max_epochs=config.n_epochs,
        accelerator="auto",
        enable_progress_bar=True,
    )

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size)

    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [64]:
def evaluate_model(test_dataset, checkpoint, tokenizer, config, visual_embeds):

    trained_model = MemesClassifier.load_from_checkpoint(
        checkpoint,
        config=config
    ).to(device)

    trained_model.eval()

    predictions = []
    labels = []
    for item in tqdm(test_dataset):
        with torch.no_grad():
            prediction = trained_model(
                item["input_ids"].unsqueeze(dim=0).to(device),
                item["attention_mask"].unsqueeze(dim=0).to(device),
                item["visual_embedding"].unsqueeze(dim=0).to(device),
                item['visual_attention_mask'].unsqueeze(dim=0).to(device),
                item['visual_token_type_ids'].unsqueeze(dim=0).to(device)
            )
        predictions.append(prediction.flatten())
        labels.append(item["labels"])
    
    predictions = torch.stack(predictions).detach().cpu()
    labels = torch.stack(labels).detach().cpu()

    _, preds = torch.max(predictions, dim=1)
    _, labels = torch.max(labels, dim=1)
    
    f1_macro = f1_score(labels, preds , average="macro")
    f1_micro = f1_score(labels, preds , average="micro")

    return f1_macro, f1_micro

In [39]:
model = MemesClassifier(config).to('cuda')

Downloading config.json:   0%|          | 0.00/631 [00:00<?, ?B/s]

In [47]:
train_model(
    model=model, 
    train_dataset=train_dataset, 
    val_dataset=val_dataset,
    tokenizer=tokenizer,
    config = config,
    visual_embeds=visual_embeds
)

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [62]:
import gc
torch.cuda.empty_cache()
gc.collect()

4797

In [65]:
f1_macro, f1_micro = evaluate_model(
    val_dataset,
    checkpoint='./checkpoints/best-checkpoint.ckpt',
    tokenizer=tokenizer,
    config=config,
    visual_embeds=visual_embeds
)

  0%|          | 0/1081 [00:00<?, ?it/s]

  _, preds = torch.max(torch.tensor(predictions), dim=1)


In [66]:
print(f"Evaluate on valid set: f1_macro={f1_macro}, f1_micro={f1_micro}")

Evaluate on valid set: f1_macro=0.32361345104117534, f1_micro=0.6318223866790009


# Load checkpoints and predict



In [67]:
import gc
torch.cuda.empty_cache()
gc.collect()

4303

In [68]:

def get_prediction(test_dataset, checkpoint, tokenizer, config, visual_embeds, batch_size=2):
    # Create DataLoader
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
    )
    
    # Load and set model to evaluation mode
    trained_model = MemesClassifier.load_from_checkpoint(
        checkpoint,
        config=config
    ).to(device)
    trained_model.eval()
    
    all_predictions = []
    
    # Process batches
    for batch in tqdm(test_dataloader, desc="Getting predictions"):
        with torch.no_grad():
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Get predictions
            predictions = trained_model(
                batch["input_ids"],
                batch["attention_mask"],
                batch["visual_embedding"],
                batch["visual_attention_mask"],
                batch["visual_token_type_ids"]
            )
            
            # Move predictions to CPU
            all_predictions.append(predictions.cpu())
    
    # Concatenate all predictions
    predictions = torch.cat(all_predictions, dim=0)
    
    # Get probabilities and predicted classes
    probabilities = torch.softmax(predictions, dim=1)
    predicted_classes = torch.argmax(predictions, dim=1)
    
    # Convert to class labels
    predicted_labels = [config.classes[idx.item()] for idx in predicted_classes]
    
    return predicted_labels

In [74]:
predictions = get_prediction(
    test_dataset,
    checkpoint='./checkpoints/best-checkpoint.ckpt',
    tokenizer=tokenizer,
    config=config,
    visual_embeds=visual_embeds
)

Getting predictions:   0%|          | 0/707 [00:00<?, ?it/s]

In [75]:
len(predictions)

1413

In [76]:

test_predicted = {k:i for k,i in zip(dev_json.keys(),predictions)}

In [77]:
result_json = {
    "results": test_predicted,
    "phase": "dev"
}

In [78]:

with open('results.json', 'w') as fp:
    json.dump(result_json, fp,ensure_ascii=True,indent=True)