# Imports and data

In [1]:




import pandas as pd
import numpy as np
import jsonlines
import seaborn as sns
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch_optimizer as optim
import random
from transformers import AutoModelWithLMHead, AutoTokenizer, AutoModel

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
from importlib import reload
pd.set_option('display.max_rows', 500)
pd.set_option('display.float_format', '{:0.3f}'.format)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
pd.options.display.width = 0
import warnings
import torchvision
warnings.filterwarnings('ignore')

from facebook_hateful_memes_detector.utils.globals import set_global, get_global
set_global("cache_dir", "/home/ahemf/cache/cache")
# set_global("cache_dir", "/Users/ahemf/mygit/facebook-hateful-memes/cache")
set_global("dataloader_workers", 8)
set_global("use_autocast", True)
set_global("models_dir", "/home/ahemf/cache/")

from facebook_hateful_memes_detector.utils import read_json_lines_into_df, in_notebook, set_device, random_word_mask, dict2sampleList, run_simclr, load_stored_params
get_global("cache_dir")
from facebook_hateful_memes_detector.models import Fasttext1DCNNModel, MultiImageMultiTextAttentionEarlyFusionModel, LangFeaturesModel, AlbertClassifer
from facebook_hateful_memes_detector.preprocessing import TextImageDataset, get_datasets, get_image2torchvision_transforms, TextAugment
from facebook_hateful_memes_detector.preprocessing import DefinedRotation, QuadrantCut, ImageAugment, DefinedAffine, HalfSwap, get_transforms_for_bbox_methods
from facebook_hateful_memes_detector.preprocessing import NegativeSamplingDataset, ImageFolderDataset, ZipDatasets
from facebook_hateful_memes_detector.models.MultiModal.VilBertVisualBert import VilBertVisualBertModel
from facebook_hateful_memes_detector.models.MultiModal import VilBertVisualBertModelV2, MLMSimCLR
from facebook_hateful_memes_detector.training import *
import facebook_hateful_memes_detector
from facebook_hateful_memes_detector.utils import get_vgg_face_model, get_torchvision_classification_models, init_fc, my_collate, merge_sample_lists
reload(facebook_hateful_memes_detector)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_device(device)

scheduler_init_fn = get_cosine_schedule_with_warmup()
# Use mixup in SSL training, Use UDA maybe


In [2]:
import random
def get_preprocess_text():
    char_level = {"keyboard": 0.1, "char_substitute": 0.4, "char_insert": 0.2, "char_swap": 0.2, 
                  "ocr": 0.0, "char_delete": 0.1}
    char_level = TextAugment([0.1, 0.4, 0.5], char_level)
    word_level = {"split": 0.2,
                 "stopword_insert": 0.0, "word_join": 0.2, "punctuation_continue": 0.5}
    word_level = TextAugment([0.1, 0.4, 0.5], word_level, 
                             fasttext_file="wiki-news-300d-1M-subword.bin")
    sentence_level = {"text_rotate": 0.0, "sentence_shuffle": 0.0, # "glove_twitter": 0.75,"word_cutout": 0.5,
                      "one_third_cut": 0.25, "half_cut":0.0, "part_select": 0.25, }
    sentence_level = TextAugment([0.1, 0.9], sentence_level, # idf_file="/home/ahemf/cache/tfidf_terms.csv"
                                )
    gibberish = {"gibberish_insert": 0.25, "punctuation_insert": 0.75, 
                 "punctuation_replace": 0.25, "punctuation_strip": 0.5,}
    gibberish = TextAugment([0.25, 0.75], gibberish)
    # translation = {"dab":1.0, "punctuation_insert": 0.01}
    # translation = TextAugment([0.0, 1.0], translation, dab_file="/home/ahemf/cache/fdab.csv")
    def process(text, **kwargs):
        if random.random() < 0.25:
            text = sentence_level(text, **kwargs)
        # else:
            # text = translation(text, **kwargs)
        text = word_level(text, **kwargs)
        text = char_level(text, **kwargs)
        text = gibberish(text, **kwargs)
        return text
    return process


preprocess_text = get_preprocess_text()
transforms_for_bbox_methods = get_transforms_for_bbox_methods()

