In [17]:
import os
import pathlib

from datasets import load_dataset, concatenate_datasets
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding
import wandb

from hf_wrapper import GPTForSequenceClassification
from tokenizer import load_tokenizer
from utils import flatten_multi_features, load_random_from_pretrained_model, compute_metrics

In [18]:
normal_checkpoint_location = pathlib.Path('./cache/checkpoints/russian_polish_normal_12_5_50k/ckpt.pt')
ipa_checkpoint_location = pathlib.Path('./cache/checkpoints/russian_polish_ipa_12_5_50k/ckpt.pt')
hf_cache = pathlib.Path('./cache')
training_checkpoints = pathlib.Path('./cache/checkpoints')
tokenizer_prefix = pathlib.Path('./cache/tokenizers')
ipa_tokenizer_prefix = 'bpe-rus-pol-ipa-number-preservation'
normal_tokenizer_prefix = 'bpe-rus-pol-normal-number-preservation'

dataset_name = {
    'rus': 'iggy12345/russian-xnli-ipa-rosetta',
    'pol': 'iggy12345/cdsc-e-ipa-epitran'
}

epochs = 3
context_size = 1024
batch_size = 16
learning_rate = 2e-5

In [19]:
def load_and_preprocess(lang: str, ipa: bool, split: str, tokenizer):
    ds = load_dataset(dataset_name[lang], split=split, cache_dir=str(hf_cache))
    column_names = ['hypothesis', 'premise']
    if lang == 'pol':
        column_names = ['sentence_A', 'sentence_B']
    suffix = 'phoneme' if lang == 'pol' else 'epitran'
    fields = [
        f'{c}-{suffix}' if ipa else c
        for c in column_names
    ]

    def preprocess(examples):
        features = flatten_multi_features(examples, fields)
        encoded = tokenizer(features, truncation=True, max_length=context_size)
        encoded['label'] = examples['label']
        return encoded

    return ds.map(preprocess, batched=True, num_proc=os.cpu_count())

In [20]:
def train_model(ipa: bool) -> Trainer:
    checkpoint = ipa_checkpoint_location if ipa else normal_checkpoint_location

    project_name = f"debug-russian-polish-small-finetuning-xnli-random-initial-epitran"
    temporary_output_dir = training_checkpoints / f"{project_name}-{'ipa' if ipa else 'normal'}/"
    temporary_output_dir.mkdir(parents=True, exist_ok=True)

    vocab_path = tokenizer_prefix / f'{ipa_tokenizer_prefix if ipa else normal_tokenizer_prefix}-vocab.json'
    merges_path = tokenizer_prefix / f'{ipa_tokenizer_prefix if ipa else normal_tokenizer_prefix}-merges.txt'
    tokenizer = load_tokenizer(vocab_path, merges_path)

    base_model = load_random_from_pretrained_model(checkpoint, 'cuda')
    base_model.config.pad_token_id = tokenizer.pad_token_id
    base_model.config.padding_side = tokenizer.padding_side
    model = GPTForSequenceClassification(base_model, num_classes=3).to('cuda')

    rus_train_dataset = load_and_preprocess('rus', ipa, 'train', tokenizer)
    pol_train_dataset = load_and_preprocess('pol', ipa, 'train', tokenizer)
    train_dataset = concatenate_datasets([rus_train_dataset, pol_train_dataset])

    rus_eval_dataset = load_and_preprocess('rus', ipa, 'validation', tokenizer)
    pol_eval_dataset = load_and_preprocess('pol', ipa, 'validation', tokenizer)
    eval_dataset = concatenate_datasets([rus_eval_dataset, pol_eval_dataset])

    training_args = TrainingArguments(
        eval_strategy="steps",
        eval_steps=1000,
        output_dir=str(temporary_output_dir),
        save_strategy='steps',
        save_steps=1000,
        metric_for_best_model="precision",
        load_best_model_at_end=True,
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=epochs,
        weight_decay=0.01,
        logging_steps=100,
        fp16=True,
        warmup_ratio=0.3,
        save_safetensors=False,
        # disable_tqdm=True,
    )

    wrun = wandb.init(entity='aaronjencks-the-ohio-state-university', project=project_name, name=f'{"ipa" if ipa else "normal"}')

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
        compute_metrics=compute_metrics,
    )

    print(f"Training model")
    trainer.train()

    wrun.finish()

    return trainer

In [21]:
def finetune_transcription(eval_lang: str, ipa: bool, model: Trainer):
    vocab_path = tokenizer_prefix / f'{ipa_tokenizer_prefix if ipa else normal_tokenizer_prefix}-vocab.json'
    merges_path = tokenizer_prefix / f'{ipa_tokenizer_prefix if ipa else normal_tokenizer_prefix}-merges.txt'
    tokenizer = load_tokenizer(vocab_path, merges_path)

    if eval_lang == 'both':
        rus_eval_dataset = load_and_preprocess('rus', ipa, 'validation', tokenizer)
        pol_eval_dataset = load_and_preprocess('pol', ipa, 'validation', tokenizer)
        eval_dataset = concatenate_datasets([rus_eval_dataset, pol_eval_dataset])
    else:
        eval_dataset = load_and_preprocess(eval_lang, ipa, 'validation', tokenizer)

    print(f"Final evaluation on {eval_lang}")
    results = model.evaluate(eval_dataset=eval_dataset)
    print(results)


In [None]:
model = train_model(False)

number of parameters: 123.35M


0,1
eval/accuracy,▁▁▂▂▃▃▄▅▅▆▆▅▆▆▆▇▇▇▇▇▆▇▇█▇▇▇██▇█
eval/f1,▁▁▁▃▄▃▅▄▆▆▇▃▇▆▅▇▇▆▆█▅▇▇█▇▇▇██▇▇
eval/loss,█▇▆▅▅▅▄▅▃▄▂▅▂▃▃▂▂▃▃▂▃▂▂▁▂▂▃▁▁▁▁
eval/precision,▁▁▂▂▃▃▃▄▅▅▆▇▆▆█▇▇▆▆▇▇▇▆▇▇▇▇█▇▇█
eval/recall,▁▁▂▂▃▃▄▅▅▆▆▅▆▆▆▇▇▇▇▇▆▇▇█▇▇▇██▇█
eval/runtime,▃▃▃▃▃▃▃▃▃▃▂▁▂▁▂▁▁▁▃▃▃▃▃▃▃▁▁▁▃▃█
eval/samples_per_second,▆▆▆▆▆▆▆▆▆▆▇█▇█▇███▆▆▆▆▆▆▆███▆▆▁
eval/steps_per_second,▆▆▆▆▆▆▆▆▆▆▇█▇█▇███▆▆▆▆▆▆▆███▆▆▁
train/epoch,▁▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇████
train/global_step,▁▁▁▁▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇██

0,1
eval/accuracy,0.5004
eval/f1,0.48737
eval/loss,0.98853
eval/precision,0.52566
eval/recall,0.5004
eval/runtime,3.3262
eval/samples_per_second,748.599
eval/steps_per_second,46.9
train/epoch,1.25379
train/global_step,31400.0


  trainer = Trainer(


Training model


Step,Training Loss,Validation Loss


In [None]:
for lang in ['rus', 'pol']:
    finetune_transcription(lang, False, model)

In [None]:
model = train_model(True)

In [None]:
for lang in ['rus', 'pol']:
    finetune_transcription(lang, True, model)