In [16]:
# imports
import math, statistics, time
from collections import defaultdict
import numpy as np
from tqdm.auto import tqdm
from datetime import datetime
# import torch_xla
# import torch_xla.core.xla_model as xm
import pickle
import pandas as pd
from transformers import AutoModel
import torch.nn as nn
import torch

# from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

import warnings
warnings.filterwarnings("ignore")

# HF token
token = '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'
from huggingface_hub import HfApi, HfFolder
api=HfApi()
folder=HfFolder()
api.set_access_token(token)
folder.save_token(token)

# constants
dataset = "dank_memes"
pre_trained_model_checkpoint = "roberta-base"
model_name = "roberta-base-memes-900k-subset-75"
hub_model_id = "armageddon/roberta-base-memes-900k-subset-75"
stride = 150

In [17]:
# load meme dataset
meme_dict = None
with open('./meme_900k_cleaned_data.pkl', 'rb') as f:
    meme_dict = pickle.load(f)
print("Keys in meme dict dataset:", meme_dict.keys())
print("Number of uuids:", len(meme_dict['uuid_label_dic']))

Keys in meme dict dataset: dict_keys(['label_uuid_dic', 'uuid_label_dic', 'uuid_caption_dic', 'uuid_image_path_dic'])
Number of uuids: 300


In [18]:
# utility functions
def clean_and_unify_caption(caption):
    return caption[0].strip()+', '+caption[1].strip()

In [22]:
# # select the uuids
# import os
# import regex as re
# training_uuids = []
# dir_path = './memes900k_qa/'
# for path in os.listdir(dir_path):
#     if os.path.isfile(os.path.join(dir_path, path)):
#         if not re.match(r'.*_manual.pkl', path):
#             training_uuids.append(path.split('.')[0])
# print(len(training_uuids))
# labels = {k:v for k,v in zip(training_uuids, range(len(training_uuids)))}

# with open('./models/data/training_label.pkl', 'wb') as f:
#     pickle.dump(labels, f)

with open('./models/data/training_label.pkl', 'rb') as f:
    labels = pickle.load(f)

In [28]:
# # create pandas dataframe
# temp_arr = []
# for uuid in training_uuids:
#     for caption in meme_dict['uuid_caption_dic'][uuid]:
#         temp_arr.append([uuid, clean_and_unify_caption(caption)])
# df = pd.DataFrame(temp_arr, columns=['category', 'text'])

# # split dataset
# np.random.seed(42)
# df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=42), 
#                                      [int(.8*len(df)), int(.9*len(df))])

# print(len(df_train),len(df_val), len(df_test))

180000 22500 22500


In [29]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.labels = [labels[label] for label in df['category']]
        self.texts = [tokenizer(text, padding='max_length', max_length = 50, truncation=True,
                                return_tensors="pt") for text in df['text']]

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)

    def get_batch_labels(self, idx):
        # Fetch a batch of labels
        return np.array(self.labels[idx])

    def get_batch_texts(self, idx):
        # Fetch a batch of inputs
        return self.texts[idx]

    def __getitem__(self, idx):
        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)
        return batch_texts, batch_y

In [37]:
# train_dataset = Dataset(df_train)
# val_dataset = Dataset(df_val)
# test_dataset = Dataset(df_test)
train_dataset = torch.load('./models/data/train_dataset')
val_dataset = torch.load('./models/data/val_dataset')
test_dataset = torch.load('./models/data/test_dataset')

In [55]:
from transformers import AutoModelForSequenceClassification, AutoModel
class Meme_Classifier(nn.Module):
    def __init__(self, num_labels, dropout=0.3):
        super(Meme_Classifier, self).__init__()
        self.model = AutoModel.from_pretrained(pre_trained_model_checkpoint)
        self.dropout = nn.Dropout(dropout)
        self.linear1 = nn.Linear(768, 512)
        self.linear2 = nn.Linear(512, num_labels)
        self.relu = nn.ReLU()

    def forward(self, input_id, mask):
        _, pooled_output = self.model(input_ids=input_id, attention_mask=mask,return_dict=False)
        dropout_output1 = self.dropout(pooled_output)
        linear_output1 = self.dropout(self.relu(self.linear1(dropout_output1)))
        final_output = self.relu(self.linear2(linear_output1))
        return final_output

