In [30]:
import pandas as pd
import os
import torch
from transformers import Trainer, BertTokenizer
import numpy as np

In [31]:
os.chdir("/Users/lucasvilsen/Desktop/GrammatikTAK/GrammatiktakDatasets/checkedDatasets/")
data = pd.read_csv("CommaDevelopmentset.csv", sep=";")

In [60]:
X = list(data["comment_text"].values)
Y = list(data["label"].values)

class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels=None):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        if self.labels:
            item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.encodings["input_ids"])

In [61]:
def load_model():
    os.chdir("/Users/lucasvilsen/Desktop/GrammatikTAK/FineTuneModels/Models/")
    models = ["commaModel2", "commaModel4"]
    # punctuation model, padding and scope should match model and dataset
    # should have padding so that every word except the last is checked
    # else the logic in finding "find_comma_mistakes" needs to be changed
    punc_models = []
    for model in models:
        loaded_model = torch.load(model, map_location=torch.device('cpu'))
        punc_models.append(loaded_model)
    scope = 6
    padding = int(scope/2-1)

    punc_trainers = []
    for model in punc_models:
        model.eval()
        model.to(device)
        punc_trainers.append(Trainer(model))
    return punc_trainers, scope, padding, models

device = "mps"
torch.device(device)

device(type='mps')

In [64]:
class PunctuationCorrector():
    def __init__(self) -> None:
        self.models, self.scope, self.padding, self.model_names = load_model()
        self.tokenizer = BertTokenizer.from_pretrained("Maltehb/danish-bert-botxo")
    
    def add_padding(self, words):
        return ["<PAD>"]*self.padding + words + ["<PAD>"]*self.padding

    # prepares dataset and get predictions
    def get_predictions(self, test_data) -> list:
        final_predictions = []
        for model in self.models:
            tokenized_data = self.tokenizer(test_data, padding=True, truncation=True, max_length=15)
            final_dataset = Dataset(tokenized_data)
            raw_predictions, _, _ = model.predict(final_dataset)
            final_prediction = np.argmax(raw_predictions, axis=1)
            final_predictions.append(final_prediction)
        return final_predictions
    
    def get_accuracy_for_each_model(self, X, y):
        predictions_lsts = self.get_predictions(X)
        accuracy = [np.mean(predictions == y) for predictions in predictions_lsts]
        print([f"{self.model_names[i]} accuracy: {accuracy[i]}" for i in range(len(accuracy))])

puncCorrector = PunctuationCorrector()

In [65]:
puncCorrector.get_accuracy_for_each_model(X, Y)

  0%|          | 0/170 [00:00<?, ?it/s]

  0%|          | 0/170 [00:00<?, ?it/s]

['commaModel2 accuracy: 0.8294117647058824', 'commaModel4 accuracy: 0.825735294117647']
