In [1]:
from typing import Optional, List, Tuple, Dict, Union, Any

# Training

- BERT token classification 으로 각 토큰이 eos가 되는지를 판단하는 모델을 학습합니다.

## Source Data

- Data  
    - https://www.kaggle.com/datasets/Cornell-University/arxiv 사용  
    - 170만개의 article 정보

In [2]:
import json
from tqdm import tqdm
import re

def clean(text: str) -> str:
    """
    Cleaning strategy:
        1. Remove all line-breaking characters.
        2. Remove all content in '(...)' or '[...]'
        3. Remove all LaTex formatting characters in '$...$'
        4. Remove all Latex formatting characters without '$...$'
        5. Remove all floating numbers
        6. Remove all non-alphanumeric characters
        7. Set punctuation to has no blank before and one blank after
        8. Set all whitespace to single space
    Args:
        text: input text
    Returns:
        cleaned text
    """
    regex = [(r'\n+', ' '),
           (r'\([^\(\)]*\)|\[[^\[\]]*\]', ' '),
           (r'\$.*\$', ' '),
           (r'\\[^\s]+', ' '),
           (r'\d+\.\d+', ' '),
           (r'[^a-zA-Z\. ]', ' '),
           (r' *\. *', '. '),
           (r' +', ' ')]
    
    cleaned = text
    for pattern, repl in regex:
        cleaned = re.sub(pattern, repl, cleaned)
    
    return cleaned.strip()

def get_metadata(data_file='../data/arxiv-metadata-oai-snapshot.json'):
    """
    For memory saving, generator will be returned
    """
    with open(data_file, 'r') as f:
        for line in f:
            yield line
            
def get_papers(n: int=100) -> List[str]:
    """
    Args:
        n: number of papers to return
    Returns:
        list of papers
    """
    metadata = get_metadata()
    papers = []
    success, fail = 0, 0
    i = 0
    for paper in tqdm(metadata):
        paper_json = json.loads(paper)
        if i==38:
            print(f'=========== original ===========\n{paper_json["abstract"]}')
        papers.append(clean(paper_json['abstract'].lower()))
        if i==38:
            print(f'=========== cleaned ===========\n{papers[-1]}')
        i += 1
        if i == n:
            break
    
    return papers

In [3]:
papers = get_papers(100000)

2536it [00:00, 12213.06it/s]

  The quadratic pion scalar radius, \la r^2\ra^\pi_s, plays an important role
