In [4]:
# %run 01_preprocess.ipynb
%run 00_common.ipynb

0.27.0


In [22]:
MODEL_PATTERN='keep_all'
TORGO_TRAIN_TYPE=TorgoTrainType.WORD_KEEP.value

loaded_datasets = {}

for speaker in SPEAKERS:
    file_pattern = f'torgo_xlsr_finetune_{speaker}_{MODEL_PATTERN}_{{split}}.json'

    loaded_datasets[speaker] = load_dataset('json', data_files={
        'train': os.path.join(f'data/{TORGO_TRAIN_TYPE}', file_pattern.format(split='train')),
        'valid': os.path.join(f'data/{TORGO_TRAIN_TYPE}',  file_pattern.format(split='val')),
        'test': os.path.join(f'data/{TORGO_TRAIN_TYPE}',  file_pattern.format(split='test')),
    })

for speaker, dataset in loaded_datasets.items():
    print(f"SPEAKER {speaker} - Train Set Size: {len(dataset['train'])}")
    print(f"SPEAKER {speaker} - Valid Set Size: {len(dataset['valid'])}")
    print(f"SPEAKER {speaker} - Test Set Size: {len(dataset['test'])}")

Generating train split: 0 examples [00:00, ? examples/s]

Generating valid split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating valid split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

SPEAKER F01 - Train Set Size: 257
SPEAKER F01 - Valid Set Size: 29
SPEAKER F01 - Test Set Size: 32
SPEAKER F03 - Train Set Size: 251
SPEAKER F03 - Valid Set Size: 28
SPEAKER F03 - Test Set Size: 32
SPEAKER F04 - Train Set Size: 225
SPEAKER F04 - Valid Set Size: 26
SPEAKER F04 - Test Set Size: 28
SPEAKER M01 - Train Set Size: 299
SPEAKER M01 - Valid Set Size: 34
SPEAKER M01 - Test Set Size: 37
SPEAKER M02 - Train Set Size: 420
SPEAKER M02 - Valid Set Size: 47
SPEAKER M02 - Test Set Size: 52
SPEAKER M03 - Train Set Size: 226
SPEAKER M03 - Valid Set Size: 26
SPEAKER M03 - Test Set Size: 28
SPEAKER M04 - Train Set Size: 340
SPEAKER M04 - Valid Set Size: 38
SPEAKER M04 - Test Set Size: 42
SPEAKER M05 - Train Set Size: 348
SPEAKER M05 - Valid Set Size: 39
SPEAKER M05 - Test Set Size: 43


In [23]:
# Initialize the tokenizer
tokenizer = BartTokenizerFast.from_pretrained('facebook/bart-base')

In [24]:
def train_model_for_speaker(speaker_id):
    print(speaker_id)
    # Apply preprocess_function to train_data and val_data
    train_data=loaded_datasets[speaker_id]['train']
    val_data=loaded_datasets[speaker_id]['valid']
    train_data = train_data.map(preprocess_function, batched=True)
    val_data = val_data.map(preprocess_function, batched=True)

    # Prepare DataLoader for training and validation
    train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=16, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=16)

    output_dir = f"torgo_spell_correction_{MODEL_PATTERN}_{speaker_id}"
    print(output_dir)

    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="epoch",
        learning_rate=1e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        weight_decay=0.01,
        save_total_limit=3,
        num_train_epochs=40,
        predict_with_generate=True,
        push_to_hub=False,
        logging_steps=100
    )

    print("Train Dataset:", train_data)
    print("Validation Dataset:", val_data)
    print("Tokenizer:", tokenizer)
    print("Training Arguments:", training_args)

    model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # The Seq2SeqTrainer is created with the defined model, training arguments, datasets, tokenizer, and data_collator
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_data,
        eval_dataset=val_data,
        tokenizer=tokenizer,
        data_collator=data_collator
    )
    trainer.train()

In [10]:
train_model_for_speaker('F01')

F01


Map:   0%|          | 0/257 [00:00<?, ? examples/s]

Map:   0%|          | 0/29 [00:00<?, ? examples/s]

torgo_spell_correction_keep_all_F01
Train Dataset: Dataset({
    features: ['references_phoneme', 'predictions_phoneme', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 257
})
Validation Dataset: Dataset({
    features: ['references_phoneme', 'predictions_phoneme', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 29
})
Tokenizer: BartTokenizerFast(name_or_path='facebook/bart-base', vocab_size=50265, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)}, clean_up_tokenization_spaces=True)
Training Arguments: Seq2SeqTrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
1,No log,4.878552
2,No log,4.145248
3,No log,3.429711
4,No log,2.755198
5,No log,2.436084
6,5.236900,2.257398
7,5.236900,2.069376
8,5.236900,1.971072
9,5.236900,1.824376
10,5.236900,1.714227


In [11]:
train_model_for_speaker('F03')


F03


Map:   0%|          | 0/251 [00:00<?, ? examples/s]

Map:   0%|          | 0/28 [00:00<?, ? examples/s]