In [56]:
# training loop
from torch.optim import Adam
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(model, train_dataset, val_dataset, learning_rate, loss_diff, max_epochs):
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32)

    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr= learning_rate)

    model = model.to(device)
    criterion = criterion.to(device)
    epoch_num = 0
    prev_loss = float('inf')
    while True:
            epoch_num+=1
            total_acc_train = 0
            total_loss_train = 0

            for train_input, train_label in tqdm(train_dataloader):
                train_label = train_label.to(device)
                mask = train_input['attention_mask'].to(device)
                input_id = train_input['input_ids'].squeeze(1).to(device)

                output = model(input_id, mask)
                
                batch_loss = criterion(output, train_label.long())
                total_loss_train += batch_loss.item()
                
                acc = (output.argmax(dim=1) == train_label).sum().item()
                total_acc_train += acc

                model.zero_grad()
                batch_loss.backward()
                optimizer.step()
            
            total_acc_val = 0
            total_loss_val = 0
            
            with torch.no_grad():

                for val_input, val_label in val_dataloader:

                    val_label = val_label.to(device)
                    mask = val_input['attention_mask'].to(device)
                    input_id = val_input['input_ids'].squeeze(1).to(device)

                    output = model(input_id, mask)

                    batch_loss = criterion(output, val_label.long())
                    total_loss_val += batch_loss.item()
                    
                    acc = (output.argmax(dim=1) == val_label).sum().item()
                    total_acc_val += acc
            
            print(
                f'Epochs: {epoch_num + 1} | Train Loss: {total_loss_train / len(train_dataset): .3f} \
                | Train Accuracy: {total_acc_train / len(train_dataset): .3f} \
                | Val Loss: {total_loss_val / len(val_dataset): .3f} \
                | Val Accuracy: {total_acc_val / len(val_dataset): .3f}')
            
            if epoch_num>=max_epochs or abs(prev_loss-total_loss_train)<=loss_diff:
                break
            prev_loss=total_loss_train

In [57]:
model = Meme_Classifier(len(labels))
LR = 1e-6
max_epochs = 20
loss_diff = 0.01
train(model, train_dataset, val_dataset, LR, loss_diff, max_epochs)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.decoder.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 5625/5625 [06:54<00:00, 13.58it/s]


Epochs: 2 | Train Loss:  0.120                 | Train Accuracy:  0.270                 | Val Loss:  0.107                 | Val Accuracy:  0.358


100%|██████████| 5625/5625 [06:54<00:00, 13.56it/s]


Epochs: 3 | Train Loss:  0.097                 | Train Accuracy:  0.389                 | Val Loss:  0.090                 | Val Accuracy:  0.407


100%|██████████| 5625/5625 [06:55<00:00, 13.55it/s]


Epochs: 4 | Train Loss:  0.084                 | Train Accuracy:  0.425                 | Val Loss:  0.081                 | Val Accuracy:  0.432


100%|██████████| 5625/5625 [06:54<00:00, 13.58it/s]


Epochs: 5 | Train Loss:  0.076                 | Train Accuracy:  0.452                 | Val Loss:  0.075                 | Val Accuracy:  0.454


100%|██████████| 5625/5625 [06:54<00:00, 13.57it/s]


Epochs: 6 | Train Loss:  0.071                 | Train Accuracy:  0.476                 | Val Loss:  0.072                 | Val Accuracy:  0.470


100%|██████████| 5625/5625 [06:53<00:00, 13.60it/s]


Epochs: 7 | Train Loss:  0.067                 | Train Accuracy:  0.498                 | Val Loss:  0.069                 | Val Accuracy:  0.484


100%|██████████| 5625/5625 [06:53<00:00, 13.60it/s]


Epochs: 8 | Train Loss:  0.064                 | Train Accuracy:  0.518                 | Val Loss:  0.067                 | Val Accuracy:  0.494


100%|██████████| 5625/5625 [06:54<00:00, 13.58it/s]


Epochs: 9 | Train Loss:  0.061                 | Train Accuracy:  0.535                 | Val Loss:  0.066                 | Val Accuracy:  0.498


100%|██████████| 5625/5625 [06:54<00:00, 13.58it/s]


Epochs: 10 | Train Loss:  0.059                 | Train Accuracy:  0.550                 | Val Loss:  0.065                 | Val Accuracy:  0.505


100%|██████████| 5625/5625 [06:54<00:00, 13.58it/s]


Epochs: 11 | Train Loss:  0.057                 | Train Accuracy:  0.564                 | Val Loss:  0.065                 | Val Accuracy:  0.512


100%|██████████| 5625/5625 [06:54<00:00, 13.59it/s]


Epochs: 12 | Train Loss:  0.055                 | Train Accuracy:  0.577                 | Val Loss:  0.064                 | Val Accuracy:  0.516


100%|██████████| 5625/5625 [06:54<00:00, 13.58it/s]


