In [None]:
import torch
import pickle
import transformers
from transformers import BartTokenizer, BartForConditionalGeneration, AdamW
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from datasets import load_dataset

In [None]:
dt = load_dataset('carolmou/random-sentences')

In [None]:
annotated_data = dt['train']
test_annotated = dt['test']

In [None]:
print(len(annotated_data['wrong_text']))
print(len(test_annotated['wrong_text']))

In [None]:
wrong = annotated_data['wrong_text']
correct = annotated_data['correct_text']
train_data = [tup for tup in zip(wrong, correct)]

wrong = test_annotated['wrong_text']
correct = test_annotated['correct_text']
test_data = [tup for tup in zip(wrong, correct)]

In [None]:
class AutoCorrectionDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        inputs, target = self.data[idx]
        model_inputs = self.tokenizer(inputs, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt')
        labels = self.tokenizer(text_target=target, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt')
        model_inputs['labels'] = labels['input_ids']

        return {type: data[0] for type, data in model_inputs.items()}

In [None]:
# Load BART tokenizer and model
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('fine_tuned_bart_autocorrection_2')

In [None]:
# Create datasets and dataloaders
train_dataset = AutoCorrectionDataset(train_data, tokenizer, max_length=128)
test_dataset = AutoCorrectionDataset(test_data, tokenizer, max_length=128)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

In [None]:
# Set up GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

In [None]:
# Set up optimizer and loss function
optimizer = AdamW(model.parameters(), lr=1e-5)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
transformers.logging.set_verbosity_error()

# Fine-tune BART
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    for batch in tqdm(train_loader, leave=True):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            val_loss += loss.item()

    val_loss /= len(test_loader)
    print(f'Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss:.4f}')

In [None]:
model.save_pretrained('fine_tuned_bart_autocorrection')
tokenizer.save_pretrained('fine_tuned_bart_autocorrection')

# Accuracy

In [None]:
# model = BartForConditionalGeneration.from_pretrained('fine_tuned_bart_autocorrection')

In [None]:
# model.eval()
# model.to(device)

In [None]:
# total_samples = 0
# total_correct = 0

# with torch.no_grad():
#     loop = tqdm(test_loader, leave=True)

#     for batch in loop:
#         input_ids = batch['input_ids'].to(device)
#         attention_mask = batch['attention_mask'].to(device)
#         labels = batch['labels'].to(device)

#         outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
#         predicted_labels = outputs.logits.argmax(dim=-1)

#         # Compare the predicted values with the ground truth labels and count the matches
#         correct_mask = (predicted_labels == labels)
#         for ix,sample in enumerate(correct_mask):
#             if False not in sample:
#                 total_correct += 1 

#         batch_size = batch['input_ids'].shape[0]
#         total_samples += batch_size 

#         loop.set_postfix({'accuracy': total_correct/total_samples})