In [None]:
!pip install transformers

In [None]:
import os
import nltk
import math
import torch
import spacy
import torchvision
import numpy as np
import pandas as pd
from PIL import Image
import torch.nn as nn
from tqdm.notebook import tqdm
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Sampler
from transformers import GPT2Tokenizer
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader, random_split

# DATA

In [None]:
df = pd.read_csv('/content/drive/MyDrive/[16cls]_memes_dataset.csv')
df

Unnamed: 0,meme_template_name,meme_template_url,meme_url,meme_description,template_path
0,waiting_skeleton,https://imgflip.com/s/meme/Waiting-Skeleton.jpg,https://i.imgflip.com/4ty6ie.jpg,when mom says wait in line; and she will be ba...,waiting_skeleton.jpg
1,waiting_skeleton,https://imgflip.com/s/meme/Waiting-Skeleton.jpg,https://i.imgflip.com/4tp0ie.jpg,waiting for your mom to stop talking to her fr...,waiting_skeleton.jpg
2,batman_slapping_robin,https://imgflip.com/s/meme/Batman-Slapping-Rob...,https://i.imgflip.com/4tyoki.jpg,"buzzfeed, independent, guardian... reality",batman_slapping_robin.jpg
3,batman_slapping_robin,https://imgflip.com/s/meme/Batman-Slapping-Rob...,https://i.imgflip.com/4u38ba.jpg,because? took my pizza.,batman_slapping_robin.jpg
4,change_my_mind,https://imgflip.com/s/meme/Change-My-Mind.jpg,https://i.imgflip.com/4tcvoc.jpg,"being a programmer, doesn't mean i can hack fa...",change_my_mind.jpg
...,...,...,...,...,...
112667,drake_hotline_bling,https://imgflip.com/s/meme/Drake-Hotline-Bling...,https://i.imgflip.com/3g0gv4.jpg,upvoting upvote beggars; downvoting upvote beg...,drake_hotline_bling.jpg
112668,drake_hotline_bling,https://imgflip.com/s/meme/Drake-Hotline-Bling...,https://i.imgflip.com/3pg5kh.jpg,stocks; stonks,drake_hotline_bling.jpg
112669,drake_hotline_bling,https://imgflip.com/s/meme/Drake-Hotline-Bling...,https://i.imgflip.com/47l080.jpg,"""i forgot my homework at my home""; ""my homewor...",drake_hotline_bling.jpg
112670,drake_hotline_bling,https://imgflip.com/s/meme/Drake-Hotline-Bling...,https://i.imgflip.com/47l6wf.jpg,maglecture bago magpaquiz; magpaquiz bago magl...,drake_hotline_bling.jpg


## DataSet class


In [None]:
class MemesDataset(Dataset):
    def __init__(self, df, tokenizer, template2idx, idx2template, max_len):
        self.df = df
        self.df.reset_index(drop=True, inplace=True)

        self.max_len = max_len

        self.tokenizer = tokenizer
        self.nlp = spacy.load("en_core_web_sm") 
        self.allowed_tokens = ["NOUN", "VERB", "AUX", "PUNCT"]

        self.template2idx =  template2idx
        self.idx2template = idx2template

    def __len__(self):
        return len(self.df.index)
    
    def __getitem__(self, index):
        caption_idxs, attn_mask = self._get_caption_indexes(index)
        template_idx = self._get_template_index(index)

        return {
            'caption': caption_idxs,
            'attn_mask': attn_mask,
            'template': template_idx
        }

    def _get_template_index(self, index):
        templte = self.df.loc[index, 'meme_template_name']
        template_idx = self.template2idx[templte]

        return template_idx
    
    def _get_caption_indexes(self, index):
        caption = self.df.loc[index, 'meme_description']
        caption = self.nlp(caption)

        caption = [token.text for token in caption 
                    if token.pos_  in self.allowed_tokens]
        caption = ' '.join(caption)

        caption = caption.replace(';', self.tokenizer.eos_token )
        
        caption_tokens = self.tokenizer.tokenize(caption)[:self.max_len]
        caption_tokens += [self.tokenizer.pad_token] * \
                          max(0, self.max_len + 2 - len(caption_tokens)) 

        caption_tokens = [self.tokenizer.bos_token] + caption_tokens +\
                         [self.tokenizer.sep_token]
        
        caption_idxs = self.tokenizer.convert_tokens_to_ids(caption_tokens)
        caption_idxs = torch.tensor(caption_idxs)

        attn_mask = (
                caption_idxs != self.tokenizer.pad_token
            ).long()

        return caption_idxs, attn_mask
        

In [None]:
import pickle

# Loading labels encoder/decoder
with open('/content/drive/MyDrive/[16cls]_template2idx.pickle', 'rb') as f:
        template2idx = pickle.load(f)

with open('/content/drive/MyDrive/[16cls]_idx2template.pickle', 'rb') as f:
        idx2template = pickle.load(f)

In [None]:
# Shuffling dataframe
df = df.sample(frac=1)

# Spliting dataframe
df_train = df[:math.ceil(0.8 * len(df.index))]
df_val = df[math.ceil(0.8 * len(df.index)):]

# Creating instances of training and validation set
train_set = MemesDataset(df_train, tokenizer, template2idx, idx2template, max_len=16)
val_set = MemesDataset(df_val, tokenizer, template2idx, idx2template, max_len=16)

## Dataloaders and data spliting

In [None]:
train_loader = DataLoader(train_set, batch_size=128, num_workers=4, shuffle=True)
val_loader = DataLoader(val_set, batch_size=128, num_workers=4, shuffle=True)

# Model

## Template selector module

In [None]:
class TemplateSelector(nn.Module):
    def __init__(self, num_classes, freeze_roberta=False):
        super(TemplateSelector, self).__init__()
        self.roberta_layer = RobertaModel.from_pretrained("roberta-base")

        if freeze_roberta:
            for p in self.roberta_layer.parameters():
                p.requires_grad = False

        self.dropout_1 = torch.nn.Dropout(0.3)
        self.batch_norm_1 = nn.BatchNorm1d(768)
        self.pre_classifier = nn.Linear(768, 768)
        self.activation = nn.LeakyReLU(0.2)

        self.dropout = torch.nn.Dropout(0.5)
        self.batch_norm = nn.BatchNorm1d(768)
        self.cls_layer = nn.Linear(768, num_classes)
    
    def forward(self, seq, attn_mask):
        last_hidden = self.roberta_layer(seq, attention_mask = attn_mask)[0]

        # sent_repr = last_hidden.reshape(last_hidden.size(0), -1)
        sent_repr = last_hidden[:, 0]

        sent_repr = self.dropout_1(sent_repr)
        sent_repr = self.batch_norm_1(sent_repr)
        sent_repr = self.pre_classifier(sent_repr)
        sent_repr = self.activation(sent_repr)

        sent_repr = self.dropout(sent_repr)
        sent_repr =  self.batch_norm(sent_repr)
        logits = self.cls_layer(sent_repr)

        return logits

In [None]:
model = TemplateSelector(num_classes=16).to(device)

In [None]:
model.load_state_dict(torch.load(''))