In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
from pprint import pprint

In [None]:
import torch
import math
import transformers
from transformers import (
    RobertaConfig,
    RobertaModel,
    AutoTokenizer,
    pipeline,
    AutoModel,
    RobertaTokenizerFast
)

In [None]:
import numpy as np
import pandas as pd

In [None]:
import random

In [None]:
import pathlib

In [None]:
# DEFINE THE MODEL

configuration = RobertaConfig()
configuration.vocab_size = 65536
configuration.bos_token_id = 0
configuration.device = "cpu"
# configuration.pad_token_id = 1
configuration.eos_token_id = 2
configuration.pad_token_id = 0
pprint(configuration)

# Fine Tune

## load meronyms

In [None]:
# !pip install datasets

In [None]:
# text = "ice is a form of <mask>"

In [None]:
# inputs = tokenizer(text, return_tensors="pt")
# token_logits = model_ft(**inputs).logits

# # Find the location of [MASK] and extract its logits
# mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
# mask_token_logits = token_logits[0, mask_token_index, :]

# # Pick the [MASK] candidates with the highest logits
# k = 10
# top_k_tokens = torch.topk(mask_token_logits, k, dim=1).indices[0].tolist()

# for token in top_k_tokens:
#     print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'")

In [None]:
def load_meronyms_bc(path):
    # 1st column-pair (B, C), skip 1st 4 text data
    meronyms = pd.read_csv(path)
    meronyms = meronyms.iloc[8:, 1:3].copy().reset_index(drop=True)
    meronyms.columns = ["part", "whole"]
    return meronyms.dropna().reset_index(drop=True)

def load_meronyms_de(path):
    # for 2nd-column pair: (D, E)
    meronyms = pd.read_csv(path)
    meronyms = meronyms.iloc[4:, 3:5].copy().reset_index(drop=True)
    meronyms.columns = ["whole", "part"]
    return meronyms.dropna().reset_index(drop=True)


def load_meronyms_fg(path):
    # for 3rd-column pair: (F, G)
    meronyms = pd.read_csv(path)
    meronyms = meronyms.iloc[4:, 5:7].copy().reset_index(drop=True)
    meronyms.columns = ["whole", "part"]
    return meronyms.dropna().reset_index(drop=True)

In [None]:
meronyms = load_meronyms_bc("data/ENVO Tags - Relationships_Lexico.csv")
# meronyms = load_meronyms_de("data/ENVO Tags - Relationships_Lexico.csv")
# meronyms = load_meronyms_fg("data/ENVO Tags - Relationships_Lexico.csv")
meronyms

In [None]:
meronyms.shape

In [None]:
meronyms.head()

In [None]:
# meronyms["part"] = meronyms.apply(lambda r: "_".join(r.part), axis=1)
# meronyms["whole"] = meronyms.apply(lambda r: "_".join(r.whole), axis=1)

## preprocess texts for MLM

In [None]:
from typing import List

In [None]:
def augment_pairs(part:str, whole: str) -> List[str]:
    return [
        f" {part} is a part of {whole}.",
        f" {part} is a component of {whole}."
    ]

In [None]:
def augment_pairs(data: pd.DataFrame) -> pd.DataFrame:
    data = data.copy()
    data["augmented"] = data.apply(
        lambda r: [
            f" {r.part} is a part of {r.whole}.",
            f" {r.part} is a component of {r.whole}.",
    #         f" {r.whole} is composed of {r.part}.",
    #         f" {r.whole} consists of {r.part}.",
        ], 
        axis=1
    )
    return data.explode("augmented").reset_index(drop=True)

In [None]:
meronyms = augment_pairs(meronyms)

In [None]:
meronyms.shape[0]

In [None]:
meronyms.head()

In [None]:
with open("data/meronyms-augmented.txt", "w") as f:
    f.writelines("\n".join(meronyms["augmented"]))

## create dataloader for MLM

In [None]:
from transformers import LineByLineTextDataset
from transformers import DataCollatorForLanguageModeling

In [None]:
from dataclasses import dataclass

In [None]:
from datasets import load_dataset

In [None]:
# tokenizer = AutoTokenizer.from_pretrained("./nasa-wiki-weighted-tokenizer-10-3-22/")
# tokenizer = AutoTokenizer.from_pretrained("roberta-base")
# tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
tokenizer = AutoTokenizer.from_pretrained("data/sq2-v6/train-watbertv6-squad-2ep/")

In [None]:
tokenizer

In [None]:
dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path="data/meronyms-augmented.txt",
    block_size=128,
)

In [None]:
len(dataset)

In [None]:
# @dataclass
# class CustomDataCollatorForLanguageModeling(DataCollatorForLanguageModeling):
#     pass

In [None]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.5
)

## Trainer

In [None]:
from transformers import AutoModelForMaskedLM

In [None]:
from transformers import Trainer, TrainingArguments

In [None]:
import wandb

In [None]:
model = AutoModelForMaskedLM.from_pretrained("data/sq2-v6/train-watbertv6-squad-2ep/")
# model = AutoModelForMaskedLM.from_pretrained("roberta-base")

In [None]:
model.device

In [None]:
wandb.login()

In [None]:
wandb.init(
    project="llm-test",
    entity="nish-test",
    tags=[
        "envo-mlm", 
        pathlib.Path(model.name_or_path).stem,
    ]
)

In [None]:
training_args = TrainingArguments(
    f"tmp/finetuned/envo-mlm/{pathlib.Path(model.name_or_path).stem}",
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=128,
    learning_rate=3e-5,
    save_steps=10_000,
    save_total_limit=2,
    prediction_loss_only=True,
    report_to="wandb",
      weight_decay=0.01,
#     evaluation_strategy='steps',
    logging_steps=1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
)

In [None]:
%%time
trainer.train()

In [None]:
trainer.save_model("tmp/finetuned/envo-mlm/test/")

# Accuracy check

In [None]:
from typing import List, Dict

In [None]:
unmasker = pipeline(
    "fill-mask",
    model="tmp/finetuned/envo-mlm/test/",
    tokenizer="data/sq2-v6/train-watbertv6-squad-2ep/"
#     tokenizer=tokenizer
#     tokenizer="roberta-base"
)

In [None]:
data = load_meronyms_de("data/ENVO Tags - Relationships_Lexico.csv")

In [None]:
data.head()

In [None]:
# mask the whole
data = data[(data.whole.str.split().str.len())==1].reset_index(drop=True)
data["test"] = data.apply(lambda r: f" {r.part} is a part of <mask>.", axis=1)

In [None]:
# mask the part
# data = data[(data.part.str.split().str.len())==1].reset_index(drop=True)
# data["test"] = data.apply(lambda r: f" <mask> is a part of {r.whole}.", axis=1)

In [None]:
data.shape

In [None]:
data.head()

In [None]:
len(data)

In [None]:
predictions = unmasker(data.test.to_list(), top_k=5)

In [None]:
predictions

In [None]:
def analyze(predictions: List[List[dict]], gts: List[str]):
    matched = 0
    for (gt, preds) in zip(gts, predictions):
        gt = gt.strip()
        preds = list(map(lambda p: p["token_str"].strip(), preds))
        if gt in preds:
            matched += 1
    return matched/len(gts)

In [None]:
# if "whole" was masked
analyze(predictions, data.whole.to_list()) * 100

In [None]:
# if "part" was masked
analyze(predictions, data.part.to_list()) * 100