In [None]:
import json

# Base_directory
base_dir = './wizard_of_wikipedia/'

# Load the data
with open(base_dir + 'train.json') as f:
    train_data = json.load(f)
with open(base_dir + 'valid_random_split.json') as f:
    valid_data = json.load(f)
with open(base_dir + 'test_random_split.json') as f:
    test_data = json.load(f)

In [None]:
import torch
from transformers import GPTNeoForCausalLM, AutoTokenizer

model_id = "EleutherAI/gpt-neo-125M"

default_device = 'mps' # 'cpu' if not on apple silicon
device = torch.device('cuda' if torch.cuda.is_available() else default_device)

tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')
model = GPTNeoForCausalLM.from_pretrained(model_id).to(device)

In [None]:
tokenizer.pad_token = tokenizer.eos_token

In [None]:
def extract_checked_sentence(utterance):
    try:
        checked_sentence = list(utterance['checked_sentence'].values())[0]
        return 'PASSAGE: ' + checked_sentence + '\n'
    except:
        return ''

def parse_dialog(dialog):
        return '\n'.join([
            f'SPEAKER: {utterance["speaker"]}\n' + \
            extract_checked_sentence(utterance) + \
            f'TEXT: {utterance["text"]}\n'
        for utterance in dialog])

def parse_data(dataset):
    return [
        f'CHOSEN_TOPIC: {sample["chosen_topic"]}\n' + \
        f'PERSONA: {sample["persona"]}\n' + \
        parse_dialog(sample['dialog'])
    for sample in dataset]

In [None]:
train_parsed = parse_data(train_data)
valid_parsed = parse_data(valid_data)
test_parsed = parse_data(test_data)

In [None]:
# from parlai.tasks.convai2.build import build

# build({'datapath': './data/'})

In [None]:
# def parse_pc(path):
#     # Open file
#     with open('./data/ConvAI2/train_self_original.txt') as f:
#         # Read raw file lines
#         data = [line.strip() for line in f]
#     # Data set container
#     persona_chat = list()
#     # Now we iterate through lines and build the data set
#     for line in data:
#         # Split line data from initial index
#         line_idx, line_data = line.split(' ', 1)
#         # Check if new conversation is started
#         if line_idx == '1':
#             # Add new empthy dialogue in data set
#             persona_chat.append(
#                 {'persona_a': list(), 'persona_b': list(), 'utterances': list()}
#             )
#         # If the line is from Speaker A persona
#         if line_data.startswith('your persona: '):
#             # Append it to Persona A
#             persona_chat[-1]['persona_a'].append(line_data[len('your persona: '):])
#         # Else if the line is from Speaker B persona
#         elif line_data.startswith('partner\'s persona: '):
#             # Append it to Persona B
#             persona_chat[-1]['persona_b'].append(line_data[len('partner\'s persona: '):])
#         # Else the line is a regular dialogue line
#         else:
#             # Split utterances from distractors and separate A and B
#             utt_a, utt_b = line_data.split('\t\t')[0].split('\t')
#             # Append to dialogue utterances
#             persona_chat[-1]['utterances'].append(
#                 {'speaker': 'A', 'text': utt_a}
#             )
#             persona_chat[-1]['utterances'].append(
#                 {'speaker': 'B', 'text': utt_b}
#             )
            
#     return persona_chat

In [None]:
# training_data = parse_pc('./data/ConvAI2/train_both_original.txt')
# validation_data = parse_pc('./data/ConvAI2/valid_both_original.txt')

# training_data[0]

In [None]:
# def sample_to_string(sample, eos_token):
#     # Join strings of Persona A
#     persona_a = ' '.join(sample['persona_a'])
#     # Join strings of Persona B
#     persona_b = ' '.join(sample['persona_b'])
#     # Join dialogue strings
#     dialogue = eos_token.join(f"{utterance['speaker']}: {utterance['text']}" for utterance in sample['utterances'])
#     # Build the dialogue string
#     dialogue_string = f"Persona A: {persona_a}{eos_token}Persona B: {persona_b}{eos_token}{dialogue}{eos_token}"
    
#     return dialogue_string

In [None]:
# training_data_str = [sample_to_string(dialogue, tokenizer.eos_token) for dialogue in training_data]
# validation_data_str = [sample_to_string(dialogue, tokenizer.eos_token) for dialogue in validation_data]

# training_data_str[0]

In [None]:
# from datasets import Dataset

# train_data = Dataset.from_dict({'text': training_data_str})
# valid_data = Dataset.from_dict({'text': validation_data_str})

In [None]:
from datasets import Dataset

train_parsed = Dataset.from_dict({'text': train_parsed})
valid_parsed = Dataset.from_dict({'text': valid_parsed})
test_parsed = Dataset.from_dict({'text': test_parsed})

In [None]:
from datasets import DatasetDict

data = DatasetDict()
data['train'] = train_parsed
data['validation'] = valid_parsed
data['test'] = valid_parsed

In [None]:
def tokenize_function(examples):
    input_encodings = tokenizer(examples["text"], padding=True, truncation=True)
    sample = {
        'input_ids': input_encodings.input_ids
    }
    return sample

tokenized_data = data.map(tokenize_function, batched=True)

In [None]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    "cooler_trainer_name", 
    evaluation_strategy="steps",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    learning_rate=6.25e-5,
    lr_scheduler_type="linear",
    use_mps_device=True,
    gradient_checkpointing=True,
)

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model, 
    args=training_args, 
    train_dataset=tokenized_data['train'], 
    eval_dataset=tokenized_data['validation'],
    data_collator=data_collator
)

In [None]:
trainer.train()