### Model Training Setup

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext jupyter_black

In [2]:
import os

os.chdir("../")

In [20]:
### SETTINGS ###

DATASETS = {
    "FB15k-237-DECODE-ONLY-LABEL": "data/data_processed/FB15k-237/decode_only_label/",
    "ALL_DATA-DECODE-ONLY-LABEL": "data/data_processed/FB15k_FB15k237_WN18_WN18RR/",
}
MODELS = {
    "bart-small": "lucadiliello/bart-small",
    "bart-base": "facebook/bart-base",
    "bart-large": "facebook/bart-large",
}

# Dataset
DATASET = "ALL_DATA-DECODE-ONLY-LABEL"
MODEL = "bart-small"

MAX_LENGTH = 50
BATCH_SIZE = 1

# If True, use only DEV_BATCH of dataset
dev = False
DEV_BATCH = 100

### Load data

In [21]:
import pandas as pd
from src.utils import load_fb15k237

pd.set_option("display.max_columns", None)
pd.set_option("display.expand_frame_repr", False)
pd.set_option("max_colwidth", None)

# Path of processed datasets versioned

processed_data = pd.read_csv(DATASETS[DATASET] + "/processed_data.csv")

### Load the model

In [22]:
from transformers import (
    BartTokenizer,
    DataCollatorForSeq2Seq,
)

import torch

tokenizer = BartTokenizer.from_pretrained(MODELS[MODEL])

### Masking data

In [23]:
processed_data["data_input"] = (
    processed_data["demonstration_input"] + "%s." % tokenizer.mask_token
)
processed_data["data_label"] = processed_data["tail_text"]

if dev:
    if DEV_BATCH == -1:
        pass
    else:
        processed_data = processed_data.head(DEV_BATCH)

In [24]:
from src.datasetkgc import DatasetKGC, generate_train_valid_dataset

In [16]:
%%time
train_ds, valid_ds = generate_train_valid_dataset(processed_data, tokenizer, MAX_LENGTH)

  0%|          | 0/462294 [00:00<?, ?it/s]

  0%|          | 0/51367 [00:00<?, ?it/s]

CPU times: total: 2min 46s
Wall time: 2min 47s


In [18]:
torch.save(train_ds, DATASETS[DATASET] + "/train_ds.pth")
torch.save(valid_ds, DATASETS[DATASET] + "/valid_ds.pth")