In [10]:
import os
import pathlib

from datasets import load_dataset, concatenate_datasets
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding
import wandb

from hf_wrapper import GPTForSequenceClassification
from tokenizer import load_tokenizer
from utils import flatten_multi_features, load_pretrained_model, compute_metrics

In [11]:
normal_checkpoint_location = pathlib.Path('./cache/checkpoints/russian_polish_normal_12_5_50k/ckpt.pt')
ipa_checkpoint_location = pathlib.Path('./cache/checkpoints/russian_polish_ipa_12_5_50k/ckpt.pt')
hf_cache = pathlib.Path('./cache')
training_checkpoints = pathlib.Path('./cache/checkpoints')
tokenizer_prefix = pathlib.Path('./cache/tokenizers')
ipa_tokenizer_prefix = 'bpe-rus-pol-ipa-number-preservation'
normal_tokenizer_prefix = 'bpe-rus-pol-normal-number-preservation'

dataset_name = 'iggy12345/russian-xnli-ipa-rosetta'

epochs = 3
context_size = 1024
batch_size = 16
learning_rate = 2e-5

In [12]:
def load_and_preprocess(suffix, split, tokenizer):
    ds = load_dataset(dataset_name, split=split, cache_dir=str(hf_cache))
    fields = [
        'hypothesis' if suffix == 'normal' else f'hypothesis-{suffix}',
        'premise' if suffix == 'normal' else f'premise-{suffix}'
    ]

    def preprocess(examples):
        features = flatten_multi_features(examples, fields)
        encoded = tokenizer(features, truncation=True, max_length=context_size)
        encoded['label'] = examples['label']
        return encoded

    return ds.map(preprocess, batched=True, num_proc=os.cpu_count())

In [13]:
def finetune_transcription(suffix: str):
    ipa = suffix == 'normal'

    checkpoint = ipa_checkpoint_location if ipa else normal_checkpoint_location

    project_name = f"debug-russian-rosetta-small-finetuning-xnli"
    temporary_output_dir = training_checkpoints / f"{project_name}/"
    temporary_output_dir.mkdir(parents=True, exist_ok=True)

    vocab_path = tokenizer_prefix / f'{ipa_tokenizer_prefix if ipa else normal_tokenizer_prefix}-vocab.json'
    merges_path = tokenizer_prefix / f'{ipa_tokenizer_prefix if ipa else normal_tokenizer_prefix}-merges.txt'
    tokenizer = load_tokenizer(vocab_path, merges_path)

    base_model = load_pretrained_model(checkpoint, 'cuda')
    base_model.config.pad_token_id = tokenizer.pad_token_id
    base_model.config.padding_side = tokenizer.padding_side
    model = GPTForSequenceClassification(base_model, num_classes=3).to('cuda')

    train_dataset = load_and_preprocess(suffix, 'train', tokenizer)
    eval_dataset = load_and_preprocess(suffix, 'validation', tokenizer)

    training_args = TrainingArguments(
        eval_strategy="steps",
        eval_steps=1000,
        output_dir=str(temporary_output_dir),
        save_strategy='steps',
        save_steps=1000,
        metric_for_best_model="precision",
        load_best_model_at_end=True,
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=epochs,
        weight_decay=0.01,
        logging_steps=100,
        fp16=True,
        warmup_ratio=0.3,
        save_safetensors=False,
        # disable_tqdm=True,
    )

    wrun = wandb.init(entity='aaronjencks-the-ohio-state-university', project=project_name, name=suffix)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
        compute_metrics=compute_metrics,
    )

    print(f"Training model on {suffix}")
    trainer.train()

    print(f"Final evaluation on {suffix}")
    results = trainer.evaluate()
    print(results)

    wrun.finish()


In [14]:
for suffix in ["normal", "phonemizer", "epitran", "goruut"]:
    finetune_transcription(suffix)

number of parameters: 123.35M


0,1
eval/accuracy,▁
eval/f1,▁
eval/loss,▁
eval/precision,▁
eval/recall,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁▂▃▃▄▅▆▆▇██
train/global_step,▁▂▃▃▄▅▆▆▇██