for present precise determinations of \pi\pi scattering. Recently, Yndur\'ain,
using an Omn\`es representation of the null isospin(I) non-strange pion scalar
form factor, obtains \la r^2\ra^\pi_s=0.75\pm 0.07 fm^2. This value is larger
than the one calculated by solving the corresponding Muskhelishvili-Omn\`es
equations, \la r^2\ra^\pi_s=0.61\pm 0.04 fm^2. A large discrepancy between both
values, given the precision, then results. We reanalyze Yndur\'ain's method and
show that by imposing continuity of the resulting pion scalar form factor under
tiny changes in the input \pi\pi phase shifts, a zero in the form factor for
some S-wave I=0 T-matrices is then required. Once this is accounted for, the
resulting value is \la r^2\ra_s^\pi=0.65\pm 0.05 fm^2. The main source of error
in our determination is present experimental uncertainties in low energy S-wave
I=0 \pi\pi phase shifts. Another important contribution 

99999it [00:07, 13563.36it/s]


In [4]:
papers[0]

'a fully differential calculation in perturbative quantum chromodynamics is presented for the production of massive photon pairs at hadron colliders. all next to leading order perturbative contributions from quark antiquark gluon quark and gluon gluon subprocesses are included as well as all orders resummation of initial state gluon radiation valid at next to next to leading logarithmic accuracy. the region of phase space is specified in which the calculation is most reliable. good agreement is demonstrated with data from the fermilab tevatron and predictions are made for more detailed tests with cdf and do data. predictions are shown for distributions of diphoton pairs produced at the energy of the large hadron collider. distributions of the diphoton pairs from the decay of a higgs boson are contrasted with those produced from qcd processes at the lhc showing that enhanced sensitivity to the signal can be obtained with judicious selection of events.'

## Target Data

- Labels
    - 해당 토큰이 eos가 되는지 여부

ex)  
Our hope is to em #power new use case #s. However, we are not sure how to do this yet.  
0   0    0  0  -1 0      0   0   -1    1   0        0  0   0   0    0   0  0  0    1

여기서 -1은 (편의를 위해 -100을 -1로 적음) 자동으로 무시되는 값 => subword 들 중 하나에만 유의미한 label을 부여하면 나머지는 무시해도 됨

In [5]:
from datasets import DatasetDict, Features, ClassLabel
from datasets.arrow_dataset import Dataset
from datasets import features as ds_features

Hub:Dataset format으로 만들어주기 위해 dict로 만들어줍니다.

In [6]:
papers = {'original': papers,
          'source': [paper.replace('.', '') for paper in papers],
          'is_eos': [[int(w[-1]=='.') for w in paper.split()] for paper in papers],
          'id': list(range(len(papers)))}

`is_eos` : 토큰이 eos가 되는지를 판단하는 라벨  
따라서, Dataset의 feature를 0, 1만 존재하는 `ClassLabel`로 바꿔줍니다.

In [7]:
papers = Dataset.from_dict(papers)
papers.features['is_eos'] = ds_features.Sequence(ds_features.ClassLabel(2, names=[0,1]))

train: 50%, valid: 40%, test: 10%

In [8]:
train_testvalid = papers.train_test_split(test_size=0.5)
test_valid = train_testvalid['test'].train_test_split(test_size=0.2)

papers = DatasetDict({'train': train_testvalid['train'],
                        'valid': test_valid['train'],
                        'test': test_valid['test']})

distilbert는 같은 크기 대비 성능이 가장 좋은 BERT 모델입니다.  
그리고 특정 token이 eos인지 아닌지 판별하는 문제를 풀고자 하므로 AutoModelforTokenClassifiaction을 로드

In [9]:
import transformers
from transformers import AutoTokenizer, AutoModelForTokenClassification

pretrained_model_name = "distilbert-base-uncased"
# pretrained_model_name = "seg-model-distilBERT-finetuned"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
model = AutoModelForTokenClassification.from_pretrained(pretrained_model_name)

토크나이징을 거치면 여러 서브워드들이 생성되기 때문에, 각 단어마다 is_eos 태깅을 했던 것의 위치가 맞지 않게 됩니다.  
따라서 이를 제대로 정렬하여 `labels` 로 저장합니다.

In [10]:
from datasets import features as ds_features

def tokenize_and_align_labels(papers: DatasetDict, label_pos: str='last') -> DatasetDict:
    """
    Since the labels are not aligned with the tokens due to subwords, this method aligns them.
    
    Args:
        papers: DatasetDict of papers
        label_pos: position of the is_eos label in the tokens
    Returns:
        DatasetDict of papers with aligned labels
    """
    assert label_pos in ['all', 'first', 'last']
    
    tokenized_inputs = tokenizer(papers['source'], truncation=True)
    labels = []
    for i, label in enumerate(papers['is_eos']):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        
        if label_pos == 'last':
            word_ids = word_ids[::-1]
        
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            else:
                label_ids.append(label[word_idx] if label_pos=='all' else -100)
            previous_word_idx = word_idx
            
        labels.append(label_ids[::-1] if label_pos=='last' else label_ids)
    
    tokenized_inputs['labels'] = labels
    return tokenized_inputs

In [11]:
tokenize_and_align_labels(papers['train'][:2])

{'input_ids': [[101, 1996, 9575, 16014, 2015, 2008, 1037, 10713, 2275, 2004, 5097, 2057, 8980, 2195, 7680, 4031, 9872, 2015, 2525, 1999, 1996, 3906, 102], [101, 1996, 7613, 1997, 12702, 7770, 7741, 3463, 2875, 1996, 2312, 17454, 19036, 2594, 6112, 2145, 3464, 6801, 2317, 28984, 2031, 2042, 3818, 2000, 4863, 2122, 3463, 1998, 6516, 2000, 9002, 6022, 2000, 1996, 3742, 5166, 1997, 2256, 9088, 2174, 2195, 14679, 2006, 1996, 2535, 2209, 2011, 3180, 6351, 7722, 2317, 28984, 4839, 5294, 2860, 16584, 2063, 28984, 2024, 2245, 2000, 2022, 2081, 1997, 1037, 8150, 1997, 7722, 1998, 16231, 7978, 2135, 2037, 11520, 3446, 2003, 3469, 2084, 2216, 1997, 5171, 6351, 7722, 2317, 28984, 1998, 2027, 12985, 2000, 1999, 11365, 13464, 1999, 2460, 2335, 9289, 2229, 8821, 2027, 12346, 1037, 2204, 4018, 2005, 9990, 1996, 12702, 7770, 7741, 3463, 2182, 2057, 11628, 1999, 6987, 2023, 10744, 2011, 2478, 1996, 2087, 3522, 1998, 2039, 2000, 3058, 11520, 3162, 2005, 5294, 2317, 28984, 1998, 1037, 10125, 9758, 25837, 2

`DatasetDict` 객체에 mapping을 통해 method를 적용시킬 수 있다.

In [12]:
tokenized_datasets = papers.map(tokenize_and_align_labels, batched=True)



  0%|          | 0/50 [00:00<?, ?ba/s]

  0%|          | 0/40 [00:00<?, ?ba/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

학습에는 BERT가 받아들이는 parameter: intput_ids, attention_mask, labels 만이 필요하다.

In [13]:
tokenized_datasets = tokenized_datasets.remove_columns(['original', 'source', 'id', 'is_eos'])
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 50000
    })
    valid: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 40000
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 10000
    })
})

In [14]:
tokenized_datasets['train'].features['labels']

Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)

## Data Collator

data batch를 만들어주는 object

In [14]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

2022-08-04 10:05:30.134528: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


NER 문제에 쓰이는 seqeval metric을 eos를 분류하는 binary cls에 맞도록 강제로 바꿔봤습니다...

In [15]:
from datasets import load_metric


metric = load_metric("seqeval")

In [16]:
metric.compute(predictions=[['O', 'B-PER']], references=[['B-PER', 'O']])

{'PER': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1},
 'overall_precision': 0.0,
 'overall_recall': 0.0,
 'overall_f1': 0.0,
 'overall_accuracy': 0.0}

In [17]:
import numpy as np

label_list = ['O', 'B-PER']

def compute_metrics(p):
    """
    Args:
        p: predictions and labels, which are BERT prediction output
    Returns:
        metrics: dictionary of metrics
    """
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

Trainer API를 사용하여 학습 진행

In [18]:
from transformers import TrainingArguments, Trainer

batch_size = 16
model_name = pretrained_model_name.split("/")[-1]
args = TrainingArguments(f"{model_name}-finetuned",
                         evaluation_strategy='epoch',
                         learning_rate=2e-5,
                         per_device_train_batch_size=batch_size,
                         per_device_eval_batch_size=batch_size,
                         num_train_epochs=3,
                         weight_decay=0.01,
                         push_to_hub=False)

trainer = Trainer(model, args,
                  train_dataset=tokenized_datasets['train'],
                  eval_dataset=tokenized_datasets['valid'],
                  data_collator=data_collator,
                  tokenizer=tokenizer,
                  compute_metrics=compute_metrics)

In [20]:
trainer.train()

***** Running training *****
  Num examples = 50000
  Num Epochs = 3
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 2346


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,0.053,0.030014,0.888854,0.864517,0.876516,0.989262
2,0.029,0.027655,0.891486,0.885924,0.888697,0.990217
3,0.0267,0.026981,0.89003,0.89336,0.891692,0.990433


Saving model checkpoint to distilbert-base-uncased-finetuned/checkpoint-500
Configuration saved in distilbert-base-uncased-finetuned/checkpoint-500/config.json
Model weights saved in distilbert-base-uncased-finetuned/checkpoint-500/pytorch_model.bin
tokenizer config file saved in distilbert-base-uncased-finetuned/checkpoint-500/tokenizer_config.json
Special tokens file saved in distilbert-base-uncased-finetuned/checkpoint-500/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 40000
  Batch size = 64
Saving model checkpoint to distilbert-base-uncased-finetuned/checkpoint-1000
Configuration saved in distilbert-base-uncased-finetuned/checkpoint-1000/config.json
Model weights saved in distilbert-base-uncased-finetuned/checkpoint-1000/pytorch_model.bin
tokenizer config file saved in distilbert-base-uncased-finetuned/checkpoint-1000/tokenizer_config.json
Special tokens file saved in distilbert-base-uncased-finetuned/checkpoint-1000/special_tokens_map.json
Saving model ch

TrainOutput(global_step=2346, training_loss=0.034029456558942996, metrics={'train_runtime': 1067.9166, 'train_samples_per_second': 140.46, 'train_steps_per_second': 2.197, 'total_flos': 1.220557177589568e+16, 'train_loss': 0.034029456558942996, 'epoch': 3.0})

# Evaluation

In [21]:
trainer.evaluate()

***** Running Evaluation *****
  Num examples = 40000
  Batch size = 64


{'eval_loss': 0.026981081813573837,
 'eval_precision': 0.8900300888088372,
 'eval_recall': 0.8933598771044299,
 'eval_f1': 0.8916918744132679,
 'eval_accuracy': 0.9904326962501993,
 'eval_runtime': 138.0474,
 'eval_samples_per_second': 289.756,
 'eval_steps_per_second': 4.527,
 'epoch': 3.0}

모델 저장

In [24]:
trainer.save_model('seg-model-distilBERT-finetuned')

Saving model checkpoint to seg-model-distilBERT-finetuned
Configuration saved in seg-model-distilBERT-finetuned/config.json
Model weights saved in seg-model-distilBERT-finetuned/pytorch_model.bin
tokenizer config file saved in seg-model-distilBERT-finetuned/tokenizer_config.json
Special tokens file saved in seg-model-distilBERT-finetuned/special_tokens_map.json


# Demonstration

In [4]:
sample_text = [
    """USB was designed to standardize the connection of peripherals to personal computers, both to communicate with and to supply electric power. It has largely replaced interfaces such as serial ports and parallel ports, and has become commonplace on a wide range of devices. Examples of peripherals that are connected via USB include computer keyboards and mice, video cameras, printers, portable media players, mobile (portable) digital telephones, disk drives, and network adapters. USB connectors have been increasingly replacing other types as charging cables of portable devices.""",
    """Historical linguistics is the study of language changes in history, particularly with regard to a specific language or a group of languages. Western trends in historical linguistics date back to roughly the late 18th century, when the discipline grew out of philology, the study of ancient texts and oral traditions.[15] Historical linguistics emerged as one of the first few sub-disciplines in the field, and was most widely practiced during the late 19th century.[16] Despite a shift in focus in the twentieth century towards formalism and generative grammar, which studies the universal properties of language, historical research today still remains a significant field of linguistic inquiry. Subfields of the discipline include language change and grammaticalisation.[17]""",
    """in the formation education that is carried out within the scope of undergraduate and non thesis graduate programs within the same university different criteria are used to evaluate students success. in this study classification accuracy of letter grades that are generated to evaluate students success using relative and absolute criteria and decisions for students passing or failing a course were examined. within the scope of this study it was also intended to determine the cut off point required for students to pass a course. in this regard midterm and final grades of a total of students. first correct classification percentages of the letter grades that the students scored with absolute and relative evaluations were calculated. then classification percentages for decisions regarding passing or failing a course were examined.""",
    """Hi. I'm Phil from BBC learning English. Today, I'm going to tell you how to use make and do not. I can be tricky and there are some exceptions, but Hero full things to remember use make when we create something.  Play Miss Kate. We use due to talk about an activity. What are you doing? We can use mic to talk about something that causes a reaction. This music reading makes me want to sing. We can also use do with General activities. What are you doing, tomorrow? I will doing anything."""
]

In [24]:
precision = load_metric('precision')
recall = load_metric('recall')
f1 = load_metric('f1')
accuracy = load_metric('accuracy')

In [20]:
from pprint import pprint

def demo(texts: List[str]) -> List[str]:
    """
    Args:
        texts: list of texts to demonstrate.
    Returns:
        punctated texts by BERT prediction.
    """
    if not isinstance(texts, list):
        texts = [texts]
    
    clean_texts = [clean(text).lower().strip() for text in texts]
    text_dataset = {'original': clean_texts,
                    'source': [text.replace('.', '') for text in clean_texts],
                    'is_eos': [[int(w[-1]=='.') for w in text.split()] for text in clean_texts]}
    
    text_dataset = Dataset.from_dict(text_dataset)
    text_dataset.features['is_eos'] = ds_features.Sequence(ds_features.ClassLabel(2, names=[0,1]))
    
    samples = DatasetDict({'sample': text_dataset})
    samples = samples.map(tokenize_and_align_labels, batched=True)
    samples = samples.remove_columns(['original', 'source', 'is_eos'])['sample']
    
    preds, labels, infos = trainer.predict(samples)
    preds = np.argmax(preds, axis=2)
    labels = np.clip(labels, 0, np.inf).astype(int)
    pprint(infos)
    print()
    
    pred_texts = []
    
    for i, (pred, label) in enumerate(zip(preds, labels)):
        print(f'========= text [{i}] =========')
        seqlen = len(samples['input_ids'][i])

        print(f'accuracy: {np.mean((pred==label)[:seqlen])*100:.2f}')
             
        print(f'----- original text -----')
        print(text_dataset['original'][i].replace('.', '.\n'))
        
        pred_text = tokenizer.convert_ids_to_tokens(samples['input_ids'][i])
        pred_text = [w if l==0 else w+'.' for w,l in zip(pred_text, pred)]
        pred_text = ' '.join(pred_text[1:-1]).replace(' ##', '')
        pred_texts.append(pred_text)
        print(f'----- predicted text -----')
        print(pred_text.replace('.', '.\n'))
    
    return pred_texts
    
demo(sample_text)

  0%|          | 0/1 [00:00<?, ?ba/s]

***** Running Prediction *****
  Num examples = 4
  Batch size = 64


  0%|          | 0/1 [00:00<?, ?it/s]

{'test_accuracy': 0.9786729857819905,
 'test_f1': 0.8085106382978724,
 'test_loss': 0.05699896812438965,
 'test_precision': 0.8636363636363636,
 'test_recall': 0.76,
 'test_runtime': 6.3409,
 'test_samples_per_second': 0.631,
 'test_steps_per_second': 0.158}

accuracy: 100.00
----- original text -----
usb was designed to standardize the connection of peripherals to personal computers both to communicate with and to supply electric power.
 it has largely replaced interfaces such as serial ports and parallel ports and has become commonplace on a wide range of devices.
 examples of peripherals that are connected via usb include computer keyboards and mice video cameras printers portable media players mobile digital telephones disk drives and network adapters.
 usb connectors have been increasingly replacing other types as charging cables of portable devices.

----- predicted text -----
usb was designed to standardize the connection of peripherals to personal computers both to communicate 

['usb was designed to standardize the connection of peripherals to personal computers both to communicate with and to supply electric power. it has largely replaced interfaces such as serial ports and parallel ports and has become commonplace on a wide range of devices. examples of peripherals that are connected via usb include computer keyboards and mice video cameras printers portable media players mobile digital telephones disk drives and network adapters. usb connectors have been increasingly replacing other types as charging cables of portable devices.',
 'historical linguistics is the study of language changes in history particularly with regard to a specific language or a group of languages. western trends in historical linguistics date back to roughly the late th century when the discipline grew out of philology the study of ancient texts and oral traditions. historical linguistics emerged as one of the first few sub disciplines in the field and was most widely practiced during

# API

`punctuate()` 메소드만 쓰면 됩니다.

In [1]:
from typing import Optional, Tuple, Dict, List, Any

from transformers import (
    PreTrainedModel,
    AutoTokenizer, 
    AutoModelForTokenClassification, 
    DataCollatorForTokenClassification
    )

def load_sent_seg_model(model_name: str):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForTokenClassification.from_pretrained(model_name)
    data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
    return tokenizer, model, data_collator

2022-08-05 12:23:20.638473: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
model_name = 'seg-model-distilBERT-finetuned'
# tokenizer, model, data_collator = load_sent_seg_model(model_name)

In [12]:
import re
import torch as th
from torch.utils.data import DataLoader
from datasets import Dataset

# model, tokenizer, data collator needed!

def clean(text: str) -> str:
    regex = [(r'\n+', ' '),
           (r'\([^\(\)]*\)|\[[^\[\]]*\]', ' '),
           (r'\$.*\$', ' '),
           (r'\\[^\s]+', ' '),
           (r'\d+\.\d+', ' '),
           (r'[^a-zA-Z\. ]', ' '),
           (r' *\. *', '. '),
           (r' +', ' ')]
    
    cleaned = text
    for pattern, repl in regex:
        cleaned = re.sub(pattern, repl, cleaned)
    
    return cleaned.strip()

def construct_dataset(tokenizer, texts: List[str], puncuated=True):
    clean_texts = [clean(text).lower().strip() for text in texts]
    if puncuated:
        dataset = Dataset.from_dict({'original': clean_texts,
                                     'source': [text.replace('.', '') for text in clean_texts],
                                     'is_eos': [[int(w[-1]=='.') for w in text.split()] for text in clean_texts]})

        dataset = dataset.map(tokenize_and_align_labels, batched=True)
        dataset = dataset.remove_columns(['original', 'source', 'is_eos'])
    
    else:
        dataset = Dataset.from_dict(tokenizer(clean_texts, truncation=True))
    
    return dataset


def punctuate(texts: List[str]) -> List[str]:
    """
    Args:
        texts: STT완료한 non-puncuated 문장들, 꼭 List[str] 형식으로 넣어줄 것.
    Returns:
        주어진 texts들에 대해 온점을 찍은 문장들
    """
    tokenizer, model, data_collator = load_sent_seg_model(model_name)
    dataset = construct_dataset(tokenizer, texts, puncuated=False)
    
    dataloader = DataLoader(dataset, 
                            batch_size=1, 
                            shuffle=False, 
                            collate_fn=data_collator)
    
    def insert_punct(token_ids, preds):
        args = th.argwhere(preds==1)
        subwords = tokenizer.convert_ids_to_tokens(token_ids.squeeze())
        for arg in args:
            subwords[arg] += '.'
        return subwords
            
    pred_texts = []
    for batch in dataloader:
        input_ids = batch['input_ids'].to(model.device)
        attention_mask = batch['attention_mask'].to(model.device)
        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        pred = th.argmax(logits, dim=2).squeeze()
        
        pred_text = insert_punct(input_ids, pred)
        pred_text = ' '.join(pred_text[1:-1]).replace(' ##', '')
        pred_texts.append(pred_text)
    
    return pred_texts
    

In [13]:
result = punctuate([text.replace('.', ' ') for text in sample_text])

In [14]:
for i, res in enumerate(result):
    print(f'===== doc {i} =====')
    print(res.replace('.', '.\n'))

===== doc 0 =====
usb was designed to standardize the connection of peripherals to personal computers both to communicate with and to supply electric power.
 it has largely replaced interfaces such as serial ports and parallel ports and has become commonplace on a wide range of devices.
 examples of peripherals that are connected via usb include computer keyboards and mice video cameras printers portable media players mobile digital telephones disk drives and network adapters.
 usb connectors have been increasingly replacing other types as charging cables of portable devices.

===== doc 1 =====
historical linguistics is the study of language changes in history particularly with regard to a specific language or a group of languages.
 western trends in historical linguistics date back to roughly the late th century when the discipline grew out of philology the study of ancient texts and oral traditions.
 historical linguistics emerged as one of the first few sub disciplines in the field 