In [None]:
import json
from transformers import BartTokenizer, BartForConditionalGeneration, Trainer, TrainingArguments, EarlyStoppingCallback
import torch
from sklearn.model_selection import train_test_split

# Load the dataset
with open('grammar_correction_data.json', 'r') as f:
    data = json.load(f)

# Initialize the tokenizer
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')

# Tokenize the data
def tokenize_data(data):
    inputs = tokenizer([item['incorrect'] for item in data], padding=True, truncation=True, return_tensors='pt')
    labels = tokenizer([item['corrected'] for item in data], padding=True, truncation=True, return_tensors='pt')
    return inputs, labels

inputs, labels = tokenize_data(data)

In [3]:

# Create a dataset class
class GrammarCorrectionDataset(torch.utils.data.Dataset):
    def __init__(self, inputs, labels):
        self.inputs = inputs
        self.labels = labels

    def __len__(self):
        return len(self.inputs['input_ids'])

    def __getitem__(self, idx):
        return {
            'input_ids': self.inputs['input_ids'][idx],
            'attention_mask': self.inputs['attention_mask'][idx],
            'labels': self.labels['input_ids'][idx]
        }

# Split data into train and validation sets
train_inputs, val_inputs, train_labels, val_labels = train_test_split(inputs, labels, test_size=0.1)

# Create datasets
train_dataset = GrammarCorrectionDataset(train_inputs, train_labels)
val_dataset = GrammarCorrectionDataset(val_inputs, val_labels)

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=10,
    logging_dir='./logs',
    evaluation_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    weight_decay=0.01,
    logging_steps=10,
    save_total_limit=3,
    gradient_clipping=1.0
)

# Initialize Trainer with EarlyStoppingCallback
trainer = Trainer(
    model=BartForConditionalGeneration.from_pretrained('facebook/bart-large'),
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]  # Early stopping patience
)

# Train the model
trainer.train()

# Save the trained model and tokenizer
model_path = "./grammar_correction_model1"
trainer.save_model(model_path)
tokenizer.save_pretrained(model_path)

Step,Training Loss


Step,Training Loss
500,0.6748


Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}
Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}
Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


('./grammar_correction_model/tokenizer_config.json',
 './grammar_correction_model/special_tokens_map.json',
 './grammar_correction_model/vocab.json',
 './grammar_correction_model/merges.txt',
 './grammar_correction_model/added_tokens.json')

In [16]:
# Define a function to generate corrected sentences
def generate_correction(model, tokenizer, sentence):
    inputs = tokenizer(sentence, return_tensors='pt')
    output = model.generate(**inputs)
    corrected_sentence = tokenizer.decode(output[0], skip_special_tokens=True)
    return corrected_sentence

# Test the model
test_sentence = "he were very much happyier "
print(f"Original: {test_sentence}")
print(f"Corrected: {generate_correction(model, tokenizer, test_sentence)}")


Original: he were very much happyier 
Corrected: he was very much happy.