0,1
eval/accuracy,0.33574
eval/f1,0.30685
eval/loss,1.19436
eval/precision,0.33723
eval/recall,0.33574
eval/runtime,1.3116
eval/samples_per_second,1898.419
eval/steps_per_second,118.937
train/epoch,0.04074
train/global_step,1000.0


  trainer = Trainer(


Training model on normal


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1000,1.0936,1.133202,0.351004,0.359593,0.351004,0.332191
2000,1.0236,1.087112,0.395181,0.427678,0.395181,0.371076
3000,0.951,0.993132,0.510442,0.547379,0.510442,0.49992
4000,0.9094,0.982215,0.531325,0.590106,0.531325,0.510108
5000,0.8403,0.978217,0.553012,0.640688,0.553012,0.533311
6000,0.7986,0.836782,0.624096,0.659098,0.624096,0.624751
7000,0.8163,0.815164,0.636546,0.651711,0.636546,0.63675
8000,0.7942,0.886788,0.600402,0.679925,0.600402,0.589708
9000,0.7538,0.799834,0.635341,0.665448,0.635341,0.635406
10000,0.7752,0.788722,0.648193,0.667364,0.648193,0.648867


Final evaluation on normal


{'eval_loss': 0.6823201775550842, 'eval_accuracy': 0.7192771084337349, 'eval_precision': 0.7325102004555184, 'eval_recall': 0.7192771084337349, 'eval_f1': 0.719870922061566, 'eval_runtime': 1.3234, 'eval_samples_per_second': 1881.512, 'eval_steps_per_second': 117.878, 'epoch': 3.0}


0,1
eval/accuracy,▁▂▄▅▆▆▇▇▇▇▇▇▇▇▇███▇█████████████████████
eval/f1,▁▂▄▄▆▆▇▇▇▇▇█▇█▇▇▇███████▇███████████████
eval/loss,█▆▆▃▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▂▁▂▃▃▃▂▃▃▃▃▃▃▃▃▃▃▁
eval/precision,▁▅▅▇▇▇▇▇▇▇▇▇▇▇▇█████████████████████████
eval/recall,▁▄▄▅▆▆▇▇▇▇▇█▇▇▇█████████████████████████
eval/runtime,▁▄▁▅▃▆▆▅▄▆█▆▅▆▆▆▆▆▂▄▃▂▂▃▃▇▇▃▄▃▁▇▂▁▆▅▄▄▃▅
eval/samples_per_second,█▆█▇█▆█▄▄▅▄▄▄▂▁▄▄▄▄▄▅▆█▆▆▃▃▂▁▆▇▇▃▇█▄▅▆▆▅
eval/steps_per_second,▇▅▇▇▇▅▆▇▄▄▃▄▂▂▃▃▃▄▃▄▆▅▄▆▇▆▆▃▁▁▅▇█▇▃▇▅▅▅▆
train/epoch,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
train/global_step,▁▁▁▂▂▂▂▂▂▂▂▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████

0,1
eval/accuracy,0.71928
eval/f1,0.71987
eval/loss,0.68232
eval/precision,0.73251
eval/recall,0.71928
eval/runtime,1.3234
eval/samples_per_second,1881.512
eval/steps_per_second,117.878
total_flos,0.0
train/epoch,3.0


number of parameters: 123.35M


Map (num_proc=16):   0%|          | 0/392702 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2490 [00:00<?, ? examples/s]

  trainer = Trainer(


Training model on phonemizer


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1000,1.0891,1.094731,0.371486,0.390546,0.371486,0.319042
2000,1.0421,1.046531,0.458635,0.486189,0.458635,0.444652
3000,1.0063,1.015625,0.495984,0.541695,0.495984,0.483915
4000,0.9692,1.050555,0.479518,0.524072,0.479518,0.453335
5000,0.9447,1.012856,0.514458,0.540213,0.514458,0.503715
6000,0.9216,1.009262,0.494779,0.581409,0.494779,0.471037
7000,0.9194,0.938041,0.545783,0.549685,0.545783,0.543531
8000,0.901,0.972601,0.543775,0.606139,0.543775,0.530642
9000,0.8714,0.957498,0.557831,0.593701,0.557831,0.556275
10000,0.8871,0.942699,0.553414,0.572882,0.553414,0.549616


Final evaluation on phonemizer


{'eval_loss': 0.8471311926841736, 'eval_accuracy': 0.6240963855421687, 'eval_precision': 0.6620727339267688, 'eval_recall': 0.6240963855421687, 'eval_f1': 0.6243196783331949, 'eval_runtime': 5.5041, 'eval_samples_per_second': 452.387, 'eval_steps_per_second': 28.342, 'epoch': 3.0}


0,1
eval/accuracy,▁▄▃▄▄▅▄▄▄▆▆▅▄▅▆▆▇▅▇▆▇█▆▆▇▇▇█▇████▇▇█▇█▇▇
eval/f1,▁▄▅▄▅▆▆▆▆▆▇▇▆▇▇▇▆▇▇▇▇▇▇█▇█▇█████████████
eval/loss,█▇▆▆▄▅▄▃▄▄▃▃▃▄▃▃▃▄▂▄▂▃▂▄▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▂
eval/precision,▁▃▃▅▄▅▅▅▅▆▆▅▆▆▇▆▆▇▇▆▇▇▇▇▇█▇▇███▇██▇▇█▇██
eval/recall,▁▃▄▅▆▆▆▆▆▇▇▆▇▇▇▆▇▇█▇▇▇██▇█▇█████████████
eval/runtime,▃▃▃▃▄▃▃▃▂▁▂▂▂▅▅▃▁▅▇▇█▇▇▇▆▄▇▅█▇▃▃▃▃▃▃▃▃▃▄
eval/samples_per_second,▇▆▆▆▆████▄▄▄▄▃▆▄▂▂▁▂▂▇▃█▃▆▂▃▁▆▆▆▆▆▆▆▆▆▇▅
eval/steps_per_second,▆▆▅▅▆▅▆▇█▇▇▇▃▃█▄▁▂▁▁▇▂▂▅▄▄▁▂▅▆▆▆▆▅▆▅▅▆▆▄
train/epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇█
train/global_step,▁▁▁▁▁▂▂▂▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇█████

0,1
eval/accuracy,0.6241
eval/f1,0.62432
eval/loss,0.84713
eval/precision,0.66207
eval/recall,0.6241
eval/runtime,5.5041
eval/samples_per_second,452.387
eval/steps_per_second,28.342
total_flos,0.0
train/epoch,3.0


number of parameters: 123.35M


Map (num_proc=16):   0%|          | 0/392702 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2490 [00:00<?, ? examples/s]

  trainer = Trainer(


Training model on epitran


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1000,1.0901,1.1237,0.349398,0.392101,0.349398,0.252658
2000,1.0498,1.093651,0.400402,0.488999,0.400402,0.340479
3000,1.0083,1.016973,0.489558,0.526365,0.489558,0.480032
4000,0.9767,1.015113,0.499197,0.538537,0.499197,0.483013
5000,0.9431,1.048002,0.484739,0.546521,0.484739,0.459334
6000,0.9287,0.978682,0.515663,0.583335,0.515663,0.505116
7000,0.9174,0.93705,0.55502,0.571515,0.55502,0.552935
8000,0.9123,0.969351,0.54498,0.609914,0.54498,0.530153
9000,0.8834,0.941808,0.556225,0.609054,0.556225,0.550186
10000,0.8792,0.922701,0.573494,0.592135,0.573494,0.570557


Final evaluation on epitran


{'eval_loss': 0.8374748826026917, 'eval_accuracy': 0.6421686746987951, 'eval_precision': 0.6672181625450206, 'eval_recall': 0.6421686746987951, 'eval_f1': 0.6405633549479972, 'eval_runtime': 2.9624, 'eval_samples_per_second': 840.526, 'eval_steps_per_second': 52.659, 'epoch': 3.0}


0,1
eval/accuracy,▂▁▄▄▅▅▅▅▅▅▇▅▆▄▆▅▆▆▆▆▆▇▇▆▇▇▇▇▇▇███▇███▇██
eval/f1,▁▃▅▅▅▆▆▆▇▇▆▇▆▇▇▇▆▇▇▇▇▇▇▇▇▇█▇████████████
eval/loss,█▇▆▄▄▃▂▃▃▂▃▃▄▃▂▃▂▂▂▂▁▂▁▂▁▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁
eval/precision,▁▄▅▆▇▆▆▇▇▇▆▇▇▆▇▇▇▇▇▇▇█████▇▇████████████
eval/recall,▁▄▄▅▅▆▆▆▆▆▆▇▇▆▆▇▇▇▇▇▇▇██▇▇██████████████
eval/runtime,▃▂▃▂▃█▂▁▂▁▁▁▁▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eval/samples_per_second,▆▆▆▆▆▆▁▇█▇█▇███▇████▇███████████████████
eval/steps_per_second,▁▁▁▁▁▁▁▄█▇▇▇▅▇▆▇▇▇▇▅▇▇▇██▇▇▇▇▇█▇▇▇█▇▇▇▇▇
train/epoch,▁▁▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇█████
train/global_step,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▆▆▆▇▇▇▇▇████

0,1
eval/accuracy,0.64217
eval/f1,0.64056
eval/loss,0.83747
eval/precision,0.66722
eval/recall,0.64217
eval/runtime,2.9624
eval/samples_per_second,840.526
eval/steps_per_second,52.659
total_flos,0.0
train/epoch,3.0


number of parameters: 123.35M


Map (num_proc=16):   0%|          | 0/392702 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2490 [00:00<?, ? examples/s]

  trainer = Trainer(


Training model on goruut


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1000,1.0872,1.099726,0.375502,0.399593,0.375502,0.353725
2000,1.0451,1.078119,0.42249,0.477549,0.42249,0.384651
3000,1.0029,1.006109,0.491968,0.497328,0.491968,0.491358
4000,0.9959,1.03899,0.472289,0.500922,0.472289,0.453848
5000,0.9544,1.015684,0.508434,0.552518,0.508434,0.494596
6000,0.9326,0.97583,0.5249,0.56128,0.5249,0.519308
7000,0.9266,0.946005,0.551807,0.553757,0.551807,0.55095
8000,0.9192,0.981134,0.534538,0.577221,0.534538,0.525405
9000,0.883,0.964637,0.553414,0.577052,0.553414,0.55219
10000,0.891,0.95929,0.541365,0.554454,0.541365,0.53839


Final evaluation on goruut


{'eval_loss': 0.8661487698554993, 'eval_accuracy': 0.6253012048192771, 'eval_precision': 0.6469924392644436, 'eval_recall': 0.6253012048192771, 'eval_f1': 0.624724640446597, 'eval_runtime': 4.0586, 'eval_samples_per_second': 613.507, 'eval_steps_per_second': 38.437, 'epoch': 3.0}


0,1
eval/accuracy,▁▂▅▆▅▆▇▆▆▆▆▇▆▇▆▅▆▇▆▇▇▇▇▇▇▇▇▇█▇▇█████████
eval/f1,▁▄▅▆▅▅▆▆▇▆▆▆▇▆▆▇▇▇▇▇▆▇▆▇▇▇▇█▇████▇██████
eval/loss,█▇▅▆▅▄▄▃▄▄▃▃▃▃▃▃▃▂▄▃▃▂▃▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▂
eval/precision,▁▃▄▄▆▆▇▆▇▆▇▇▆▇▆▇▇▇▇▇▇▇▇▇▇▇█▇████████████
eval/recall,▁▂▅▆▆▆▆▆▆▆▆▇▆▅▇▇▇▇▇▇▇▇▇▇▇██▇████████████
eval/runtime,▃▁▄▃▂▁▂▃▃▄▃▃▄▆▄▃▃▂▃▅▃▃▄█▂▃▅▂▅▆▃▇▄▃▃▃▄▄▃▅
eval/samples_per_second,▅▅▇▄▇█▇▅▅▄▃▅▆▅▅▄▁▄▄▇▄▆▃▄▅▄▆▄▇▂▆▇▅▆▄▆▅▄▅▂
eval/steps_per_second,▆▆▅▅▅█▇▆▅▆▆▆▅▅▇▆▇▄▆▁▅▇▇▅▇▆▇▆▂▇▆▆▄▆▆▅▅▆▅▃
train/epoch,▁▁▁▁▁▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇█
train/global_step,▁▁▁▁▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇████

0,1
eval/accuracy,0.6253
eval/f1,0.62472
eval/loss,0.86615
eval/precision,0.64699
eval/recall,0.6253
eval/runtime,4.0586
eval/samples_per_second,613.507
eval/steps_per_second,38.437
total_flos,0.0
train/epoch,3.0
