In [1]:
from deepmultilingualpunctuation import PunctuationModel
from datasets import load_dataset

import evaluate
import seqeval

import spacy
import json


In [2]:
nlp = spacy.load("en_core_web_sm")

In [3]:
seqeval = evaluate.load("seqeval")

In [4]:
model = PunctuationModel()



In [5]:
dataset_path = "just097/wiki-comma-placement"
wiki_comma_placement = load_dataset(dataset_path)
LABEL_LIST = ["O", "B-COMMA"]

In [6]:
try_sample = wiki_comma_placement["test"][21]

text = " ".join(try_sample["tokens"]).strip()
clean_text = model.preprocess(text)
punct_text = model.restore_punctuation(text)
print(punct_text)

Kovacs only had two campus visits: Division II, Hillsdale and Toledo ( a school 13 miles from his high school ).


In [7]:
test_subset = wiki_comma_placement["test"]

In [8]:
def sentence_to_sample(sentence: str) -> str:
    sentence = sentence.strip()
    words = [word.text for word in nlp(sentence)]
    tags = []
    clean_words = []
    for i in range(len(words) - 1):
        if words[i] == ",":
            continue
        if words[i + 1] == ",":
            clean_words.append(words[i])
            tags.append(1)
        else:
            clean_words.append(words[i])
            tags.append(0)
    clean_words.append(words[-1])
    tags.append(0)
    assert len(tags) == len(clean_words)
    return {"tokens": clean_words, "tags": tags}

In [9]:
def restore_text(example):
    text = " ".join(example["tokens"]).strip()
    clean_text = model.preprocess(text)
    clean_text = " ".join(clean_text).strip()
    punct_text = model.restore_punctuation(clean_text)
    return sentence_to_sample(punct_text)

## Generate restored sentences with baseline.

In [10]:
restored_texts = test_subset.map(restore_text)

Map:   0%|          | 0/19667 [00:00<?, ? examples/s]



In [18]:
restored_texts["tags"]

[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
 [0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [15]:
def compute_metrics(predictions, labels):
    true_predictions = [[LABEL_LIST[p] for sample in predictions for p in sample]]
    true_labels = [[LABEL_LIST[p] for sample in labels for p in sample]]

    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [19]:
compute_metrics(restored_texts["tags"], wiki_comma_placement["test"]["tags"])

{'precision': 0.7262762566050919,
 'recall': 0.6416280864197531,
 'f1': 0.6813330875273971,
 'accuracy': 0.9690515125110907}