In [6]:
import numpy as np

import datasets
import os
from tqdm import tqdm
import os
from utils import *
import argparse
import torch

from torch.utils.data import Dataset
import pickle

from tokenizers import ByteLevelBPETokenizer
from tokenizers.implementations import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing

from transformers import RobertaConfig
from transformers import RobertaTokenizerFast
from transformers import RobertaForMaskedLM
from transformers import LineByLineTextDataset
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from transformers import pipeline


In [None]:
# GET DATA ###############################################################################################################

# initialize
def initialize_data(mode='train', dir='../dataset'):
    data = {}
    data['dir'] = dir
    data['mode'] = mode
    data['src'] = read_data(f'src_{mode}.txt', dir)
    data['tgt'] = read_data(f'tgt_{mode}.txt', dir)

    if mode != 'train':
        data['pkl'] = read_data(f'ref_{mode}.pkl', dir, pkl=True)

    return data

# assert length
def get_data_length(data):
    assert len(data['src']) == len(data['tgt'])
    return len(data['src'])

# het item
def get_item(data, idx):
    if data['mode'] == 'train':
        return data['src'][idx], data['tgt'][idx]
    return data['src'][idx], data['tgt'][idx], data['pkl'][idx]

# read data
def read_data(file, dir, pkl=False):
    if pkl:
        return pickle.load(open(f'{dir}/{file}', 'rb'))

    with open(f'{dir}/{file}', 'r') as f:
        lines = f.readlines()
        return list(map(lambda x: x.strip(), lines))


In [None]:
# train data loader
def get_dataloader(batch_size):
    train_data = initialize_data(mode='train')
    val_data = initialize_data(mode='valid')

    train_loader = DataLoader(
        CustomDataset(train_data),  # Replace CustomDataset with your Dataset creation function
        batch_size=batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=4
    )
    val_loader = DataLoader(
        CustomDataset(val_data),  # Replace CustomDataset with your Dataset creation function
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=4
    )
    return iter(train_loader), iter(val_loader)


In [None]:
# test data loader
def get_testloader(batch_size):
    test_data = initialize_data(mode='test')
    test_loader = DataLoader(
        CustomDataset(test_data),  # Replace CustomDataset with your Dataset creation function
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True
    )
    return iter(test_loader)


In [None]:
def encode_batch(encoderTokenizer, decoderTokenizer, src, tgt, max_len=100):
    src_tok = encoderTokenizer(
        src,
        max_length=max_len,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )

    tgt_tok = decoderTokenizer(
        tgt,
        max_length=max_len,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )

    labels = tgt_tok.input_ids.clone()

    return src_tok.input_ids, src_tok.attention_mask, tgt_tok.input_ids, tgt_tok.attention_mask, labels


In [None]:
# Initialize data
train_data = initialize_data(mode='train', dir='../dataset')


In [None]:
# TRAINING ###############################################################################################################

os.environ['TOKENIZERS_PARALLELISM'] = 'true'

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
encoderTokenizer, decoderTokenizer = None, None
WARMUP_EPOCHS = 2                                                               # start saving model after #WARMUP_EPOCHS
PATIENCE = 4                                                                    # #epochs to wait before early stopping

def train(model, optimizer, args):
    global encoderTokenizer, decoderTokenizer, DEVICE

    scores = []
    for epoch in range(args.epochs):
        print(f"Epoch {epoch+1}/{args.epochs}")

        trainLoader, valLoader = get_dataloader(args.batch_size)
        metric = {'loss': [], 'sari': [], 'bleu': [], 'fkgl': []}

        # when dataloader runs out of batches, it throws an exception
        try:
            for source, target in tqdm(trainLoader):
                src_inp, _, _, _, labels = encode_batch(
                    encoderTokenizer, decoderTokenizer, source, target
                )

                optimizer.zero_grad(set_to_none=True)

                loss = model(
                    input_ids = src_inp.to(DEVICE),
                    labels = labels.to(DEVICE)
                )[0]

                metric['loss'] += [loss.item()]
                loss.backward()
                optimizer.step()
                # scheduler.step()
        except StopIteration:
            pass

        # get model performace on val set
        with torch.no_grad():
            try:
                for source, target, ref in tqdm(valLoader):
                    ref = np.array(ref).T.tolist()                              # transpose ref, order gets changed in datagen

                    src_inp, _, _, _, labels = encode_batch(
                        encoderTokenizer, decoderTokenizer, source, target
                    )

                    logits = model.generate(
                        input_ids=src_inp.to(DEVICE),
                        max_length=args.max_length
                    )
                    outputs = decoderTokenizer.batch_decode(
                        logits, skip_special_tokens=True
                    )

                    sari = sari_score(source, outputs, ref)
                    bleu = bleu_score(outputs, ref)
                    fkgl = fkgl_score(outputs)

                    metric['sari'] += sari
                    metric['bleu'] += bleu
                    metric['fkgl'] += fkgl
            except StopIteration:
                pass

        log = []
        for key in metric.keys():
            log.append(f'{key}: {np.mean(metric[key]):.4f}')
        print(' - '.join(log))

        scores.append(np.mean(metric['sari']))
        # save checkpoint for only the best model
        if epoch >= WARMUP_EPOCHS and scores[-1] == np.max(scores):
            torch.save({'epoch': epoch+1,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': loss,
                        }, args.save_path)
            print('checkpoint saved.')
        # early stopping
        elif len(scores) - np.argmax(scores) > PATIENCE:
            print('stopping training.')
            break

