In [1]:
import pandas as pd
test_df = pd.read_csv('test_data.csv')
print(f"Val length: {len(val_df)}")

from transformers import AutoModelForSeq2SeqLM, T5Tokenizer, T5ForConditionalGeneration
gec_tokenizer = T5Tokenizer.from_pretrained("./gec_model_final")
gec_model = AutoModelForSeq2SeqLM.from_pretrained("./gec_model_final")

base_tokenizer = T5Tokenizer.from_pretrained("t5-base")
base_model = T5ForConditionalGeneration.from_pretrained("t5-base")

Val length: 5000


For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [3]:
from tqdm import tqdm
from datasets import load_metric
rouge_metric = load_metric("rouge")

BATCH_SIZE = 10
corrections = []
base_corrections = []

for i in tqdm(range(0, len(test_df), BATCH_SIZE)):
    sentence_batch = list(test_df["in"])[i:i + BATCH_SIZE]
    
    batch_gec_tokens = gec_tokenizer(sentence_batch, padding='max_length', max_length=1024, return_tensors="pt")
    translated = gec_model.generate(**batch_gec_tokens,num_beams=5, num_return_sequences=1)
    corrs = gec_tokenizer.batch_decode(translated, padding="longest", skip_special_tokens=True)
    corrections.extend(corrs)
    
    base_sentence_batch = [f"correct grammar: {sentence}" for sentence in sentence_batch]
    
    base_batch_tokens = base_tokenizer(base_sentence_batch, padding='max_length', max_length=1024, return_tensors="pt")
    base_translated = base_model.generate(**base_batch_tokens, num_beams=5, num_return_sequences=1)
    base_corrs = base_tokenizer.batch_decode(base_translated, padding="longest", skip_special_tokens=True)
    base_corrections.extend(base_corrs)

# Testing to see whether data was generated properly
print(base_corrections[:4])
    
rouge_data = rouge_metric.compute(predictions=corrections, references=list(test_df["out"]), use_stemmer=True)
for key, val in rouge_data.items():
    print(key)
    print(val)

100%|███████████████████████████████████████| 500/500 [6:29:45<00:00, 46.77s/it]


['grammar: Parking, dining stopping and comfort rooms in the city Pasig Bayan.', 'For safty shide-change device id..', ': correct grammar: correct grammar: correct grammar: correct grammar: correct grammar: correct grammar', ':: Use correct spelling: Use correct grammar: Use correct grammar: Use correct grammar']
rouge1
AggregateScore(low=Score(precision=0.882235113606088, recall=0.6240742340506934, fmeasure=0.6990202902310695), mid=Score(precision=0.8857489672230149, recall=0.6311192927672602, fmeasure=0.7045775615411423), high=Score(precision=0.889071889710229, recall=0.6377160716600342, fmeasure=0.7098546001281772))
rouge2
AggregateScore(low=Score(precision=0.765615888656686, recall=0.5289343953742742, fmeasure=0.5953806220937231), mid=Score(precision=0.7716980404562785, recall=0.5360336563302136, fmeasure=0.6018972623265957), high=Score(precision=0.7776944923152849, recall=0.5431313650388109, fmeasure=0.6086543208887297))
rougeL
AggregateScore(low=Score(precision=0.8722794229518613

In [4]:
import json

to_serialize = {
    "gec_corrections": corrections,
    "base_corrections": base_corrections
}

# Serialize results to avoid having to re-generate later
json.dump(to_serialize, open("corrections.json", "w"))

In [5]:
serialized_predictions = json.load(open("corrections.json", "r"))
corrections = serialized_predictions["gec_corrections"]
base_corrections = serialized_predictions["base_corrections"]

In [6]:
rouge_gec_data = rouge_metric.compute(predictions=corrections, references=list(test_df["out"]), use_stemmer=True)
rouge_base_data = rouge_metric.compute(predictions=base_corrections, references=list(test_df["out"]), use_stemmer=True)

print("BASE DATA")
for key, val in rouge_base_data.items():
    print(f"{key} : {val}")

print("\n\nGEC DATA")
for key, val in rouge_gec_data.items():
    print(f"{key} : {val}")

BASE DATA
rouge1 : AggregateScore(low=Score(precision=0.3759902708640686, recall=0.2942409423274288, fmeasure=0.3147301813645697), mid=Score(precision=0.3873816837544118, recall=0.3037062288532276, fmeasure=0.3244145021798222), high=Score(precision=0.3980939048721737, recall=0.31279246728034243, fmeasure=0.3335493236608105))
rouge2 : AggregateScore(low=Score(precision=0.3092195474815217, recall=0.23782304253551326, fmeasure=0.25463236940898076), mid=Score(precision=0.3199181645805176, recall=0.24659805018214329, fmeasure=0.26366938854437766), high=Score(precision=0.32963022051029, recall=0.25571714015385494, fmeasure=0.2724534498323195))
rougeL : AggregateScore(low=Score(precision=0.36915925059673266, recall=0.28883047647642957, fmeasure=0.3086469656001121), mid=Score(precision=0.3807016655463552, recall=0.29927071826018903, fmeasure=0.31934303548720544), high=Score(precision=0.3934571908567431, recall=0.3096064068137392, fmeasure=0.3297915063773098))
rougeLsum : AggregateScore(low=Sco

In [7]:
bleu_metric = load_metric("sacrebleu")
bleu_gec_data = bleu_metric.compute(predictions=[corrections], references=[list(test_df["out"])])
bleu_base_data = bleu_metric.compute(predictions=[base_corrections], references=[list(test_df["out"])])

print(f"base Score: {bleu_base_data['score']}")
print(f"GEC Score: {bleu_gec_data['score']}")


base Score: 22.162359254261354
GEC Score: 33.24152826260671


In [9]:
perplexity_metric = load_metric("perplexity")
perplexity_gec_score = perplexity_metric.compute(input_texts=corrections, model_id='gpt2')
perplexity_base_score = perplexity_metric.compute(input_texts=list(filter(lambda x: x != "", base_corrections)), model_id='gpt2')
print(f"Base: {perplexity_base_score['mean_perplexity']}")
print(f"GEC: {perplexity_gec_score['mean_perplexity']}")

Using pad_token, but it is not set yet.


  0%|          | 0/313 [00:00<?, ?it/s]

Using pad_token, but it is not set yet.


  0%|          | 0/304 [00:00<?, ?it/s]

Base: 559.820864630875
GEC: 57.85254389276505