torgo_spell_correction_keep_all_F03
Train Dataset: Dataset({
    features: ['references_phoneme', 'predictions_phoneme', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 251
})
Validation Dataset: Dataset({
    features: ['references_phoneme', 'predictions_phoneme', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 28
})
Tokenizer: BartTokenizerFast(name_or_path='facebook/bart-base', vocab_size=50265, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)}, clean_up_tokenization_spaces=True)
Training Arguments: Seq2SeqTrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
1,No log,5.482511
2,No log,4.805676
3,No log,4.081069
4,No log,3.250211
5,No log,2.642261
6,No log,2.313696
7,4.638400,2.126961
8,4.638400,1.987642
9,4.638400,1.823175
10,4.638400,1.705786


In [12]:
train_model_for_speaker('F04')


F04


Map:   0%|          | 0/225 [00:00<?, ? examples/s]

Map:   0%|          | 0/26 [00:00<?, ? examples/s]

torgo_spell_correction_keep_all_F04
Train Dataset: Dataset({
    features: ['references_phoneme', 'predictions_phoneme', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 225
})
Validation Dataset: Dataset({
    features: ['references_phoneme', 'predictions_phoneme', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 26
})
Tokenizer: BartTokenizerFast(name_or_path='facebook/bart-base', vocab_size=50265, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)}, clean_up_tokenization_spaces=True)
Training Arguments: Seq2SeqTrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
1,No log,4.631664
2,No log,4.13206
3,No log,3.496891
4,No log,2.7421
5,No log,2.365767
6,No log,2.20317
7,5.337600,2.046761
8,5.337600,1.908385
9,5.337600,1.792579
10,5.337600,1.749781


In [13]:
train_model_for_speaker('M01')

M01


Map:   0%|          | 0/299 [00:00<?, ? examples/s]

Map:   0%|          | 0/34 [00:00<?, ? examples/s]

torgo_spell_correction_keep_all_M01
Train Dataset: Dataset({
    features: ['references_phoneme', 'predictions_phoneme', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 299
})
Validation Dataset: Dataset({
    features: ['references_phoneme', 'predictions_phoneme', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 34
})
Tokenizer: BartTokenizerFast(name_or_path='facebook/bart-base', vocab_size=50265, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)}, clean_up_tokenization_spaces=True)
Training Arguments: Seq2SeqTrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
1,No log,5.873299
2,No log,4.79385
3,No log,3.377897
4,No log,2.806203
5,No log,2.436958
6,5.289800,2.199484
7,5.289800,1.990299
8,5.289800,1.825338
9,5.289800,1.664042
10,5.289800,1.509864


In [14]:
train_model_for_speaker('M02')


M02


Map:   0%|          | 0/420 [00:00<?, ? examples/s]

Map:   0%|          | 0/47 [00:00<?, ? examples/s]

torgo_spell_correction_keep_all_M02
Train Dataset: Dataset({
    features: ['references_phoneme', 'predictions_phoneme', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 420
})
Validation Dataset: Dataset({
    features: ['references_phoneme', 'predictions_phoneme', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 47
})
Tokenizer: BartTokenizerFast(name_or_path='facebook/bart-base', vocab_size=50265, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)}, clean_up_tokenization_spaces=True)
Training Arguments: Seq2SeqTrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
1,No log,5.095481
2,No log,3.413417
3,No log,2.526167
4,5.344000,2.155766
5,5.344000,1.923764
6,5.344000,1.694354
7,5.344000,1.494337
8,2.341100,1.340538
9,2.341100,1.183524
10,2.341100,1.121591


In [15]:
train_model_for_speaker('M03')


M03


Map:   0%|          | 0/226 [00:00<?, ? examples/s]

Map:   0%|          | 0/26 [00:00<?, ? examples/s]

torgo_spell_correction_keep_all_M03
Train Dataset: Dataset({
    features: ['references_phoneme', 'predictions_phoneme', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 226
})
Validation Dataset: Dataset({
    features: ['references_phoneme', 'predictions_phoneme', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 26
})
Tokenizer: BartTokenizerFast(name_or_path='facebook/bart-base', vocab_size=50265, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)}, clean_up_tokenization_spaces=True)
Training Arguments: Seq2SeqTrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
1,No log,5.757432
2,No log,5.033386
3,No log,4.170192
4,No log,3.172798
5,No log,2.690087
6,No log,2.417926
7,5.411100,2.226091
8,5.411100,2.065876
9,5.411100,1.932645
10,5.411100,1.782761


In [25]:
train_model_for_speaker('M04')


M04


Map:   0%|          | 0/340 [00:00<?, ? examples/s]

Map:   0%|          | 0/38 [00:00<?, ? examples/s]

torgo_spell_correction_keep_all_M04
Train Dataset: Dataset({
    features: ['references_phoneme', 'predictions_phoneme', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 340
})
Validation Dataset: Dataset({
    features: ['references_phoneme', 'predictions_phoneme', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 38
})
Tokenizer: BartTokenizerFast(name_or_path='facebook/bart-base', vocab_size=50265, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)}, clean_up_tokenization_spaces=True)
Training Arguments: Seq2SeqTrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
1,No log,5.318884
2,No log,4.111108
3,No log,2.848101
4,No log,2.363951
5,5.311800,2.105358
6,5.311800,1.899361
7,5.311800,1.703793
8,5.311800,1.531732
9,5.311800,1.410571
10,2.382300,1.270169


