In [None]:
! pip install transformers datasets evaluate jiwer accelerate

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import pandas as pd
import torch

from transformers import pipeline
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments, AutoModelForSeq2SeqLM, AutoTokenizer
from datasets import load_dataset, DatasetDict
import evaluate

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

batch_size = 16
max_length = 2048

ModuleNotFoundError: ignored

## Create Dataset
(A small subsample of it because training is hard)

In [None]:
dataset = load_dataset("data", data_files="spell_correction_task.tsv")

dataset['train'] = dataset['train'].shuffle(seed=42).select(range(10000)) # subset

train_testval = dataset['train'].train_test_split(0.2)
test_valid = train_testval['test'].train_test_split(0.5)

dataset = DatasetDict({
    'train': train_testval['train'],
    'test': test_valid['test'],
    'val': test_valid['train']})

dataset.set_format(type='torch', )
dataset

Downloading and preparing dataset csv/. to /root/.cache/huggingface/datasets/csv/.-741e6a5416becf81/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/.-741e6a5416becf81/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'corrected'],
        num_rows: 8000
    })
    test: Dataset({
        features: ['text', 'corrected'],
        num_rows: 1000
    })
    val: Dataset({
        features: ['text', 'corrected'],
        num_rows: 1000
    })
})

## Check network performance before finetuning

In [None]:
metric = evaluate.load("wer")

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

In [None]:
# The model we use
# model = "oliverguhr/spelling-correction-english-base"
model = 'google/flan-t5-small'

fix_spelling = pipeline("text2text-generation", model=model, device=device)

model = fix_spelling.model
tokenizer = fix_spelling.tokenizer

NameError: ignored

In [None]:
predictions = [item['generated_text'] for item in fix_spelling(dataset['test']['text'], max_length=max_length)]
references = dataset['test']['corrected']

wer = metric.compute(predictions=predictions, references=references)

print(f'Pre-Finetuning WER: {wer}') #0.10643249737372096

Pre-Finetuning WER: 1.2270769877113747


## Finetuninng

In [None]:
collator = DataCollatorForSeq2Seq(tokenizer, model, return_tensors='pt', )

### Tokenize dataset

In [None]:

def preprocess_function(examples):
    inputs = examples['text']
    targets = examples['corrected']
    model_inputs = tokenizer(inputs, truncation=True)

    labels = tokenizer(text_target=targets, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


tokenized_datasets = dataset.map(preprocess_function, remove_columns=dataset["train"].column_names, batched=True)
tokenized_datasets.set_format(type='torch',)
tokenized_datasets

Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 8000
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1000
    })
    val: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1000
    })
})

### Define WER as the metric of interest

In [None]:
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    preds = torch.tensor(preds)
    labels = torch.tensor(labels)

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True, max_length=max_length)

    # Replace -100 in the labels as we can't decode them.
    labels = torch.where(labels != -100, labels, torch.full_like(labels,tokenizer.pad_token_id))
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True, max_length=max_length)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
    predictions = [p for p in decoded_preds]
    references = [p[0] for p in decoded_labels]

    result = metric.compute(predictions=predictions, references=references)

    result = {"wer": result}
    return result


### Finetune

In [None]:

model.config.max_length = max_length
args = Seq2SeqTrainingArguments(
    'speller',
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=10,
    predict_with_generate=True,
    # fp16=True,
    push_to_hub=False,
    auto_find_batch_size = True,
    logging_steps=1
)


trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["val"],
    data_collator=collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Wer
1,0.6814,0.523798,0.15428
2,0.6314,0.487154,0.143429
3,0.6116,0.468253,0.137424
4,0.5067,0.462143,0.133684
5,0.4909,0.453325,0.130945
6,0.478,0.448163,0.129576
7,0.4653,0.443506,0.127627




Epoch,Training Loss,Validation Loss,Wer
1,0.6814,0.523798,0.15428
2,0.6314,0.487154,0.143429
3,0.6116,0.468253,0.137424
4,0.5067,0.462143,0.133684
5,0.4909,0.453325,0.130945
6,0.478,0.448163,0.129576
7,0.4653,0.443506,0.127627
8,0.6193,0.441971,0.127364
9,0.4695,0.440123,0.126732
10,0.4757,0.439634,0.126995


TrainOutput(global_step=5000, training_loss=0.5954516100734473, metrics={'train_runtime': 1046.5003, 'train_samples_per_second': 76.445, 'train_steps_per_second': 4.778, 'total_flos': 2299140234805248.0, 'train_loss': 0.5954516100734473, 'epoch': 10.0})

In [None]:
test_results = trainer.predict(tokenized_datasets['test'], max_length=max_length)

print(f'Post-Finetuning WER: {test_results.metrics["test_wer"]}')

Post-Finetuning WER: 0.12966755947691824
