In [None]:
import torch
import pickle
import re
import json
import pandas as pd
from tqdm import tqdm
from transformers import BartForConditionalGeneration, BartTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset


In [None]:
model_checkpoint = '../training/bart_best_checkpoint.pt'

In [None]:
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

In [None]:
state_dict = torch.load(model_checkpoint, map_location='cuda:1')

In [None]:
model.load_state_dict(state_dict)

In [None]:
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

In [None]:
model.to(device)

In [None]:
data = load_dataset("carolmou/random-sentences")["test"]

In [None]:
class AutoCorrectionDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self._data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self._data)

    def __getitem__(self, idx):
        inputs, target = self._data[idx]
        model_inputs = self.tokenizer(inputs, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt')
        labels = self.tokenizer(text_target=target, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt')
        model_inputs['labels'] = labels['input_ids']

        return {type: data[0] for type, data in model_inputs.items()}

In [None]:
wrong_text = data['wrong_text']
correct_text = data['correct_text']
test_data = list(zip(wrong_text, correct_text))

test_dataset = AutoCorrectionDataset(test_data, tokenizer, max_length=128)

In [None]:
# sanity check
dic = test_dataset[0]
input_ids = dic["input_ids"]
labels = dic["labels"]

print(tokenizer.decode(input_ids, skip_special_tokens=True))
print(tokenizer.decode(labels, skip_special_tokens=True))

In [None]:
test_loader = DataLoader(test_dataset, batch_size=32)

# Accuracy measurement

In [None]:
total_samples = 0
total_correct = 0

with torch.no_grad():
    loop = tqdm(test_loader, leave=True)

    for batch in loop:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        predicted_labels = outputs.logits.argmax(dim=-1)

        # Compare the predicted values with the ground truth labels and count the matches
        correct_mask = (predicted_labels == labels)
        for ix,sample in enumerate(correct_mask):
            if False not in sample:
                total_correct += 1

        batch_size = batch['input_ids'].shape[0]
        total_samples += batch_size 

        loop.set_postfix({'accuracy': total_correct/total_samples})

In [None]:
inputs = tokenizer("Ele é engraçada", max_length=128, padding='max_length', truncation=True, return_tensors='pt')

In [None]:
id = inputs["input_ids"]
attention = inputs["attention_mask"]

In [None]:
id = id.to(device)
attention = attention.to(device)

In [None]:
outputs = model(input_ids = id, attention_mask = attention)
predicted_labels = outputs.logits.argmax(dim=-1)


In [None]:
predicted_labels

In [None]:
tokenizer.decode(predicted_labels[0], skip_special_tokens=True)