In [1]:
import os
import pathlib

from datasets import load_dataset, concatenate_datasets
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding
import wandb

from hf_wrapper import GPTForSequenceClassification
from tokenizer import load_tokenizer
from utils import flatten_multi_features, load_random_from_pretrained_model, compute_metrics

In [2]:
normal_checkpoint_location = pathlib.Path('./cache/checkpoints/russian_polish_normal_12_5_50k/ckpt.pt')
ipa_checkpoint_location = pathlib.Path('./cache/checkpoints/russian_polish_ipa_12_5_50k/ckpt.pt')
hf_cache = pathlib.Path('./cache')
training_checkpoints = pathlib.Path('./cache/checkpoints')
tokenizer_prefix = pathlib.Path('./cache/tokenizers')
ipa_tokenizer_prefix = 'bpe-rus-pol-ipa-number-preservation'
normal_tokenizer_prefix = 'bpe-rus-pol-normal-number-preservation'

dataset_name = 'iggy12345/russian-xnli-ipa-rosetta'

epochs = 3
context_size = 1024
batch_size = 16
learning_rate = 2e-5

In [3]:
def load_and_preprocess(suffix, split, tokenizer):
    ds = load_dataset(dataset_name, split=split, cache_dir=str(hf_cache))
    fields = [
        'hypothesis' if suffix == 'normal' else f'hypothesis-{suffix}',
        'premise' if suffix == 'normal' else f'premise-{suffix}'
    ]

    def preprocess(examples):
        features = flatten_multi_features(examples, fields)
        encoded = tokenizer(features, truncation=True, max_length=context_size)
        encoded['label'] = examples['label']
        return encoded

    return ds.map(preprocess, batched=True, num_proc=os.cpu_count())

In [4]:
def finetune_transcription(suffix: str):
    ipa = suffix == 'normal'

    checkpoint = ipa_checkpoint_location if ipa else normal_checkpoint_location

    project_name = f"debug-russian-rosetta-small-finetuning-xnli-random-initial"
    temporary_output_dir = training_checkpoints / f"{project_name}/"
    temporary_output_dir.mkdir(parents=True, exist_ok=True)

    vocab_path = tokenizer_prefix / f'{ipa_tokenizer_prefix if ipa else normal_tokenizer_prefix}-vocab.json'
    merges_path = tokenizer_prefix / f'{ipa_tokenizer_prefix if ipa else normal_tokenizer_prefix}-merges.txt'
    tokenizer = load_tokenizer(vocab_path, merges_path)

    base_model = load_random_from_pretrained_model(checkpoint, 'cuda')
    base_model.config.pad_token_id = tokenizer.pad_token_id
    base_model.config.padding_side = tokenizer.padding_side
    model = GPTForSequenceClassification(base_model, num_classes=3).to('cuda')

    train_dataset = load_and_preprocess(suffix, 'train', tokenizer)
    eval_dataset = load_and_preprocess(suffix, 'validation', tokenizer)

    training_args = TrainingArguments(
        eval_strategy="steps",
        eval_steps=1000,
        output_dir=str(temporary_output_dir),
        save_strategy='steps',
        save_steps=1000,
        metric_for_best_model="precision",
        load_best_model_at_end=True,
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=epochs,
        weight_decay=0.01,
        logging_steps=100,
        fp16=True,
        warmup_ratio=0.3,
        save_safetensors=False,
        # disable_tqdm=True,
    )

    wrun = wandb.init(entity='aaronjencks-the-ohio-state-university', project=project_name, name=suffix)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
        compute_metrics=compute_metrics,
    )

    print(f"Training model on {suffix}")
    trainer.train()

    print(f"Final evaluation on {suffix}")
    results = trainer.evaluate()
    print(results)

    wrun.finish()


In [5]:
for suffix in ["normal", "phonemizer", "epitran", "goruut"]:
    finetune_transcription(suffix)

number of parameters: 123.35M


