In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
%cd gdrive/MyDrive/repos/knowledge-graph-completion/

In [None]:
!pip install colab-xterm
!pip install transformers==4.28.0 accelerate
%load_ext colabxterm

In [None]:
%xterm

In [None]:
!nvidia-smi

In [None]:
### SETTINGS ###

DATASETS = {
    "FB15k-237-DECODE-ONLY-LABEL": "data/data_processed/FB15k-237/decode_only_label/",
}
MODELS = {
    "bart-small": "lucadiliello/bart-small",
    "bart-base": "facebook/bart-base",
    "bart-large": "facebook/bart-large",
}

# Dataset
DATASET = "FB15k-237-DECODE-ONLY-LABEL"
MODEL = "bart-small"
MODEL_NAME = MODEL + "_" + DATASET
MODEL_PATH = f"models/{MODEL_NAME}"
MAX_LENGTH = 50

# Training
params = {
    # Dir
    "output_dir": f"models/{MODEL_NAME}/",
    # Batch
    "per_device_train_batch_size": 2,
    "per_device_eval_batch_size": 2,
    # Learning rate
    "learning_rate": 5e-5,
    "seed": 42,
    # Epochs
    "num_train_epochs": 50,
    # Logging
    "logging_dir": f"models/{MODEL_NAME}/logs",
    "logging_strategy": "epoch",
    "logging_steps": 10,
    # Evaluation
    "evaluation_strategy": "epoch",
    "eval_steps": 1,
    # Checkpoint
    "save_strategy": "epoch",
    "save_steps": 2,
    "save_total_limit": 2,
    "ddp_find_unused_parameters": False,
    "warmup_steps": 2,
}

### Load Model / Tokenizer

In [None]:
from transformers import (
    BartForConditionalGeneration,
    BartTokenizer,
)

import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = (
    BartForConditionalGeneration.from_pretrained(MODELS[MODEL], use_cache=False)
    .cuda()
    .float()
    .to(device)
)

tokenizer = BartTokenizer.from_pretrained(MODELS[MODEL])

### Load data

In [None]:
from src.datasetkgc import DatasetKGC

In [None]:
import torch

train_ds, valid_ds = torch.load(DATASETS[DATASET] + "train_ds_DEV.pth"), torch.load(
    DATASETS[DATASET] + "valid_ds_DEV.pth"
)

### Training Model

#### Using Loop

In [None]:
from transformers import AdamW, get_scheduler

epochs = params["num_train_epochs"]
epoch_accuracy_frequency = ["eval_steps"]
lr = params["learning_rate"]
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
from torch.utils.data import DataLoader

train_ds, valid_ds = (
    DataLoader(
        train_ds, batch_size=params["per_device_train_batch_size"], shuffle=False
    ),
    DataLoader(
        valid_ds, batch_size=params["per_device_eval_batch_size"], shuffle=False
    ),
)

In [None]:
# Verify if checkpoint exists

if os.path.exists(f"{params['output_dir']}loop_trainer/checkpoint/"):
    checkpoint_path = (
        f"{params['output_dir']}loop_trainer/checkpoint/"
        + sorted(os.listdir(f"{params['output_dir']}loop_trainer/checkpoint/"))[-1]
    )

    checkpoint = torch.load(checkpoint_path)

    start_epoch = checkpoint["epoch"]
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    train_epoch_loss = checkpoint["loss"]
    train_losses = checkpoint["train_losses"]
    valid_losses = checkpoint["valid_losses"]

else:
    os.makedirs(f"{params['output_dir']}/loop_trainer/checkpoint/", exist_ok=True)
    start_epoch = 0
    train_epoch_loss = 0
    train_losses = []
    valid_losses = []

In [None]:
params["num_train_epochs"] = 4

In [None]:
%%time

from tqdm.auto import tqdm

pbar = tqdm(range(start_epoch + 1, params["num_train_epochs"] + 1), desc="Epochs")

