In [1]:
from transformers import BartTokenizer, BartForConditionalGeneration
from sklearn.model_selection import train_test_split
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch
from transformers import AdamW
from tqdm import tqdm

# Load the dataset
dataset_path = 'Dataset/sinhala_dataset.csv'
data = pd.read_csv(dataset_path)

# Prepare the data
input_texts = "grammar_error: " + data['grammar_error_sentence']
target_texts = data['corrected_sentence']

# Split the data into train and test sets
train_inputs, test_inputs, train_targets, test_targets = train_test_split(
    input_texts, target_texts, test_size=0.2, random_state=42
)

# Load BART tokenizer
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

# Define a Dataset class
class GrammarCorrectionDataset(Dataset):
    def __init__(self, inputs, targets, tokenizer, max_len=128):
        self.inputs = inputs
        self.targets = targets
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_text = self.inputs.iloc[idx]
        target_text = self.targets.iloc[idx]

        input_encoding = self.tokenizer(
            input_text, truncation=True, padding="max_length", max_length=self.max_len, return_tensors="pt"
        )
        target_encoding = self.tokenizer(
            target_text, truncation=True, padding="max_length", max_length=self.max_len, return_tensors="pt"
        )

        return {
            "input_ids": input_encoding["input_ids"].squeeze(0),
            "attention_mask": input_encoding["attention_mask"].squeeze(0),
            "labels": target_encoding["input_ids"].squeeze(0),
        }

# Create datasets and dataloaders
train_dataset = GrammarCorrectionDataset(train_inputs, train_targets, tokenizer)
test_dataset = GrammarCorrectionDataset(test_inputs, test_targets, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8)

# Load BART model
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training loop
epochs = 3
for epoch in range(epochs):
    model.train()
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        optimizer.zero_grad()

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        loop.set_description(f"Epoch {epoch}")
        loop.set_postfix(loss=loss.item())

# Save the model
model.save_pretrained("Models/Advanced_Bart/bart_sinhala_grammar_checker")
tokenizer.save_pretrained("TokenizerAdvanced_Bart/bart_sinhala_grammar_checker")



  from .autonotebook import tqdm as notebook_tqdm
Epoch 0: 100%|██████████| 1443/1443 [1:53:36<00:00,  4.72s/it, loss=0.0226] 
Epoch 1: 100%|██████████| 1443/1443 [1:56:06<00:00,  4.83s/it, loss=0.0041] 
Epoch 2: 100%|██████████| 1443/1443 [2:07:19<00:00,  5.29s/it, loss=0.00603]  


('TokenizerAdvanced_Bart/bart_sinhala_grammar_checker\\tokenizer_config.json',
 'TokenizerAdvanced_Bart/bart_sinhala_grammar_checker\\special_tokens_map.json',
 'TokenizerAdvanced_Bart/bart_sinhala_grammar_checker\\vocab.json',
 'TokenizerAdvanced_Bart/bart_sinhala_grammar_checker\\merges.txt',
 'TokenizerAdvanced_Bart/bart_sinhala_grammar_checker\\added_tokens.json')

In [8]:
# Define a function for inference
def correct_sentence(input_sentence):
    input_text = "grammar_error: " + input_sentence
    input_encoding = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True, max_length=128)
    input_encoding = input_encoding.to(device)
    outputs = model.generate(input_encoding["input_ids"], max_length=128, num_beams=4, early_stopping=True)
    corrected_sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return corrected_sentence

# Test the model
test_sentence = "මම ගෙදර යව"
print("Corrected Sentence:", correct_sentence(test_sentence))

Corrected Sentence: වාහන ගෙදර ළයි