In [26]:
train_model_for_speaker('M05')

M05


Map:   0%|          | 0/348 [00:00<?, ? examples/s]

Map:   0%|          | 0/39 [00:00<?, ? examples/s]

torgo_spell_correction_keep_all_M05
Train Dataset: Dataset({
    features: ['references_phoneme', 'predictions_phoneme', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 348
})
Validation Dataset: Dataset({
    features: ['references_phoneme', 'predictions_phoneme', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 39
})
Tokenizer: BartTokenizerFast(name_or_path='facebook/bart-base', vocab_size=50265, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)}, clean_up_tokenization_spaces=True)
Training Arguments: Seq2SeqTrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
1,No log,4.698354
2,No log,3.683243
3,No log,2.659508
4,No log,2.258935
5,5.205200,2.041618
6,5.205200,1.82717
7,5.205200,1.672076
8,5.205200,1.555382
9,5.205200,1.410544
10,2.315600,1.324415


In [27]:
from transformers import BartForConditionalGeneration

SPEAKERS = ['F01', 'F03', 'F04', 'M01', 'M02', 'M03', 'M04', 'M05']
models = {}

for speaker in SPEAKERS:
    model_name = f"torgo_spell_correction_{MODEL_PATTERN}_{speaker}/checkpoint-500"
    models[speaker] = BartForConditionalGeneration.from_pretrained(model_name)

In [28]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for speaker, model in models.items():
    models[speaker] = model.to(device)
    models[speaker].eval()

In [29]:
def evaluate_speaker(model, tokenizer, speaker_id):
    predictions = []
    references = []
    test_dataset = loaded_datasets[speaker_id]['test']
    test_dataset = test_dataset.map(preprocess_function, batched=True)
    model.eval()

    for example in test_dataset:
        input_text = f"{example['predictions_phoneme']}"
        with torch.no_grad():
            input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
            outputs = model.generate(input_ids=input_ids, max_length=max_length)

        predicted_sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)

        references.append(example['references_phoneme'])
        predictions.append(predicted_sentence)

    # Verify that the number of predictions and references are the same
    if len(predictions) == len(references):
        print("Number of predictions and references are the same.")
    else:
        print("Mismatch in the number of predictions and references.")

    # Print the number of predictions and references
    print("Number of predictions:", len(predictions))
    print("Number of references:", len(references))
    # print the length of the dataset
    print("Number of rows in dataset:", len(test_dataset))

    # Assuming 'predictions' and 'references' are your sequences
    # Calculate Word Error Rate (WER)
    wer_value = wer(predictions, references)
    wer_percentage = wer_value * 100
    print(f"WER for {speaker_id}: {wer_percentage:.2f}%")

    # Calculate Character Error Rate (CER)
    cer_value = cer(predictions, references)
    cer_percentage = cer_value * 100
    print(f"CER for {speaker_id}: {cer_percentage:.2f}%")

In [30]:
for speaker_id, model in models.items():
    evaluate_speaker(model, tokenizer, speaker_id)

Map:   0%|          | 0/32 [00:00<?, ? examples/s]

Number of predictions and references are the same.
Number of predictions: 32
Number of references: 32
Number of rows in dataset: 32
WER for F01: 61.11%
CER for F01: 45.58%


Map:   0%|          | 0/32 [00:00<?, ? examples/s]

Number of predictions and references are the same.
Number of predictions: 32
Number of references: 32
Number of rows in dataset: 32
WER for F03: 60.55%
CER for F03: 48.70%


Map:   0%|          | 0/28 [00:00<?, ? examples/s]

Number of predictions and references are the same.
Number of predictions: 28
Number of references: 28
Number of rows in dataset: 28
WER for F04: 65.22%
CER for F04: 50.51%


Map:   0%|          | 0/37 [00:00<?, ? examples/s]

Number of predictions and references are the same.
Number of predictions: 37
Number of references: 37
Number of rows in dataset: 37
WER for M01: 73.68%
CER for M01: 56.90%


Map:   0%|          | 0/52 [00:00<?, ? examples/s]

Number of predictions and references are the same.
Number of predictions: 52
Number of references: 52
Number of rows in dataset: 52
WER for M02: 61.54%
CER for M02: 43.90%


Map:   0%|          | 0/28 [00:00<?, ? examples/s]

Number of predictions and references are the same.
Number of predictions: 28
Number of references: 28
Number of rows in dataset: 28
WER for M03: 65.56%
CER for M03: 48.69%


Map:   0%|          | 0/42 [00:00<?, ? examples/s]

Number of predictions and references are the same.
Number of predictions: 42
Number of references: 42
Number of rows in dataset: 42
WER for M04: 66.89%
CER for M04: 48.87%


Map:   0%|          | 0/43 [00:00<?, ? examples/s]

Number of predictions and references are the same.
Number of predictions: 43
Number of references: 43
Number of rows in dataset: 43
WER for M05: 66.67%
CER for M05: 54.17%