vectorized_text_processor = np.vectorize(preprocess_text)
def vectorized_image_processor(images):
    return [transforms_for_bbox_methods(i) for i in images]

def augment_method(sampleList):
    sampleList = dict2sampleList(sampleList, device=get_device())
    sampleList = sampleList.copy()
    sampleList.image = vectorized_image_processor(sampleList.original_image)
    sampleList.text = vectorized_text_processor(sampleList.original_text)
    sampleList.mixup = [False] * len(sampleList.text)
    sampleList = sampleList.to(get_device())
    return sampleList


def get_view_1():
    augs = {"keyboard": 0.4, "char_substitute": 0.4, "char_insert": 0.2, "char_swap": 0.2, 
                  "ocr": 0.0, "char_delete": 0.1, "gibberish_insert": 0.25, "punctuation_insert": 0.75, 
                 "punctuation_replace": 0.25, "punctuation_strip": 0.5, "word_join": 0.2, "punctuation_continue": 0.5}
    text_augs = TextAugment([0.1, 0.9,], augs)
    imtrans = transforms.RandomHorizontalFlip(p=1.0)
    vtp = np.vectorize(text_augs)
    def vip(images):
        return [imtrans(i) for i in images]

    def aug(sampleList):
        sampleList = dict2sampleList(sampleList, device=get_device())
        sampleList = sampleList.copy()
        sampleList.image = vip(sampleList.image)
        sampleList.text = vtp(sampleList.original_text)
        sampleList.mixup = [False] * len(sampleList.text)
        sampleList = sampleList.to(get_device())
        return sampleList
    return aug

data = get_datasets(data_dir="../data/",
                    train_text_transform=None,
                    train_image_transform=None,
                    test_text_transform=None,
                    test_image_transform=None,
                    train_torchvision_pre_image_transform=None,
                    test_torchvision_pre_image_transform=None,
                    cache_images=False,
                    use_images=True,
                    dev=False,
                    test_dev=True,
                    keep_original_text=True,
                    keep_original_image=True,
                    keep_processed_image=False,
                    keep_torchvision_image=False,
                    train_mixup_config=None)


data["test"]["label"] = -1

df = pd.concat((data["train"],
                data["dev"], 
                data["test"]))

In [3]:
dataset = convert_dataframe_to_dataset(df, data["metadata"], True)

In [4]:
model_params = dict(
    model_name={"lxmert": dict(dropout=0.05, gaussian_noise=0.01), 
                "vilbert": dict(dropout=0.1, gaussian_noise=0.05), 
                "visual_bert": dict(dropout=0.1, gaussian_noise=0.05), 
                "mmbt_region": dict(dropout=0.1, gaussian_noise=0.05)},
    num_classes=2,
    gaussian_noise=0.0,
    dropout=0.0,
    word_masking_proba=0.125,
    featurizer="pass",
    final_layer_builder=fb_1d_loss_builder,
    internal_dims=768,
    classifier_dims=768,
    n_tokens_in=128,
    n_tokens_out=128,
    n_layers=0,
    attention_drop_proba=0.0,
    loss="focal",
    dice_loss_coef=0.0,
    auc_loss_coef=0.0,
    bbox_swaps=1,
    bbox_copies=1,
    bbox_deletes=0,
    bbox_gaussian_noise=0.01,
    view_transforms=[get_view_1()],
    finetune=False)

model_class = VilBertVisualBertModelV2
model = model_class(**model_params)
model = model.to(get_device())




Overriding option config to projects/hateful_memes/configs/vilbert/from_cc.yaml
Overriding option model to vilbert
Overriding option datasets to hateful_memes
Overriding option run_type to val
Overriding option checkpoint.resume_zoo to vilbert.finetuned.hateful_memes.from_cc_original
Overriding option evaluation.predict to true


Some weights of the model checkpoint at bert-base-uncased were not used when initializing ViLBERTBase: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
- This IS expected if you are initializing ViLBERTBase from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing ViLBERTBase from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViLBERTBase were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.v_embeddings.image_embeddings.weight', 'bert.v_embeddings.image_embeddings.bias', 'bert.v_embeddings.image_location_embeddings.weight', 'bert.v_embeddings.image_location_embeddings.bias', 'bert.v_embeddings.LayerNorm.weight', 'bert.v_embeddings.LayerNorm.

