### Model Training Setup

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext jupyter_black

In [None]:
cd ..

#### Bart Example

In [None]:
# from transformers import BartForConditionalGeneration, BartTokenizer

# model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", forced_bos_token_id=0)
# tok = BartTokenizer.from_pretrained("facebook/bart-large")
# example_english_phrase = "Dominican Republic has form of government of <mask>"
# batch = tok(example_english_phrase, return_tensors="pt")
# generated_ids = model.generate(batch["input_ids"])
# tok.batch_decode(generated_ids, skip_special_tokens=True)

### Load data

In [None]:
import pandas as pd
from src.utils import load_fb15k237

PATH_FB15k237 = "data/datasets_knowledge_embedding/FB15k-237"

train, valid, test, entity2wikidata = load_fb15k237(PATH_FB15k237)
processed_data = pd.read_csv(PATH_FB15k237 + "/processed_data.csv")

### Load the model

In [None]:
from transformers import (
    BartForConditionalGeneration,
    BartTokenizer,
    BartConfig,
    DataCollatorForSeq2Seq,
)

import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MAX_LENGHT = 512
BATCH_SIZE = 8

# Select model
# MODEL = "facebook/bart-large"
# MODEL = "facebook/bart-base"
MODEL = "lucadiliello/bart-small"

# Load model and tokenizer
model = BartForConditionalGeneration.from_pretrained(MODEL).to(device)
tok = BartTokenizer.from_pretrained(MODEL, model_max_length=MAX_LENGHT)

### Some tests

In [None]:
# sample = processed_data.iloc[5]

# text_sample = sample.demonstration_input + sample.tail_text + "."
# print(text_sample)

# to_mask = sample.tail_text
# print(to_mask)

# ids_to_mask = tok.convert_tokens_to_ids(to_mask)
# print(ids_to_mask)
# print(tok.convert_ids_to_tokens(ids_to_mask))

# print(
#     tok.encode(
#         "i want to mask Hello world",
#         add_special_tokens=True,
#         max_length=MAX_LENGHT,
#         truncation=True,
#         padding="max_length",
#     )[:10]
# )

# print(
#     tok.encode(
#         "i want to mask <mask>",
#         add_special_tokens=True,
#         max_length=MAX_LENGHT,
#         truncation=True,
#         padding="max_length",
#     )[:10]
# )

# print(tok.all_special_tokens)
# print(tok.all_special_ids)

# tok.mask_token_id#

### Masking data

In [None]:
processed_data["data_input"] = (
    processed_data["demonstration_input"] + "%s." % tok.mask_token
)
processed_data["data_label"] = (
    processed_data["demonstration_input"] + processed_data["tail_text"] + "."
)

In [None]:
dev = True

if dev:
    processed_data = processed_data.head(10)

In [None]:
from transformers import DataCollatorForLanguageModeling

# Codifica as strings de entrada e rótulos como sequências de tokens BART
encoded_input = tok(
    list(processed_data["data_input"]),
    padding="max_length",
    truncation=True,
    return_tensors="pt",
    max_length=MAX_LENGHT,
    add_special_tokens=True,
    return_attention_mask=False,
)
encoded_label = tok(
    list(processed_data["data_label"]),
    padding="max_length",
    truncation=True,
    return_tensors="pt",
    max_length=MAX_LENGHT,
    add_special_tokens=True,
    return_attention_mask=False,
)

# Cria uma lista de exemplos
examples = []
for i in range(len(processed_data)):
    input_ids = encoded_input["input_ids"][i]
    labels = encoded_label["input_ids"][i]
    examples.append({"input_ids": input_ids, "labels": labels})

# Cria um objeto DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)

# Prepara os dados de treinamento
prepared_data = data_collator(examples)

### Training Model

In [None]:
from torch.utils.data import Dataset
import copy
import random


class DatasetKGC(Dataset):
    def __init__(self, data):
        self.data = data
        self.data["input_ids"] = self.data["input_ids"].to(device)
        self.data["labels"] = self.data["labels"].to(device)
        self.num_rows = self.data["input_ids"].shape[0]

    def __len__(self):
        return self.num_rows

    def __getitem__(self, idx):
        _input = self.data["input_ids"][idx].squeeze(0)
        label = self.data["labels"][idx].squeeze(0)

        return (_input, label)

In [None]:
datasetKGC = DatasetKGC(prepared_data)

In [None]:
from transformers import AdamW, get_scheduler
from torch.utils.data import DataLoader

data_loader = DataLoader(DatasetKGC(prepared_data), batch_size=8, shuffle=True)

In [None]:
# Test batch shape
train_features, train_labels = next(iter(data_loader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

In [None]:
# Setup
epochs = 5
loss_epoch = []
lr = 1e-3
cross = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
# Train

from tqdm import tqdm

# from tqdm.auto import tqdm

pbar = tqdm(range(epochs), desc="Epochs")


for epoch in pbar:
    epoch_loss = 0
    for _input, label in data_loader:

        pbar.set_description("Epoch %s" % epoch)
        pbar.refresh()

        model.zero_grad()

        _dt = model(_input, return_dict=True)
        logits = _dt.logits
        loss = cross(logits.view(-1, logits.size(-1)), label.view(-1))
        epoch_loss += loss.item()

        loss.backward()
        optimizer.step()
        
        pbar.set_postfix(loss=loss.item())
    torch.cuda.empty_cache()
    loss_epoch.append(epoch_loss)

In [None]:
import matplotlib.pyplot as plt

plt.title("Transformer Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.plot(loss_epoch)