Epochs: 13 | Train Loss:  0.052                 | Train Accuracy:  0.594                 | Val Loss:  0.064                 | Val Accuracy:  0.514


100%|██████████| 5625/5625 [06:54<00:00, 13.57it/s]


Epochs: 14 | Train Loss:  0.050                 | Train Accuracy:  0.608                 | Val Loss:  0.064                 | Val Accuracy:  0.519


100%|██████████| 5625/5625 [06:54<00:00, 13.57it/s]


Epochs: 15 | Train Loss:  0.048                 | Train Accuracy:  0.622                 | Val Loss:  0.064                 | Val Accuracy:  0.518


100%|██████████| 5625/5625 [06:54<00:00, 13.57it/s]


Epochs: 16 | Train Loss:  0.045                 | Train Accuracy:  0.639                 | Val Loss:  0.064                 | Val Accuracy:  0.519


100%|██████████| 5625/5625 [06:54<00:00, 13.57it/s]


Epochs: 17 | Train Loss:  0.043                 | Train Accuracy:  0.656                 | Val Loss:  0.065                 | Val Accuracy:  0.522


100%|██████████| 5625/5625 [06:54<00:00, 13.56it/s]


Epochs: 18 | Train Loss:  0.040                 | Train Accuracy:  0.674                 | Val Loss:  0.066                 | Val Accuracy:  0.518


100%|██████████| 5625/5625 [06:45<00:00, 13.89it/s]


Epochs: 19 | Train Loss:  0.038                 | Train Accuracy:  0.692                 | Val Loss:  0.066                 | Val Accuracy:  0.518


100%|██████████| 5625/5625 [06:53<00:00, 13.61it/s]


Epochs: 20 | Train Loss:  0.035                 | Train Accuracy:  0.711                 | Val Loss:  0.067                 | Val Accuracy:  0.515


100%|██████████| 5625/5625 [06:54<00:00, 13.57it/s]


Epochs: 21 | Train Loss:  0.033                 | Train Accuracy:  0.732                 | Val Loss:  0.069                 | Val Accuracy:  0.514


In [98]:
# save model
MODEL_PATH = './models/roberta-base-memes-900k-subset-75'
torch.save(model.state_dict(), MODEL_PATH)

In [59]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Meme_Classifier(len(labels))
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()
model = model.to(device)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.decoder.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [60]:
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True)
acc = 0
for test_input, test_label in tqdm(test_dataloader):
    mask = test_input['attention_mask'].to(device)
    input_id = test_input['input_ids'].squeeze(1).to(device)
    logits = model(input_id, mask).to('cpu')
    acc += meme_accuracy_sum_only(logits, test_label)

100%|██████████| 704/704 [00:16<00:00, 42.47it/s]


In [61]:
test_accuracy  = acc/len(test_dataset.labels)
print(test_accuracy)

0.6573333333333333


In [74]:
# now test full user captions for accuracy
import os
import regex as re
import pickle
testing_user_captions = []
dir_path = './memes900k_qa/'
for path in tqdm(os.listdir(dir_path)):
    if os.path.isfile(os.path.join(dir_path, path)):
        if not re.match(r'.*_manual.pkl', path):
            with open(os.path.join(dir_path, path), 'rb') as f:
                dic = pickle.load(f)
                for v in dic['qa'].keys(): 
                    testing_user_captions.append([v, labels[dic['uuid']]])

100%|██████████| 152/152 [00:00<00:00, 10827.69it/s]


In [77]:
tokenized = [tokenizer(text[0], padding='max_length', max_length = 50, truncation=True,
                                return_tensors="pt") for text in testing_user_captions]
# input_ids = torch.stack()

In [80]:
tokenized[0]

{'input_ids': tensor([[    0,  7424,   939,  3529,    65,    55, 15711,     9, 19982,  1666,
             6,   117, 16506,    47, 14964,   101,   195,   416,   328,     2,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0]])}

In [85]:
input_ids = torch.stack([x['input_ids'] for x in tokenized])
input_ids = input_ids.reshape(len(input_ids), -1).to(device)
masks = torch.stack([x['attention_mask'] for x in tokenized])
masks = masks.reshape(len(input_ids), -1).to(device)

In [96]:
acc = 0
for i in tqdm(range(len(input_ids))):
    logits = model(input_ids[i].reshape(1, -1), masks[i].reshape(1, -1))
    acc += meme_accuracy_sum_only(logits, [testing_user_captions[i][1]])

100%|██████████| 3745/3745 [01:25<00:00, 43.71it/s]


In [97]:
acc/len(input_ids)

0.6085447263017356