Overriding option config to projects/hateful_memes/configs/visual_bert/from_coco.yaml
Overriding option model to visual_bert
Overriding option datasets to hateful_memes
Overriding option run_type to val
Overriding option checkpoint.resume_zoo to visual_bert.finetuned.hateful_memes.from_coco
Overriding option evaluation.predict to true


Some weights of VisualBERTBase were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.embeddings.token_type_embeddings_visual.weight', 'bert.embeddings.position_embeddings_visual.weight', 'bert.embeddings.projection.weight', 'bert.embeddings.projection.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


LXRT encoder with 9 l_layers, 5 x_layers, and 5 r_layers.
Overriding option config to projects/hateful_memes/configs/mmbt/with_features.yaml
Overriding option model to mmbt
Overriding option datasets to hateful_memes
Overriding option run_type to val
Overriding option checkpoint.resume_zoo to mmbt.hateful_memes.features
Overriding option evaluation.predict to true


Config '/local/home/ahemf/mygit/facebook-hateful-memes/facebook_hateful_memes_detector/utils/faster_rcnn_R_101_C4_attr_caffemaxpool.yaml' has no VERSION. Assuming it to be compatible with latest v2.
Config '/local/home/ahemf/mygit/facebook-hateful-memes/facebook_hateful_memes_detector/utils/faster_rcnn_R_101_C4_attr_caffemaxpool.yaml' has no VERSION. Assuming it to be compatible with latest v2.


N tokens Out =  164 Classifier Dims =  768 Matches embedding_dims:  True


# Unimodal MLM

In [5]:

optimizer = torch.optim.AdamW
optimizer_params = dict(lr=1e-4, weight_decay=1e-2)

from facebook_hateful_memes_detector.models.MultiModal.VilBertVisualBertV2 import positive, negative
mlm_model = MLMSimCLR(model, 0.1, {1: negative, 0: positive}, augment_method, augment_method)
mlm_model = mlm_model.to(get_device())


In [None]:

lr_strategy = {
    "model": {
        "vilbert": {"finetune": False,},
        "visual_bert": {"finetune": False,},
        "mmbt_region": {"finetune": False,},
        "lxmert": {"finetune": False,},
    },
    "mlms": {"finetune": True},
    "simclr_layer": {"finetune": True},
}
epochs = 1
batch_size = 3
optimizer_class = torch.optim.AdamW
optimizer_params = dict(lr=1e-4, betas=(0.9, 0.98), eps=1e-08, weight_decay=1e-3)

_ = group_wise_finetune(mlm_model, lr_strategy)
params_conf, _ = group_wise_lr(mlm_model, lr_strategy)
optimizer = optimizer_class(params_conf, **optimizer_params)
train_losses, learning_rates, _ = train(mlm_model, optimizer, scheduler_init_fn, batch_size, epochs, dataset,
                                     model_call_back=None, accumulation_steps=4, plot=True,
                                     sampling_policy=None, class_weights=None)



mlm_model.plot_loss_acc_hist()
mlm_model.test_accuracy(batch_size, dataset)




Autocast =  True Epochs =  1 Divisor = 1 Examples = 10000 Batch Size =  3
Training Samples =  10000 Weighted Sampling =  False Num Batches =  3334 Accumulation steps =  4
[WARN]: Number of training batches not divisible by accumulation steps, some training batches will be wasted due to this.


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Batches', max=3334.0, style=ProgressStyle(description_wid…

FeatureExtractor : Loaded Model...
Modifications for VG in RPN (modeling/proposal_generator/rpn.py):
	Use hidden dim 512 instead fo the same dim as Res4 (1024).

Modifications for VG in RoI heads (modeling/roi_heads/fast_rcnn.py))
	Embedding: 1601 --> 256	Linear: 2304 --> 512	Linear: 512 --> 401

LXMERTFeatureExtractor : Loaded Model...