def main(args):
    global encoderTokenizer, decoderTokenizer, DEVICE

    print('using device:', DEVICE)
    print('save path:', args.save_path)
    print('model:', args.model)

    STEPS = 138500 // args.batch_size                                           # total training samples / batch size
    encoderTokenizer, decoderTokenizer, model = select_model(mod=args.model)

    model.config.max_length = args.max_length
    model.config.no_repeat_ngram_size = 3
    model = model.to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    # scheduler = torch.optim.lr_scheduler.OneCycleLR(
    #     optimizer,
    #     max_lr=args.lr*10,
    #     steps_per_epoch=STEPS,
    #     pct_start=0.15,
    #     epochs=args.epochs
    # )

    # start from last checkpoint
    if args.init_epoch > 0:
        print('loading from', args.save_path)
        checkpoint = torch.load(args.save_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        INIT_EPOCH = checkpoint['epoch']

    train(model, optimizer, args)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Arguments for training.')
    parser.add_argument(
        '--model', default='gpt2', type=str,
        choices=['gpt2', 'bert', 'bert_gpt2', 'gpt2_bert'],
        help='model type'
    )
    parser.add_argument(
        '--max_length', default=80, type=int,
        help='maximum length for encoder'
    )
    parser.add_argument(
        '--epochs', default=40, type=int,
        help='number of training epochs'
    )
    parser.add_argument(
        '--init_epoch', default=0, type=int,
        help='epoch to resume the training from'
    )
    parser.add_argument(
        '--batch_size', default=50, type=int,
        help='batch size for training'
    )
    parser.add_argument(
        '--lr', default=1e-4, type=float,
        help='learning rate for training'
    )
    parser.add_argument(
        '--save_path', default='../checkpoint/model.pt', type=str,
        help='model save path'
    )
    args = parser.parse_args()
    main(args)


In [None]:
VOCAB_SIZE = 50000
oscar_path = '../dataset/oscar.en.txt'
tokenizer_path = '../tokenizer'
model_path = '../RoBERTa'

oscar = datasets.load_dataset(
    'oscar',
    'unshuffled_deduplicated_en',
    split='train',
    streaming=True
)

print('[INFO] reading oscar_en corpus')
if not os.path.exists(oscar_path) or os.path.getsize(oscar_path) < 1_000_000:
    with open(oscar_path, 'w') as f:
        for num, batch in enumerate(oscar):
            f.write(batch['text'] + '\n')
            if num > 1_000_000:
                break
    print('[INFO] saved corpora, file size', os.path.getsize(oscar_path))

print('[INFO] training tokenizer')
tokenizer = ByteLevelBPETokenizer()
tokenizer.train(
    files=[oscar_path],
    vocab_size=VOCAB_SIZE,
    min_frequency=2,
    special_tokens=[
        "<s>",
        "<pad>",
        "</s>",
        "<unk>",
        "<mask>",
    ]
)
tokenizer.save_model(tokenizer_path)
print('[INFO] saved tokenizer')

tokenizer = ByteLevelBPETokenizer(
    f'{tokenizer_path}/vocab.json', f'{tokenizer_path}/merges.txt'
)
tokenizer._tokenizer.post_processor = BertProcessing(
    ("</s>", tokenizer.token_to_id("</s>")),
    ("<s>", tokenizer.token_to_id("<s>")),
)
tokenizer.enable_truncation(max_length=512)

config = RobertaConfig(
    vocab_size=VOCAB_SIZE,
    max_position_embeddings=514,
    num_attention_heads=12,
    num_hidden_layers=6,
    type_vocab_size=1,
)

tokenizer = RobertaTokenizerFast.from_pretrained(tokenizer_path, max_len=512)

model = RobertaForMaskedLM(config=config)
print('[INFO] model parameters:', model.num_parameters())

dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path=oscar_path,
    block_size=128,
)
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

training_args = TrainingArguments(
    output_dir=model_path,
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_gpu_train_batch_size=64,
    save_steps=10_000,
    save_total_limit=2,
    prediction_loss_only=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
)

print(f'[INFO] training RoBERTa on gpu: {torch.cuda.is_available()}')
trainer.train()
trainer.save_model(model_path)
print('[INFO] model saved')

fill_mask = pipeline(
    "fill-mask",
    model=model_path,
    tokenizer=tokenizer_path
)

print('[INFO] sanity check')
print(fill_mask('Let children play <mask>.'))
print(fill_mask('Sun rises in the <east>.'))
print(fill_mask('David went to a <mask> store to buy the toilet paper.'))
print('done')
