In [5]:
# imports
import sys
sys.path.append('../')
import math, statistics, time
from collections import defaultdict
import numpy as np
from datetime import datetime
import pickle
import pandas as pd
import torch.nn as nn
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample
from tqdm import tqdm
from utils.sentence_bert_dataloader import SentenceBertDataloader
from utils.dataset import Dataset

base_model = 'roberta-base'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

num_epochs = 20
model_name = 'sentence_transformer_30'
model_save_path = '../models/{}'.format(model_name)

with open('../data/training_label_100.pkl', 'rb') as f:
    labels = pickle.load(f)

In [6]:
# load meme dataset
meme_dict = None
with open('../data/meme_900k_cleaned_data_v2.pkl', 'rb') as f:
    meme_dict = pickle.load(f)
print("Keys in meme dict dataset:", meme_dict.keys())
print("Number of uuids:", len(meme_dict['uuid_label_dic']))

Keys in meme dict dataset: dict_keys(['label_uuid_dic', 'uuid_label_dic', 'uuid_caption_dic', 'uuid_image_path_dic', 'uuid_caption_cased_dic'])
Number of uuids: 300


In [7]:
# utility functions
def clean_and_unify_caption(caption):
    return caption[0].strip()+'; '+caption[1].strip()

In [8]:
# create pandas dataframe
training_uuids = labels.keys()
temp_arr = []
for uuid in training_uuids:
    for caption in meme_dict['uuid_caption_dic'][uuid]:
        temp_arr.append([uuid, clean_and_unify_caption(caption)])
df = pd.DataFrame(temp_arr, columns=['category', 'text'])

# split dataset
np.random.seed(42)
df_train, df_test = np.split(df.sample(frac=1, random_state=42), [int(.9*len(df))])

print(len(df_train), len(df_test))

270000 30000


## Creating Dataset and DataLoader

In [5]:
train_dataset = Dataset(df_train, labels)
test_dataset = Dataset(df_test, labels)

In [6]:
train_loader = SentenceBertDataloader(train_dataset, 32)
test_loader = SentenceBertDataloader(test_dataset, 32)

202500it [00:00, 733497.27it/s]
22500it [00:00, 715014.02it/s]


## Model Training

In [15]:
model = SentenceTransformer('../models/sentence_transformer_roberta_20', device=device)
train_loss = losses.ContrastiveLoss(model=model)

In [None]:
model.fit(train_objectives=[(train_loader, train_loss)],
                              epochs=num_epochs, 
                              warmup_steps=100, 
                              output_path=model_save_path)

## Create and save category embeddings

In [11]:
def getCategoryEmbeddings(df_train, model):
    uuid_to_emb_dict = {}
    uuid_count_dict = defaultdict(int)
    batch_size = 512
    
    for i in tqdm(range(0, df_train.shape[0], batch_size)):
        texts = list(df_train.text[i:i+batch_size])
        uuids = list(df_train.category[i:i+batch_size])
        embeddings = model.encode(texts)
        for i, uuid in enumerate(uuids):
            uuid_count_dict[uuid]+=1
            if uuid in uuid_to_emb_dict:
                uuid_to_emb_dict[uuid]=uuid_to_emb_dict[uuid]+embeddings[i]
            else:
                uuid_to_emb_dict[uuid]=embeddings[i]
    
    for k, v in uuid_to_emb_dict.items():
        uuid_to_emb_dict[k] = uuid_to_emb_dict[k]/uuid_count_dict[k] 
    
    return uuid_to_emb_dict

In [13]:
num_epochs = 5
import os
for i in range(1,6):
    model_name = 'sentence_transformer_roberta_samples_100_epochs_{}'.format(i*5)
    model_load_path = '../models/{}'.format(model_name)
    embeddings_save_path = '../models/model_utils/{}/category_embeddings.pkl'.format(model_name)
    
    model = SentenceTransformer(model_load_path, device=device)
    uuid_to_emb_dict = getCategoryEmbeddings(df_train, model)
    
    os.makedirs(os.path.dirname(embeddings_save_path), exist_ok=True)
    with open(embeddings_save_path, 'wb') as f:
        pickle.dump(uuid_to_emb_dict, f)

100%|██████████| 528/528 [02:49<00:00,  3.11it/s]
100%|██████████| 528/528 [02:47<00:00,  3.16it/s]
100%|██████████| 528/528 [02:48<00:00,  3.14it/s]
100%|██████████| 528/528 [02:47<00:00,  3.16it/s]
100%|██████████| 528/528 [02:46<00:00,  3.16it/s]