In [None]:
# torch.save(mlm_model.state_dict(), "lxmert-mlm-init.pth")
# mlm_model.load_state_dict(torch.load("lxmert-mlm-init.pth"))


In [None]:

epochs = 5
batch_size = 48
optimizer_class = torch.optim.AdamW
optimizer_params = dict(lr=1e-5, betas=(0.9, 0.98), eps=1e-08, weight_decay=1e-4)

lr_strategy = {
    "model": {
        "finetune": True,
        "lr": optimizer_params["lr"]
    },
    "mlm": {
        "finetune": True
    }
}

_ = group_wise_finetune(mlm_model, lr_strategy)
params_conf, _ = group_wise_lr(mlm_model, lr_strategy)
optimizer = optimizer_class(params_conf, **optimizer_params)
train_losses, learning_rates, _ = train(mlm_model, optimizer, scheduler_init_fn, batch_size, epochs, dataset,
                                     model_call_back=None, accumulation_steps=5, plot=True,
                                     sampling_policy=None, class_weights=None)



mlm_model.plot_loss_acc_hist()
acc = mlm_model.test_accuracy(batch_size, dataset)


In [None]:
acc = mlm_model.test_accuracy(batch_size, dataset)

In [None]:
torch.save(mlm_model.model.state_dict(), "lxmert-mlm.pth")


# AugSim
- Combine both unimodal and bimodal augsim using `random.random`
- Take hints from SimCLR
- We can do Text x Image (TODO: CrissCrossDataset for Augsim)


In [None]:
load_stored_params(model, "lxmert-mlm.pth")
set_global("cache_allow_writes", True)


In [None]:
adamw = torch.optim.AdamW
adamw_params = dict(lr=1e-5, betas=(0.9, 0.98), eps=1e-08, weight_decay=1e-3)
optimizer_class = adamw
optimizer_params = adamw_params


In [None]:
epochs = 10
batch_size = 64


_ = group_wise_finetune(model, lr_strategy_model)
params_conf, _ = group_wise_lr(model, lr_strategy_model)
optim = optimizer_class(params_conf, **optimizer_params)

_ = train_for_augment_similarity(model,
                                 optim,
                                 scheduler_init_fn,
                                 batch_size,
                                 epochs,
                                 dataset,
                                 augment_method=augment_method,
                                 model_call_back=None,
                                 collate_fn=my_collate,
                                 accumulation_steps=4,
                                 plot=True)
# 0.001580, 0.000527
# Try Augsim with L2 normed / LayerNormed vectors


In [None]:
torch.save(model.state_dict(), "lxmert-augsim.pth")
# model.load_state_dict(torch.load("lxmert-augsim.pth"))

# SimCLR style or Differentiator
- Combine Unimodal and Bimodal with probability
- In unimodal differentiator we only change either text or image
- Ability to use non-overlapping image sections.


In [None]:
load_stored_params(model, "lxmert-smclr.pth")


In [None]:
from facebook_hateful_memes_detector.utils import SimCLR

def simclr_aug(sampleList):
    sampleList = augment_method(sampleList.copy())
    s2 = sampleList.copy()
    s2.text = list(reversed(s2.text))
    s = merge_sample_lists(sampleList, s2)
    return s

# set_global("cache_allow_writes", False)


In [None]:
smclr = SimCLR(model, 768, 256, 0.05, simclr_aug, simclr_aug)
smclr = smclr.to(get_device())

lr_strategy_pre = {
    "finetune": True,
    "model": {
        "finetune": False,
    },
}

lr_strategy_post = {
    "finetune": True,
}

pre_lr, post_lr = 5e-5, 5e-5
pre_batch_size, post_batch_size = 256, 32
pre_epochs, full_epochs = 2, 5
collate_fn = my_collate

def simclr_aug(sampleList):
    sampleList = augment_method(sampleList.copy())


In [None]:
res = run_simclr(smclr, dataset, dataset, lr_strategy_pre, lr_strategy_post, pre_lr, post_lr,
           pre_batch_size, post_batch_size, pre_epochs, full_epochs,
           collate_fn)

res

# 0.3268


In [None]:
torch.save(model.state_dict(), "lxmert-smclr.pth")