In [1]:
import torch
import pickle
import re
import json
import pandas as pd
from tqdm import tqdm
from transformers import BartForConditionalGeneration, BartTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_checkpoint = '../training/checkpoint.pt'

In [3]:
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model.load_state_dict(torch.load(model_checkpoint))

<All keys matched successfully>

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
model.to(device)

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=

In [6]:
data = load_dataset("carolmou/random-sentences")["test"]

Found cached dataset parquet (/home/carolmou/.cache/huggingface/datasets/carolmou___parquet/carolmou--random-sentences-b36071ffaba43c26/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)
100%|██████████| 2/2 [00:00<00:00, 187.00it/s]


In [7]:
class AutoCorrectionDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self._data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self._data)

    def __getitem__(self, idx):
        inputs, target = self._data[idx]
        model_inputs = self.tokenizer(inputs, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt')
        labels = self.tokenizer(text_target=target, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt')
        model_inputs['labels'] = labels['input_ids']

        return {type: data[0] for type, data in model_inputs.items()}

In [8]:
wrong_text = data['wrong_text']
correct_text = data['correct_text']
test_data = list(zip(wrong_text, correct_text))

test_dataset = AutoCorrectionDataset(test_data, tokenizer, max_length=128)

In [9]:
# sanity check
dic = test_dataset[0]
input_ids = dic["input_ids"]
labels = dic["labels"]

print(tokenizer.decode(input_ids, skip_special_tokens=True))
print(tokenizer.decode(labels, skip_special_tokens=True))

dão tivera intenção de dizer isso
Não tivera intenção de dizer isso


In [10]:
test_loader = DataLoader(test_dataset, batch_size=32)

# Vocabulary of mistakes

Save all the bad words that the spellchecker has to fix.

In [86]:
vocab_wrong = {}

In [87]:
for sent_wrong, sent_right in test_data:
    wrong = sent_wrong.split()
    right = sent_right.split()

    for w1, w2 in zip(wrong, right):
        if w1 != w2:
            vocab_wrong[(w1,w2)] = vocab_wrong.get(w1, 0) + 1

In [1]:
vocab_wrong[:10]

NameError: name 'vocab_wrong' is not defined

In [89]:
df = {'wrong': [], 'correction': []}

for w1, w2 in vocab_wrong:
    df['wrong'].append(w1)
    df['correction'].append(w2)

In [90]:
df = pd.DataFrame(df)

In [91]:
df

Unnamed: 0,wrong,correction
0,perdoo,Perdão
1,perdao,Perdão
2,declamar.,reclamar.
3,reclamam.,reclamar.
4,"Legal,","legal,"
...,...,...
451914,t.vo,t.v
451915,"Jae-it,","Jae-in,"
451916,"Jae-is,","Jae-in,"
451917,lb:電気,lo:電気


In [92]:
with open('../../data/test_wrong_vocabulary.pickle', 'wb') as file:
    pickle.dump(df, file)

# Accuracy measurement

In [17]:
total_samples = 0
total_correct = 0

with torch.no_grad():
    loop = tqdm(test_loader, leave=True)

    for batch in loop:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        predicted_labels = outputs.logits.argmax(dim=-1)

        # Compare the predicted values with the ground truth labels and count the matches
        correct_mask = (predicted_labels == labels)
        for ix,sample in enumerate(correct_mask):
            if False not in sample:
                total_correct += 1

        batch_size = batch['input_ids'].shape[0]
        total_samples += batch_size 

        loop.set_postfix({'accuracy': total_correct/total_samples})

100%|██████████| 1239/1239 [02:26<00:00,  8.47it/s, accuracy=0.311]


In [24]:
inputs = tokenizer("Ele é engraçada", max_length=128, padding='max_length', truncation=True, return_tensors='pt')

In [25]:
id = inputs["input_ids"]
attention = inputs["attention_mask"]

In [26]:
id = id.to(device)
attention = attention.to(device)

In [27]:
outputs = model(input_ids = id, attention_mask = attention)
predicted_labels = outputs.logits.argmax(dim=-1)


In [28]:
predicted_labels

tensor([[    0, 28888,  7935, 20407,   763,  3381,  2102,     2,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,  

In [29]:
tokenizer.decode(predicted_labels[0], skip_special_tokens=True)

'Ele é engraçado'