In [17]:
from datasets import Dataset
import evaluate
from collections import Counter
from itertools import chain
from spacy.lang.en import English
from transformers import CONFIG_MAPPING, AutoModelForCausalLM

In [5]:
train_dataset = []
nlp = English()
tokenizer = nlp.tokenizer
path = 'data/e2e_data/src1_train.txt'
with open(path, 'r') as ff:
    for row in ff:
        word_lst = row.split('||')[1]
        word_lst = [x.text for x in tokenizer(word_lst)]
        train_dataset.append(word_lst)

counter = Counter()
for input_ids in train_dataset:
    counter.update(input_ids)

vocab = {'START': 0, 'END': 1, 'UNK':2, 'PAD':3}
for k, v in counter.items():
    if v > 10:
        vocab[k] = len(vocab)

train_datasets = Dataset.from_dict({'text': train_dataset})
raw_datasets = train_datasets.train_test_split(0.01)
raw_datasets.vocab = vocab
raw_datasets['validation'] = raw_datasets['test']

In [6]:
config = CONFIG_MAPPING['gpt2']()

tokenizer = raw_datasets.vocab
reverse_tokenizer = {v: k for k, v in tokenizer.items()}

config.vocab_size = len(tokenizer)
model = AutoModelForCausalLM.from_config(config)
model.resize_token_embeddings(len(tokenizer))

column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]

In [None]:
def tokenize_function(examples):
    vocab_dict = raw_datasets.vocab
    input_ids = [[0] + [vocab_dict.get(x, vocab_dict['UNK']) for x in seq] + [1] for seq in examples['text']]
    return {'input_ids': input_ids}

tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    num_proc=4,
    remove_columns=column_names)

In [None]:
block_size = 64

def group_texts(examples):
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    result = {
        k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    num_proc=4)

train_dataset = lm_datasets["train"]
eval_dataset = lm_datasets["validation"]

In [None]:
def preprocess_logits_for_metrics(logits, labels):
    print(logits[0].shape, logits[1].shape)
    if type(logits) == tuple:
        return logits[0].argmax(dim=-1)
    else:
        return logits.argmax(dim=-1)

metric = evaluate.load("accuracy")

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    labels = labels[:, 1:].reshape(-1)
    preds = preds[:, :-1].reshape(-1)
    return metric.compute(predictions=preds, references=labels)