for epoch in pbar:
    pbar.set_description("Epoch %s" % epoch)
    pbar.refresh()

    # Checkpoint
    if epoch % params["save_steps"] == 0:
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": train_epoch_loss,
                "train_losses": train_losses,
                "valid_losses": valid_losses,
            },
            f"{params['output_dir']}/loop_trainer/checkpoint/epoch_{epoch}.pth",
        )

    model.train()

    train_epoch_loss = 0

    for batch in train_ds:
        optimizer.zero_grad()

        input_ids, attention_mask, labels = (
            batch["input_ids"],
            batch["attention_mask"],
            batch["labels"],
        )
        label = batch["labels"]

        outputs = model(
            input_ids.to(device),
            labels=label.to(device),
            attention_mask=attention_mask.to(device),
            return_dict=True,
        )

        loss = outputs.loss
        train_epoch_loss += loss.item()

        loss.backward()
        optimizer.step()

    train_losses.append(train_epoch_loss)

    # Flag to avoid multiple trains
    start_epoch += 1

    model.eval()

    if epoch % params["eval_steps"] == 0:
        valid_loss = 0

        for batch in valid_ds:
            input_ids, attention_mask, labels = (
                batch["input_ids"],
                batch["attention_mask"],
                batch["labels"],
            )
            label = batch["labels"]

            outputs = model(
                input_ids.to(device),
                labels=label.to(device),
                attention_mask=attention_mask.to(device),
                return_dict=True,
            )

            loss = outputs.loss
            valid_loss += loss.item()

        valid_losses.append(valid_loss)

    else:
        valid_losses.append(0)

    pbar.set_postfix(loss=train_epoch_loss)

    torch.cuda.empty_cache()

In [None]:
model.save_pretrained(f"{params['output_dir']}loop_trainer/trained_model/")

#### Using Trainer

In [None]:
import torch

train_ds, valid_ds = torch.load(DATASETS[DATASET] + "train_ds_DEV.pth"), torch.load(
    DATASETS[DATASET] + "valid_ds_DEV.pth"
)

In [None]:
loss_values = []


def compute_metrics(eval_pred):
    global loss_values
    loss = float(eval_pred["loss"])
    loss_values.append(eval_pred)

    return {"loss": loss}

In [None]:
from transformers import Trainer
from transformers import TrainingArguments
from transformers import DataCollatorWithPadding

os.makedirs(params["output_dir"] + "hf_trainer/", exist_ok=True)
os.makedirs(params["output_dir"] + "hf_trainer/logs/", exist_ok=True)

training_args = TrainingArguments(
    # Dir
    output_dir=params["output_dir"] + "hf_trainer/checkpoint/",
    # Batch
    per_device_train_batch_size=params["per_device_train_batch_size"],
    per_device_eval_batch_size=params["per_device_eval_batch_size"],
    # Learning Rate
    learning_rate=params["learning_rate"],
    seed=params["seed"],
    # Epoch
    num_train_epochs=params["num_train_epochs"],
    # logging
    logging_dir=params["output_dir"] + "hf_trainer/logs",
    logging_strategy=params["logging_strategy"],
    logging_steps=params["logging_steps"],
    # Evaluation
    # evaluation_strategy=params["evaluation_strategy"],
    # eval_steps=params["eval_steps"],
    # Checkpoint
    save_strategy=params["save_strategy"],
    save_steps=params["save_steps"],
    save_total_limit=params["save_total_limit"],
    # pretraining
    ddp_find_unused_parameters=params["ddp_find_unused_parameters"],
    warmup_steps=params["warmup_steps"],
    fp16=True,
    fp16_full_eval=True,
    # test
    eval_accumulation_steps=1,
)


data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer, padding="max_length", max_length=MAX_LENGTH
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
# Verfiy if checkpoint exists
if os.path.exists(f"{params['output_dir']}hf_trainer/checkpoint/"):

    trainer.train(resume_from_checkpoint=True)

else:
    trainer.train()

In [None]:
model.save_pretrained(f"{params['output_dir']}hf_trainer/trained_model/")

### Train plots

#### Loop

In [None]:
import matplotlib.pyplot as plt

plt.title(f"{f'{MODEL} - {DATASET}'} train")
plt.xlabel("Epoch")
plt.ylabel("Loss - Cross Entropy")
plt.plot(train_losses, label="Training Loss")
plt.plot(valid_losses, label="Validation Loss")

xmin, xmax, ymin, ymax = plt.axis()

plt.text(
    max(xmin, xmax) * 0.9,
    max(ymin, ymax) * 0.9,
    f'epochs = {params["num_train_epochs"]}\nlr={params["learning_rate"]}',
    horizontalalignment="center",
    verticalalignment="center",
)

plt.show()

In [None]:
import matplotlib.pyplot as plt

x = [i["epoch"] for i in trainer.state.log_history[:-1]]
y = [i["loss"] for i in trainer.state.log_history[:-1]]
plt.title(f"{f'{MODEL} - {DATASET}'} Loss")
plt.text(
    max(x) * 0.99,
    max(y) * 0.99,
    f'epochs = {params["num_train_epochs"]}\nlr={params["learning_rate"]}',
    ha="right",
    va="top",
)
plt.plot(x, y)
plt.xlabel("Epoch")
plt.ylabel("Loss - Cross Entropy")