In [1]:
!pip install pandas
!pip install tqdm



In [2]:
from tqdm import tqdm
import csv

with open("gec_data_truncated.csv", "w") as f:
    write = csv.writer(f, escapechar="\\")
    write.writerow(["in", "out"])
    for i in tqdm(range(10)):
        with open(f"./gec_data/C4_200M.tsv-0000{i}-of-00010") as d:
            inp_out_pairs = d.readlines()
            d_data = [list(map(str.strip, inp_out_pairs[j].split("\t"))) for j in tqdm(range(5_000))]
            write.writerows(d_data)

  0%|                                                                                                            | 0/10 [00:00<?, ?it/s]
  0%|                                                                                                          | 0/5000 [00:00<?, ?it/s][A
  1%|█▎                                                                                               | 70/5000 [00:01<01:23, 59.23it/s][A
 15%|██████████████▍                                                                                | 757/5000 [00:01<00:08, 491.99it/s][A
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:02<00:00, 2258.84it/s][A
 10%|██████████                                                                                          | 1/10 [00:12<01:50, 12.29s/it]
  0%|                                                                                                          | 0/5000 [00:00<?, ?it/s][A
 12%|███████████▏         

In [1]:
import pandas as pd
gec_df = pd.read_csv('gec_data_truncated.csv')

In [2]:
gec_df.head()

Unnamed: 0,in,out
0,"Bitcoin is for $7,094 this morning, which Coin...","Bitcoin goes for $7,094 this morning, accordin..."
1,The effect of widespread dud targets two face ...,"1. The effect of ""widespread dud"" targets two ..."
2,tax on sales of stores for non residents are s...,Capital Gains tax on the sale of properties fo...
3,Much many brands and sellers still in the market.,Many brands and sellers still in the market.
4,this is is the latest Maintenance release of S...,This is is the latest maintenance release of S...


In [3]:
!pip install sentencepiece
from transformers import T5ForConditionalGeneration, T5Tokenizer


model_name = 't5-base'
tokenizer = T5Tokenizer.from_pretrained(model_name)
token_model = T5ForConditionalGeneration.from_pretrained(model_name)



For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [4]:
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(gec_df, test_size=0.20, shuffle=True)
validation_df, test_df = train_test_split(test_df, test_size=0.50, shuffle=True)
validation_df.to_csv("validation_data.csv", index=False)
test_df.to_csv("test_data.csv", index=False)

In [5]:
validation_df['token_len'] = validation_df['in'].apply(lambda inp: len(tokenizer(inp).input_ids))
validation_df.describe()

Unnamed: 0,token_len
count,5000.0
mean,33.5058
std,26.028451
min,6.0
25%,17.0
50%,27.0
75%,42.0
max,484.0


In [6]:
from datasets import Dataset
train_dataset = Dataset.from_pandas(train_df)
validation_dataset = Dataset.from_pandas(validation_df)
test_dataset = Dataset.from_pandas(test_df)

In [7]:
from torch.utils.data import Dataset, DataLoader

class GrammarDataset(Dataset):
    def __init__(self, dataset, tokenizer,print_text=False):         
        self.dataset = dataset
        self.pad_to_max_length = False
        self.tokenizer = tokenizer
        self.print_text = print_text
        self.max_len = 128
    
    def __len__(self):
        return len(self.dataset)
    
    def tokenize_data(self, in_out_pair):
        input_, target_ = in_out_pair['in'], in_out_pair['out']

        tokenized_inputs = self.tokenizer(input_, pad_to_max_length=self.pad_to_max_length, 
                                            max_length=self.max_len,
                                            return_attention_mask=True,
                                            truncation=True)
    
        tokenized_targets = self.tokenizer(target_, pad_to_max_length=self.pad_to_max_length, 
                                            max_length=self.max_len,
                                            return_attention_mask=True,
                                            truncation=True)

        inputs={"input_ids": tokenized_inputs['input_ids'],
            "attention_mask": tokenized_inputs['attention_mask'],
            "labels": tokenized_targets['input_ids']
        }
        
        return inputs

    def __getitem__(self, index):
        inputs = self.tokenize_data(self.dataset[index])
        
        if self.print_text:
            for k in inputs.keys():
                print(k, len(inputs[k]))

        return inputs


In [8]:
train_gec_data = GrammarDataset(train_dataset, tokenizer, True)
print(train_gec_data[4])

input_ids 43
attention_mask 43
labels 45
{'input_ids': [37, 682, 28, 16009, 9952, 1195, 38, 3, 9, 1253, 56, 817, 1937, 19, 24, 34, 22, 7, 66, 396, 514, 12, 24460, 1737, 135, 38, 6, 497, 6, 8929, 16023, 6, 38, 25, 653, 12, 1344, 3, 9, 12803, 8109, 5, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [332, 3, 88, 682, 28, 16009, 9952, 1195, 38, 3, 9, 1253, 12, 817, 1937, 19, 24, 34, 22, 7, 66, 396, 514, 12, 24460, 1737, 135, 38, 6, 497, 6, 8929, 16023, 6, 38, 25, 653, 12, 1344, 3, 9, 12803, 8109, 5, 1]}


In [9]:
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer, model=token_model, padding='longest', return_tensors='pt')

In [10]:
from transformers import Seq2SeqTrainingArguments
batch_size = 32
args = Seq2SeqTrainingArguments(output_dir="./gec_out",
                        evaluation_strategy="epoch",
                        per_device_train_batch_size=batch_size,
                        per_device_eval_batch_size=batch_size,
                        learning_rate=1e-5,
                        num_train_epochs=1,
                        weight_decay=0.01,
                        predict_with_generate=True,
                        save_steps = 500)

In [11]:
!pip install rouge_score
from datasets import load_metric
rouge_metric = load_metric("rouge")

from nltk import sent_tokenize
import numpy as np

def eval_metrics(in_out_pairs):
    preds, labels = in_out_pairs
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels]
    
    rouge_data = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    result = {key: value.mid.fmeasure * 100 for key, value in rouge_data.items()}
    
    pred_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(pred_lens)
    return {k: round(v, 4) for k, v in result.items()}



  rouge_metric = load_metric("rouge")


In [12]:
from transformers import Seq2SeqTrainer
gec_model = Seq2SeqTrainer(model=token_model, 
                args=args, 
                train_dataset=GrammarDataset(train_dataset, tokenizer),
                eval_dataset=GrammarDataset(test_dataset, tokenizer),
                tokenizer=tokenizer,
                data_collator=data_collator,
                compute_metrics=eval_metrics)

gec_model.train()



Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,0.6904,0.593096,70.1315,59.7146,69.3441,69.3889,17.3692


TrainOutput(global_step=1266, training_loss=0.718926319767137, metrics={'train_runtime': 13667.3974, 'train_samples_per_second': 2.963, 'train_steps_per_second': 0.093, 'total_flos': 4761165812244480.0, 'train_loss': 0.718926319767137, 'epoch': 1.0})

In [13]:
gec_model.save_model("./gec_model_final")

In [8]:
from transformers import AutoModelForSeq2SeqLM, T5Tokenizer
gec_tokenizer = T5Tokenizer.from_pretrained("./gec_model_final")
gec_model = AutoModelForSeq2SeqLM.from_pretrained("./gec_model_final")

In [9]:
input_texts = [
    "She ain't gonna goes to the store.", 
    "They has been playing soccer all day.", 
    "I goed to the movies last night.",
    "He don't likes it.",
    "He don't visits us.",
    "We was watching TV when the power went out.",
    "The dog chases the cat up the tree.",
    "Me and him is best friends.",
    "Oh, I'm goin' to the city for visit my friend. How 'bout you?",
    "I meet my friend for weekend. Family is nice, yes?",
    "The affect of the tsunami is incomprehensible",
    "This cinnamom powder cut thru the grease.",
    "I uh, likes pizza and also soda."
]

batch = gec_tokenizer(input_texts, truncation=True, padding='max_length', max_length=64, return_tensors="pt")
translated = gec_model.generate(**batch, num_beams=5, num_return_sequences=1, early_stopping=True)
print(translated.shape)
corrs = gec_tokenizer.batch_decode(translated, skip_special_tokens=True)

for inp, out in zip(input_texts, corrs):
    print(f"Orig: {inp}\nCorr: {out}\n")



torch.Size([13, 20])
Orig: She ain't gonna goes to the store.
Corr: She ain't gonna go to the store.

Orig: They has been playing soccer all day.
Corr: They have been playing soccer all day.

Orig: I goed to the movies last night.
Corr: I went to the movies last night.

Orig: He don't likes it.
Corr: He doesn't like it.

Orig: He don't visits us.
Corr: He doesn't visit us.

Orig: We was watching TV when the power went out.
Corr: We were watching TV when the power went out.

Orig: The dog chases the cat up the tree.
Corr: The dog chases the cat up the tree.

Orig: Me and him is best friends.
Corr: Me and him are best friends.

Orig: Oh, I'm goin' to the city for visit my friend. How 'bout you?
Corr: Oh, I'm going to the city to visit my friend. How'bout you

Orig: I meet my friend for weekend. Family is nice, yes?
Corr: I meet my friend for the weekend. Family is nice, yes?

Orig: The affect of the tsunami is incomprehensible
Corr: The impact of the tsunami is incomprehensible.

Orig: T