[34m[1mwandb[0m: Currently logged in as: [33maaronjencks[0m ([33maaronjencks-the-ohio-state-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


  trainer = Trainer(


Training model on normal


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1000,1.1176,1.126177,0.346586,0.34707,0.346586,0.326792
2000,1.1108,1.11679,0.340964,0.343917,0.340964,0.298535
3000,1.0975,1.100971,0.382329,0.380973,0.382329,0.366453
4000,1.0966,1.091415,0.386345,0.382336,0.386345,0.377602
5000,1.0739,1.104519,0.395984,0.429739,0.395984,0.378754
6000,1.0424,1.071045,0.4249,0.440293,0.4249,0.406962
7000,1.0538,1.045136,0.438554,0.458285,0.438554,0.409338
8000,1.0273,1.02836,0.461446,0.473191,0.461446,0.463879
9000,1.0229,1.019781,0.473092,0.507633,0.473092,0.455885
10000,1.0126,1.01621,0.486747,0.486057,0.486747,0.485747


Final evaluation on normal


{'eval_loss': 0.9586529731750488, 'eval_accuracy': 0.5421686746987951, 'eval_precision': 0.5853163478210806, 'eval_recall': 0.5421686746987951, 'eval_f1': 0.5228079880826221, 'eval_runtime': 1.2896, 'eval_samples_per_second': 1930.778, 'eval_steps_per_second': 120.964, 'epoch': 3.0}


0,1
eval/accuracy,▁▁▂▄▄▆▆▆▇▆▇▆▆▇▆▇▇▇▇▇▇▇▇█▇▇███████▇██████
eval/f1,▁▂▃▃▄▆▅▆▇▇▇▄▅▇▅▇▇▇▇██▇▆▇██▇▇▇█▇▇██▇▇▇███
eval/loss,█▇█▆▅▄▄▃▃▄▂▄▄▂▃▃▃▃▃▂▂▂▂▃▂▁▁▂▂▂▁▂▁▂▁▂▁▁▁▂
eval/precision,▁▃▄▄▅▅▆▅▆▆▇▇▇▆▇▆▇▇▇▇▇██▇▇▇███▇▇▇██▇█▇▇██
eval/recall,▁▁▂▄▅▆▆▆▆▇▆▇▆▇▆▇▇▇▇▇█▇▇▇▇██▇█▇██▇███████
eval/runtime,▂▁▁▁▂▃▃▃▃▃▄▃▃▃▄▄▄▄▄▄▄▄▄▄▆▆▇▆▇█▆▇▇▆▆▆▇█▄▃
eval/samples_per_second,██▇█▆▆▆▆▅▆▆▆▆▅▅▅▅▅▅▅▅▅▅▃▃▂▂▃▂▁▂▂▄▃▂▇▁▅▅▆
eval/steps_per_second,█▇▆▆▅▆▆▆▅▅▅▅▅▅▅▅▅▅▄▅▄▅▄▄▃▂▂▂▁▂▁▃▂▂▁▄▂▂▅▆
train/epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇████
train/global_step,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇███

0,1
eval/accuracy,0.54217
eval/f1,0.52281
eval/loss,0.95865
eval/precision,0.58532
eval/recall,0.54217
eval/runtime,1.2896
eval/samples_per_second,1930.778
eval/steps_per_second,120.964
total_flos,0.0
train/epoch,3.0


number of parameters: 123.35M


  trainer = Trainer(


Training model on phonemizer


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1000,1.128,1.130153,0.329719,0.330247,0.329719,0.327625
2000,1.1141,1.114747,0.339357,0.3467,0.339357,0.316361
3000,1.1049,1.108775,0.352209,0.351473,0.352209,0.314064
4000,1.0981,1.094742,0.36988,0.367098,0.36988,0.336058
5000,1.0807,1.081045,0.393173,0.390794,0.393173,0.389228
6000,1.0787,1.074916,0.415663,0.414659,0.415663,0.409481
7000,1.0781,1.081992,0.397189,0.605258,0.397189,0.320595
8000,1.0575,1.058813,0.434538,0.455439,0.434538,0.398043
9000,1.0449,1.05379,0.439759,0.441125,0.439759,0.415537
10000,1.0518,1.054469,0.427711,0.427986,0.427711,0.416265


Final evaluation on phonemizer


{'eval_loss': 1.0819917917251587, 'eval_accuracy': 0.39718875502008033, 'eval_precision': 0.605257788806176, 'eval_recall': 0.39718875502008033, 'eval_f1': 0.32059547713808934, 'eval_runtime': 5.2548, 'eval_samples_per_second': 473.854, 'eval_steps_per_second': 29.687, 'epoch': 3.0}


0,1
eval/accuracy,▁▂▃▃▅▆▅▅▅▆▆▇▆▆▆▇▇▇▇▇█▇▇▇▇▇▇▆▇▇▇▇██▇▇███▃
eval/f1,▁▄▁▄▆▆▆▆▆▆▇▇▇▇▇▇▇▇█▇▇█▇▇▇▇▇█▇▇███▇█▇███▁
eval/loss,█▇▆▅▄▃▅▃▃▄▃▂▃▃▂▂▂▂▂▂▁▂▁▁▁▂▂▁▁▁▂▁▁▂▁▁▁▁▁▆
eval/precision,▁▂▂▂▃▅▅▄▆▆▇▅▇▆▇▅▇█▇▇▇▇▇▇▇█▇██▇▇█▇▇█▇▇█▇█
eval/recall,▁▃▃▄▄▆▆▅▅▆▆▅▆▆▆▇▇▇▇▇▇▇█▇▇▇▇▇▇▇██▇▇▇▇███▃
eval/runtime,▃█▅▃▅▅▄▃▁▂▂▃▁▄▃▄▄▄▂▃▄▅▃▁▁▃▂▃▂▂▂▂▂▂▂▂▃▂▂▂
eval/samples_per_second,▅▅▁▄▆▆▅▆▄▆▆▆▅█▅▅▅▄▄▅▇▅▅▅▄▆██▆▆▆▆▆▆▆▆▆▆▆▆
eval/steps_per_second,▅▁▆▄▆▄▅▇▇▆▅▅▅▅█▄▄▅▅▅▅▄▄▅▆█▆▆▆▆▆▆▆▆▆▆▆▆▆▆
train/epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇██
train/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▆▆▆▆▆▆▆▇▇▇█████

0,1
eval/accuracy,0.39719
eval/f1,0.3206
eval/loss,1.08199
eval/precision,0.60526
eval/recall,0.39719
eval/runtime,5.2548
eval/samples_per_second,473.854
eval/steps_per_second,29.687
total_flos,0.0
train/epoch,3.0


number of parameters: 123.35M


  trainer = Trainer(


Training model on epitran


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1000,1.1275,1.157328,0.320884,0.320935,0.320884,0.302282
2000,1.1162,1.129776,0.332129,0.330255,0.332129,0.29037
3000,1.1125,1.108303,0.336546,0.336777,0.336546,0.329328
4000,1.1093,1.103633,0.351406,0.35127,0.351406,0.347
5000,1.088,1.10602,0.36988,0.40216,0.36988,0.342595
6000,1.0719,1.089834,0.393574,0.425842,0.393574,0.375555
7000,1.0696,1.064749,0.422892,0.420803,0.422892,0.417637
8000,1.05,1.046804,0.434137,0.430154,0.434137,0.420998
9000,1.0407,1.042567,0.450201,0.470645,0.450201,0.421927
10000,1.0311,1.045683,0.449799,0.446536,0.449799,0.441809


Final evaluation on epitran


{'eval_loss': 0.9490161538124084, 'eval_accuracy': 0.5582329317269076, 'eval_precision': 0.5650241232109442, 'eval_recall': 0.5582329317269076, 'eval_f1': 0.557263009666008, 'eval_runtime': 3.0979, 'eval_samples_per_second': 803.763, 'eval_steps_per_second': 50.356, 'epoch': 3.0}


0,1
eval/accuracy,▁▂▃▄▅▆▆▆▆▆▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇█▇▇██████████
eval/f1,▁▂▂▂▃▄▅▆▆▆▆▅▆▆▅▅▅▆▇▇▆▇▇▇▇▇▇▇▇███████████
eval/loss,█▇▇▇▆▅▅▃▄▃▄▃▃▄▃▄▄▄▃▆▃▂▂▃▂▂▂▂▁▁▂▂▂▁▁▁▁▁▁▁
eval/precision,▁▁▂▄▄▅▆▆▅▆▆▆▆▆▆▆▆▆▇▇▇▆▇▇▇▇▇█▇▇▇▇▇███████
eval/recall,▁▂▂▄▅▅▅▅▅▆▆▅▆▆▆▆▇▇▆▇▇▇█▇▇▇▇▇████████████
eval/runtime,▁▁▁▁▂▂▃▃▅▆▅▅▅▅▅▆▆▆▅▆▅▆▅▆▆▅▆▆▅▆▅▆▅▅▆▆▆▅▆█
eval/samples_per_second,▇████▆▅▅▅▅▅▅▄▁▄▄▅▅▄▄▅▄▄▄▄▄▄▅▅▅▅▅▄▄▄▄▄▃▅▄
eval/steps_per_second,▇████▆▄▃▃▃▃▂▂▂▂▂▄▃▃▂▃▃▂▃▃▃▃▂▃▃▃▃▃▃▂▃▁▄▃▂
train/epoch,▁▁▂▂▂▂▂▂▂▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇███
train/global_step,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇█

0,1
eval/accuracy,0.55823
eval/f1,0.55726
eval/loss,0.94902
eval/precision,0.56502
eval/recall,0.55823
eval/runtime,3.0979
eval/samples_per_second,803.763
eval/steps_per_second,50.356
total_flos,0.0
train/epoch,3.0


number of parameters: 123.35M


  trainer = Trainer(


Training model on goruut


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1000,1.1482,1.148797,0.341365,0.342031,0.341365,0.337475
2000,1.1157,1.129809,0.33253,0.334253,0.33253,0.3212
3000,1.1139,1.11293,0.342972,0.345284,0.342972,0.341261
4000,1.1164,1.10934,0.351004,0.350062,0.351004,0.34258
5000,1.1079,1.099129,0.373494,0.375049,0.373494,0.371561
6000,1.0953,1.093369,0.381526,0.377712,0.381526,0.375965
7000,1.0855,1.095535,0.383133,0.387394,0.383133,0.366837
8000,1.0782,1.082809,0.41004,0.416262,0.41004,0.393703
9000,1.07,1.076824,0.424096,0.431363,0.424096,0.382152
10000,1.0629,1.060483,0.420482,0.418383,0.420482,0.415115


Final evaluation on goruut


{'eval_loss': 0.9838324189186096, 'eval_accuracy': 0.5140562248995983, 'eval_precision': 0.5362580400252271, 'eval_recall': 0.5140562248995983, 'eval_f1': 0.5157914915374591, 'eval_runtime': 4.0921, 'eval_samples_per_second': 608.494, 'eval_steps_per_second': 38.123, 'epoch': 3.0}


0,1
eval/accuracy,▁▁▁▂▃▄▅▅▆▅▅▅▆▆▇▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇█▇█████
eval/f1,▂▁▂▂▃▃▄▄▆▆▅▄▇▆▆▆▇▇▆▇▆▇▇▇▇▇▇▇▇▇▇▇█▇██████
eval/loss,██▇▆▅▄▄▅▃▅▃▃▃▃▂▄▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▂▁▁▁▁▁▁▁▁
eval/precision,▁▁▁▂▂▄▄▆▆▆▅▅▇▅▆▅▇▇▅▆▇▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇█
eval/recall,▁▁▂▃▃▄▅▆▅▆▅▅▆▆▆▆▇▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇████████
eval/runtime,▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▆▆▆▆▆▆█▇▂▄▄▃▁▃
eval/samples_per_second,▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▁▃▃▃▃▃▃▃▃▃▃▃▃▂█▆▅▆▇
eval/steps_per_second,▃▃▃▃▂▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▃▂▃▃▂▂▂▃▁▃▂▇▆▅▆██▆
train/epoch,▁▂▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▁▁▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████

0,1
eval/accuracy,0.51406
eval/f1,0.51579
eval/loss,0.98383
eval/precision,0.53626
eval/recall,0.51406
eval/runtime,4.0921
eval/samples_per_second,608.494
eval/steps_per_second,38.123
total_flos,0.0
train/epoch,3.0
