In [2]:
import pandas as pd
import os
import torch
from transformers import Trainer, BertTokenizer
import numpy as np

In [3]:
os.chdir("/Users/lucasvilsen/Desktop/GrammatikTAK/GrammatiktakDatasets/checkedDatasets/")
os.listdir()

['CommaDevelopmentset20.csv',
 'CommaDevelopmentset8.csv',
 'CommaDevelopmentset6.csv',
 'CommaDevelopmentset10.csv']

In [4]:
def load_data(filename):
    os.chdir("/Users/lucasvilsen/Desktop/GrammatikTAK/GrammatiktakDatasets/checkedDatasets/")
    data = pd.read_csv(filename, sep=";")

    X = list(data["comment_text"].values)
    Y = list(data["label"].values)
    
    return X, Y

class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels=None):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        if self.labels:
            item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.encodings["input_ids"])



In [5]:
def load_model():
    os.chdir("/Users/lucasvilsen/Desktop/GrammatikTAK/FineTuneModels/Models/")
    models = ["commaModel4", "commaModel7", "commaModel8",  "commaModel9"]
    model_par = [(6, 15), (10, 21), (20, 55), (10, 21)]
    # punctuation model, padding and scope should match model and dataset
    # should have padding so that every word except the last is checked
    # else the logic in finding "find_comma_mistakes" needs to be changed
    punc_models = []
    for model in models:
        loaded_model = torch.load(model, map_location=torch.device('cpu'))
        punc_models.append(loaded_model)
    scope = 6
    padding = int(scope/2-1)

    punc_trainers = []
    for model in punc_models:
        model.eval()
        model.to(device)
        punc_trainers.append(Trainer(model))
    return punc_trainers, scope, padding, models, model_par

device = "mps"
torch.device(device)

device(type='mps')

In [6]:
from tqdm import tqdm
from sklearn.metrics import f1_score, recall_score, precision_score

class PunctuationCorrector():
    def __init__(self) -> None:
        self.models, self.scope, self.padding, self.model_names, self.model_par = load_model()
        self.tokenizer = BertTokenizer.from_pretrained("Maltehb/danish-bert-botxo")
    
    def add_padding(self, words):
        return ["<PAD>"]*self.padding + words + ["<PAD>"]*self.padding

    # prepares dataset and get predictions
    def get_predictions(self) -> list:
        final_predictions = []
        ys = []
        for i in tqdm(range(len(self.models))):
            test_data, y = load_data(f"CommaDevelopmentset{self.model_par[i][0]}.csv")
            ys.append(y)
            model = self.models[i]
            tokenized_data = self.tokenizer(test_data, padding=True, truncation=True, max_length=self.model_par[i][1])
            final_dataset = Dataset(tokenized_data)
            raw_predictions, _, _ = model.predict(final_dataset)
            final_prediction = np.argmax(raw_predictions, axis=1)
            final_predictions.append(final_prediction)
        return final_predictions, ys
    
    def get_accuracy_for_each_model(self):
        predictions_lsts, y = self.get_predictions()
        accuracy = [np.mean(predictions == y[i]) for i, predictions in enumerate(predictions_lsts)]
        f1_scores = [f1_score(y[i], predictions, average="macro") for i, predictions in enumerate(predictions_lsts)]
        recall = [recall_score(y[i], predictions, average="macro") for i, predictions in enumerate(predictions_lsts)]
        precision = [precision_score(y[i], predictions, average="macro") for i, predictions in enumerate(predictions_lsts)]

        print("Accuracy:\n")
        print(*[f"{self.model_names[i]} accuracy: {round(accuracy[i], 4)*100}%" for i in range(len(accuracy))], sep="\n")
        print("\nF1-score:\n")
        print(*[f"{self.model_names[i]} f1-score: {round(f1_scores[i], 3)}" for i in range(len(f1_scores))], sep="\n")
        print("\nRecall:\n")
        print(*[f"{self.model_names[i]} recall: {round(recall[i], 3)}" for i in range(len(recall))], sep="\n")
        print("\nPrecision:\n")
        print(*[f"{self.model_names[i]} precision: {round(precision[i], 3)}" for i in range(len(precision))], sep="\n")

puncCorrector = PunctuationCorrector()

In [7]:
puncCorrector.get_accuracy_for_each_model()

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/170 [00:00<?, ?it/s]

 25%|██▌       | 1/4 [00:25<01:17, 25.87s/it]

  0%|          | 0/170 [00:00<?, ?it/s]

 50%|█████     | 2/4 [00:51<00:51, 25.84s/it]

  0%|          | 0/170 [00:00<?, ?it/s]

 75%|███████▌  | 3/4 [01:42<00:37, 37.03s/it]

  0%|          | 0/170 [00:00<?, ?it/s]

100%|██████████| 4/4 [02:04<00:00, 31.12s/it]

Accuracy:

commaModel4 accuracy: 94.19%
commaModel7 accuracy: 97.64%
commaModel8 accuracy: 97.41%
commaModel9 accuracy: 98.09%

F1-score:

commaModel4 f1-score: 0.828
commaModel7 f1-score: 0.915
commaModel8 f1-score: 0.905
commaModel9 f1-score: 0.931

Recall:

commaModel4 recall: 0.913
commaModel7 recall: 0.927
commaModel8 recall: 0.911
commaModel9 recall: 0.943

Precision:

commaModel4 precision: 0.777
commaModel7 precision: 0.904
commaModel8 precision: 0.9
commaModel9 precision: 0.92



