In [None]:
# # For TPUs
# !pip install torch==1.9.0 torchvision torchaudio
# !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

!pip install torch torchvision torchaudio

!pip install transformers
!pip install timm
!pip install madgrad
!pip install git+https://github.com/openai/CLIP.git

In [None]:
from google.colab import drive
drive.mount('/content/drive/')
!pwd

In [None]:
!cp drive/MyDrive/MAMI/MAMI_data.zip . && unzip -qq MAMI_data.zip && rm MAMI_data.zip
# !cp drive/MyDrive/MAMI/fb_hateful_meme_data.zip . && unzip -qq fb_hateful_meme_data.zip && rm fb_hateful_meme_data.zip

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from madgrad import MADGRAD

import matplotlib.pyplot as plt

import transformers
from transformers import (
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

import sys
sys.path.append('/content/drive/MyDrive/MAMI/MisogynyMemeClassifierV16')

import mamipy as mami
from mamipy import data, model, trainer

In [None]:
is_debug = False

In [None]:
import importlib

device = 'cpu'

if importlib.util.find_spec("torch.cuda") is not None:
    from torch import cuda
    if cuda.is_available(): device = 'cuda'

tpu = None
if importlib.util.find_spec("torch_xla") is not None:
    import torch_xla.core.xla_model as xm
    tpu = xm.xla_device()
    if tpu is not None: device = tpu

device

In [None]:
MAMI_df = mami.data.load_data(name='MAMI', is_train=True)
display(MAMI_df)

MAMI_test_df = mami.data.load_data(name='MAMI', is_train=False)
display(MAMI_test_df)

In [None]:
fb_hateful_meme_df = mami.data.load_data(name='fb_hateful_meme')
display(fb_hateful_meme_df)

In [None]:
# clip_model, preprocess = clip.load("RN50x4", device=device, jit=False)

targets_cols=['misogynous', 'shaming', 'stereotype', 'objectification', 'violence']
image_patches=(2, 2)
num_labels = len(targets_cols)
max_tokens = 120 # 364 
max_grad_norm = 0.5
train_batch_size = 16
eval_batch_size = 16
image_size = 288 # preprocess.transforms[0].size # 288
num_train_epochs = 15

data_config = [
    {
        "name": "MAMI",
        "df": MAMI_df,
        "labels": ['misogynous', 'shaming', 'stereotype', 'objectification', 'violence'],
        "tasks": [
            # {"name": "Task_A", "labels": ['misogynous']},
            # {"name": "Task_B", "labels": ['shaming', 'stereotype', 'objectification', 'violence']},

            # {"name": "Task_B_shaming", "labels": ['shaming']},
            # {"name": "Task_B_stereotype", "labels": ['stereotype']},
            # {"name": "Task_B_objectification", "labels": ['objectification']},
            # {"name": "Task_B_violence", "labels": ['violence']},

            # {"name": "Task_B_shaming", "labels": ['misogynous', 'shaming']},
            # {"name": "Task_B_stereotype", "labels": ['misogynous', 'stereotype']},
            # {"name": "Task_B_objectification", "labels": ['misogynous', 'objectification']},
            # {"name": "Task_B_violence", "labels": ['misogynous', 'violence']},

            {"name": "MAMI", "labels": ['misogynous', 'shaming', 'stereotype', 'objectification', 'violence']},
            {"name": "Task_B", "labels": ['shaming', 'stereotype', 'objectification', 'violence']},
            {"name": "Task_A", "labels": ['misogynous']},
        ]
    },
    # {
    #     "name": "fb_hateful_meme",
    #     "df": fb_hateful_meme_df,
    #     "labels": ['hateful'],
    #     "tasks": [
    #         {"name": "Hateful_Meme", "labels": ['hateful']}
    #     ]
    # }
]

test_data_config = [
    {
        "name": "MAMI",
        "df": MAMI_test_df,
        "labels": ['misogynous', 'shaming', 'stereotype', 'objectification', 'violence'],
        "tasks": [
            {"name": "Task_B", "labels": ['shaming', 'stereotype', 'objectification', 'violence']},
            {"name": "Task_A", "labels": ['misogynous']},
        ]
    }
]

# bert-base-uncased
bert_model_name = 'Hate-speech-CNERG/bert-base-uncased-hatexplain'
image_encoders = ['clip', 'detr']
image_encoders_finetune = [None, None]

In [None]:
tokenizer = AutoTokenizer.from_pretrained(bert_model_name, do_lower_case=True)

data_limit = (10 * train_batch_size) if is_debug else None

train_loader, eval_loader, data_config, test_loader, test_data_config = mami.data.prepare(
    data_config,
    test_data_config,
    tokenizer,
    max_tokens=max_tokens,
    patches=image_patches,
    image_size=image_size,
    limit=data_limit,
    train_frac=0.8,
    train_batch_size=train_batch_size,
    eval_batch_size=eval_batch_size,
    image_encoders=image_encoders
)

gradient_accumulation_steps = min(1, (20 * data_limit) // len(train_loader)) if is_debug else 20
eval_every=len(train_loader) // ((min(1, (7 * data_limit) // len(train_loader) )) if is_debug else 7)

In [None]:
# import matplotlib.pyplot as plt
# d = [len(tokenizer.encode_plus(t)['input_ids']) for t in list(df['text_transcription'])]
# plt.hist(d)
# print(max(d))
# print(min(d))

In [None]:
model = mami.model.MAMIClassifierV16(
    bert_model_name,
    data_config=data_config,
    tokenizer=tokenizer,
    finetune_txt=False,
    enable_kl_loss=False,
    enable_embed_loss=False,
    negatives_are_close=False,
    classifier_layers_count=2,
    classifier_hidden_size=768,
    share_transformer_encoder=False,
    pool_txt=False,
    modal_transformer_encoder_nhead=8,
    modal_transformer_encoder_num_layers=6,
    pool_output=False,
    decoder_nhead=8,
    decoder_num_layers=6,
    image_encoders=['clip', 'detr'],#image_encoders,
    image_encoders_finetune=[None, None],#image_encoders_finetune,
    clip_num_patches=image_patches,
    detr_fallback_topk=4,
    projection_alignment=True
)
model.to(device);

In [None]:
_ = mami.trainer.load_checkpoint("/content/runs/MisogynyMemeClassifierV16/01/last.pt", model)

In [None]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ["bias", "LayerNorm.weight"]
weight_decay = 0.0005

optimizer_grouped_parameters = [
    {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": weight_decay},
    {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
]

t_total = (len(train_loader) // gradient_accumulation_steps) * num_train_epochs
warmup_steps = t_total // 10

optimizer = MADGRAD(optimizer_grouped_parameters, lr=2e-4)

scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, t_total)

In [None]:
trainer = mami.trainer.ModelTrainer(
    model,
    optimizer,
    scheduler,
    data_config,
    train_loader,
    eval_loader,
    test_data_config,
    test_loader,
    epochs=num_train_epochs,
    #path_prefix="/content/drive/MyDrive/MAMI/MisogynyMemeClassifierV15/runs/00/",
    path_prefix="/content/runs/MisogynyMemeClassifierV16/00/",
    threshold=0.5,
    eval_every=eval_every,
    gradient_accumulation_steps=gradient_accumulation_steps,
    max_grad_norm=max_grad_norm,
    device=device
)

In [None]:
trainer.run()

In [None]:
trainer.resume()

In [None]:
for name, p in model.named_parameters():
    if p.requires_grad:
        print(name)

In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/runs/MisogynyMemeClassifierV16/

In [None]:
import torch
import torch.nn as nn

In [None]:
from tqdm.notebook import tqdm

def predict(model, eval_data_config, loader, threshold=0.5):
    model.eval()

    batch_forward = mami.trainer.BatchForward(model, device, eval_data_config, threshold, emb_tracker=None, writer=None, log=None)

    with torch.no_grad():
        for i, batch in enumerate(tqdm(loader), 0):
            batch_forward.forward_eval(batch)

    return batch_forward.preds, batch_forward.proba, batch_forward.indicies

def predict_all(paths, eval_data_config, loader):
  preds = {}

  def find_task(c):
    for task in eval_data_config[0]['tasks']:
      for i, label in enumerate(task['labels']):
        if label == c: return task, i

  def find_ci(task_name, c):
    for task in eval_data_config[0]['tasks']:
      if task['name'] == task_name:
        for i, label in enumerate(task['labels']):
          if label == c: return i

  for i, c in enumerate(targets_cols):
    print(f"Predicting for {c}")
    max_score = -1
    max_path = None
    for path in paths:
      state = mami.trainer.load_checkpoint(f"{path}best_{c}.pt", model)
      score = state['best_val_accuracy'][c]['value']
      if score > max_score:
        max_score = score
        max_path = path

    state = mami.trainer.load_checkpoint(f"{max_path}best_{c}.pt", model)
    print('for', c, 'path', max_path, 'score', max_score)

    task_name = state['best_val_accuracy'][c]['task'] #task['name']
    c_i = find_ci(task_name, c)
    
    print('for', c, 'c_i is', c_i, 'task_name', task_name)

    preds_c, _, indicies_c = predict(model, eval_data_config, loader)
    print(preds_c)
    for j, _ in enumerate(preds_c[task_name]):
      file_name = loader.dataset.images[indicies_c[task_name][j]] # list(eval_data_config[0]['df']['file_name'])[indicies_c[j]]
      if preds.get(file_name) is None:
        preds[file_name] = []
      preds[file_name].append('1' if preds_c[task_name][j][c_i] else '0')

  return preds

def write_preds(path, preds):
    task_a_cols_idx = [0]
    task_b_cols_idx = [0, 1, 2, 3, 4]

    lines_a = []
    lines_b = []
    for file_name, preds_i in preds.items():
        line_a = f"{file_name}\t" + "\t".join([preds_i[j] for j in task_a_cols_idx])
        lines_a.append(line_a)
        line_b = f"{file_name}\t" + "\t".join([preds_i[j] for j in task_b_cols_idx])
        lines_b.append(line_b)
    
    with open(f"{path}task_a_6.tsv", 'w') as f:
        f.write("\n".join(lines_a))
        
    with open(f"{path}answer_6.txt", 'w') as f:
        f.write("\n".join(lines_b)) 

def for_comp():
  ROOT_PATH = '/content/data'
# ROOT_PATH = '/home/shared/users/ahmed-mahran/MAMI/data'

  def data_root_path(name='MAMI', is_train=True):
      tag = 'train' if is_train else 'test'
      return f"{ROOT_PATH}/{name}/{tag}/"

  def load_data(name='MAMI', is_train=True):
      import pandas as pd
      df = pd.read_csv(data_root_path(name, is_train) + 'list.csv', sep='\t')
      df.rename(columns={'Text Transcription':'text_transcription'}, inplace=True)
      return df

  dfd = load_data(is_train=False)
  display(dfd)

  eval_data_config = [
        {
            "name": "MAMI",
            "df": dfd,
            "labels": ['misogynous', 'shaming', 'stereotype', 'objectification', 'violence'],
            "tasks": [
                  {"name": "MAMI", "labels": ['misogynous', 'shaming', 'stereotype', 'objectification', 'violence']},
                #{"name": "Task_A", "labels": ['misogynous']},
                #{"name": "Task B", "labels": ['shaming', 'stereotype', 'objectification', 'violence']}
                # {"name": "Task_B_shaming", "labels": ['shaming']},
                # {"name": "Task_B_stereotype", "labels": ['stereotype']},
                # {"name": "Task_B_objectification", "labels": ['objectification']},
                # {"name": "Task_B_violence", "labels": ['violence']}
                # {"name": "Task_B_shaming", "labels": ['misogynous', 'shaming']},
                # {"name": "Task_B_stereotype", "labels": ['misogynous', 'stereotype']},
                # {"name": "Task_B_objectification", "labels": ['misogynous', 'objectification']},
                # {"name": "Task_B_violence", "labels": ['misogynous', 'violence']},

                {"name": "Task_A", "labels": ['misogynous']},
            ]
        }
    ]

  loader = mami.data.prepare_eval(
    dfd,
    tokenizer,
    max_tokens=max_tokens,
    patches=image_patches,
    image_size=image_size,
    limit=data_limit,
    frac=None,
    batch_size=16,
    image_encoders=image_encoders
  )
  preds = predict_all(['/content/runs/MisogynyMemeClassifierV16/00/','/content/runs/MisogynyMemeClassifierV16/01/'], eval_data_config, loader)
  print(preds)
  write_preds('/content/drive/MyDrive/MAMI/MisogynyMemeClassifierV16/', preds)

for_comp()