In [1]:
import re
import pandas as pd
from datasets import Dataset, DatasetDict
from transformers import BertTokenizerFast
import os
from tqdm import tqdm

# 1. define punctuation map
punctuation_map = {
    ',': 'COMMA',
    '.': 'PERIOD',
    '?': 'QUESTION',
    '!': 'EXCLAMATION',
    ';': 'SEMICOLON',
    ':': 'COLON',
    '-': 'HYPHEN',
    '–': 'EN_DASH',
    '—': 'EM_DASH',
    '(': 'LEFT_PAREN',
    ')': 'RIGHT_PAREN',
    '[': 'LEFT_BRACKET',
    ']': 'RIGHT_BRACKET',
    '{': 'LEFT_BRACE',
    '}': 'RIGHT_BRACE',
    '"': 'DOUBLE_QUOTE',
    "'": 'SINGLE_QUOTE',
    '…': 'ELLIPSIS',
    '/': 'SLASH',
    '\\': 'BACKSLASH',
    '@': 'AT_SYMBOL',
    '#': 'HASH',
    '$': 'DOLLAR',
    '%': 'PERCENT',
    '&': 'AMPERSAND',
    '*': 'ASTERISK',
    '+': 'PLUS',
    '=': 'EQUALS',
    '<': 'LESS_THAN',
    '>': 'GREATER_THAN',
    '|': 'PIPE',
    '^': 'CARET',
    '`': 'BACKTICK',
    '~': 'TILDE'
}

# Automatically create label_list from punctuation_map
label_list = ["O"] + list(punctuation_map.values())
label_to_id = {l: i for i, l in enumerate(label_list)}

print("Label list:", label_list)

  from .autonotebook import tqdm as notebook_tqdm


Label list: ['O', 'COMMA', 'PERIOD', 'QUESTION', 'EXCLAMATION', 'SEMICOLON', 'COLON', 'HYPHEN', 'EN_DASH', 'EM_DASH', 'LEFT_PAREN', 'RIGHT_PAREN', 'LEFT_BRACKET', 'RIGHT_BRACKET', 'LEFT_BRACE', 'RIGHT_BRACE', 'DOUBLE_QUOTE', 'SINGLE_QUOTE', 'ELLIPSIS', 'SLASH', 'BACKSLASH', 'AT_SYMBOL', 'HASH', 'DOLLAR', 'PERCENT', 'AMPERSAND', 'ASTERISK', 'PLUS', 'EQUALS', 'LESS_THAN', 'GREATER_THAN', 'PIPE', 'CARET', 'BACKTICK', 'TILDE']


In [2]:
punctuation_reverse_map = {v: k for k, v in punctuation_map.items()}
punctuation_reverse_map["O"] = " "   # no punctuation

In [3]:
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForTokenClassification

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load tokenizer and model
model_name = "thenlpresearcher/bert_punct_model"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
model.to(device)

Using device: cuda


2025-11-25 11:23:58.392136: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-25 11:23:58.456113: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-11-25 11:24:00.572300: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024

In [4]:
def get_word_and_prediction_ids(text: str, model, tokenizer, device) -> tuple[np.ndarray, np.ndarray]:
    """
    Takes an unpunctuated text string and returns the word IDs and prediction IDs 
    for all tokens in the sequence.
    """
    words = text.lower().split()
    
    if not words:
        return np.array([], dtype=np.int64), np.array([], dtype=np.int64)
    
    encoded_input = tokenizer(
        words, 
        is_split_into_words=True, 
        return_tensors="pt", 
        padding=True, 
        truncation=True
    )
    
    print(encoded_input)
    print(tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][0]))
    
    encoded_input = encoded_input.to(device)

    with torch.no_grad():
        outputs = model(**encoded_input)
    
    logits = outputs.logits
    
    # Squeeze to handle single-item batch
    pred_ids = torch.argmax(logits, dim=-1).squeeze().cpu().numpy()
    
    # Convert list of word IDs (including None) to a numpy array, 
    # replacing None with a placeholder -1 for special tokens
    word_ids_list = encoded_input.word_ids()
    word_ids_array = np.array([w if w is not None else -1 for w in word_ids_list], dtype=np.int64)

    return word_ids_array, pred_ids

In [5]:
def restore_punctuation(text: str, model, tokenizer, label_list, device) -> str:
    """
    Restores punctuation to an unpunctuated text string using the BERT Punctuation model.
    """
    words = text.strip().split()
    w_ids, p_ids = get_word_and_prediction_ids(text, model, tokenizer, device)
    w_ids = w_ids[1:-1]
    print(f"Word ids: {w_ids}")
    
    p_ids = p_ids[1:-1]
    print(f"Prediction ids {p_ids}")
    
    final_output = []
    seen = [False]*(len(words))
    for w_id, p_id in zip(w_ids, p_ids):
        if not seen[w_id]:
            punct = punctuation_reverse_map[label_list[p_id]]
            if punct != " ":
                punct = punct + " "
            seen[w_id] = True
            final_output.extend([words[w_id], punct])
                 
    # Join the words back into a sentence string, capitalizing the first letter.
    result = "".join(final_output).strip()
    
    if result:
        # Capitalize the first letter
        return result[0].upper() + result[1:]
    return ""

print("\n--- Punctuation Restoration Results ---")

# # Example 1
# input_text_1 = "how old are you i am a language model"
# punctuated_text_1 = restore_punctuation(input_text_1, model, tokenizer, label_list, device)
# print(f"Original: {input_text_1}")
# print(f"Punctuated: {punctuated_text_1}")

# # Example 2
# input_text_2 = "what is the capital of france it is paris"
# punctuated_text_2 = restore_punctuation(input_text_2, model, tokenizer, label_list, device)
# print(f"\nOriginal: {input_text_2}")
# print(f"Punctuated: {punctuated_text_2}")

# # Example 3 (for demonstration of comma prediction)
# input_text_3 = "if you want to know more ask me anything"
# punctuated_text_3 = restore_punctuation(input_text_3, model, tokenizer, label_list, device)
# print(f"\nOriginal: {input_text_3}")
# print(f"Punctuated: {punctuated_text_3}")

# Example 4 (for demonstration of exclamation)
input_text_4 = "wow that is amazing isnt it"
punctuated_text_4 = restore_punctuation(input_text_4, model, tokenizer, label_list, device)
print(f"\nOriginal: {input_text_4}")
print(f"Punctuated: {punctuated_text_4}")

# # Example 4 (for demonstration of exclamation)
# input_text_5 = "i am kaustubh who are you"
# punctuated_text_5 = restore_punctuation(input_text_5, model, tokenizer, label_list, device)
# print(f"\nOriginal: {input_text_5}")
# print(f"Punctuated: {punctuated_text_5}")


--- Punctuation Restoration Results ---
{'input_ids': tensor([[  101, 10166,  2008,  2003,  6429,  3475,  2102,  2009,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])}
['[CLS]', 'wow', 'that', 'is', 'amazing', 'isn', '##t', 'it', '[SEP]']
Word ids: [0 1 2 3 4 4 5]
Prediction ids [ 1  0  0  1 17  0  3]

Original: wow that is amazing isnt it
Punctuated: Wow, that is amazing, isnt' it?


In [14]:
def restore_punctuation(text: str, model, tokenizer, label_list, device, punctuation_reverse_map) -> str:
    """
    Restores punctuation to an unpunctuated text string, utilizing tokenizer subwords
    and word_ids() for accurate mapping.
    """
    # 1. Tokenize the input text
    words = re.findall(r"\w+|[^\w\s]", text.strip())
    print(words)
    encoded_input = tokenizer(
        words, 
        is_split_into_words=True, 
        return_tensors="pt", 
        padding=True, 
        truncation=True
    ).to(device)
    
    # 2. Extract word IDs and perform inference
    word_ids_list = encoded_input.word_ids() # Maps subword token index to original word index (or None for special tokens)
    tokens = tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][0])
    tokens = tokens[1:-1]
    # Run model inference to get logits/predictions
    with torch.no_grad():
        outputs = model(**encoded_input)
        logits = outputs.logits
    
    # Get the predicted label index (p_id) for each token
    # We take the first element [0] because the input is a batch of size 1
    p_ids = torch.argmax(logits, dim=-1).squeeze().tolist()
    
    p_ids = p_ids[1:-1]
    word_ids_list = word_ids_list[1:-1]
    print(word_ids_list)
    print(p_ids)
    print(tokens)
    
    final_output = []
    
    i = 0
    for t, p in zip(tokens, p_ids):
        punct = punctuation_reverse_map[label_list[p]]
        t = t.strip('#')
        if t == punct:
            continue
        if punct != " " and i < len(word_ids_list) - 1 and word_ids_list[i] != word_ids_list[i+1]:
            punct = punct + " "
        if punct == " " and i < len(word_ids_list) - 1 and word_ids_list[i] == word_ids_list[i+1]:
            punct = ""
        
        if i < len(word_ids_list) - 1 and word_ids_list[i] == word_ids_list[i+1] and p_ids[i] == p_ids[i+1]:
            final_output.append(t)
        else:
            final_output.extend([t, punct])
        i = i + 1
        
    # 6. Final cleanup
    result = "".join(final_output).strip()
    
    if result:
        # Capitalize the first letter
        return result[0].upper() + result[1:]
    return ""

# Assume model, tokenizer, label_list, device, and punctuation_reverse_map are defined
# NOTE: You will need to define or pass the 'punctuation_reverse_map' dictionary for this to run.

In [16]:
input_text_4 = "A six-month-old calf was submitted for examination, showing lameness in all four legs which had been present since soon after birth."
import string
input_text_4 = input_text_4.strip('.')
punctuated_text_4 = restore_punctuation(input_text_4, model, tokenizer, label_list, device, punctuation_reverse_map)
print(f"\nOriginal: {input_text_4}")
print(f"Punctuated: {punctuated_text_4}")

['A', 'six', '-', 'month', '-', 'old', 'calf', 'was', 'submitted', 'for', 'examination', ',', 'showing', 'lameness', 'in', 'all', 'four', 'legs', 'which', 'had', 'been', 'present', 'since', 'soon', 'after', 'birth']
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]
[0, 7, 7, 7, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2]
['a', 'six', '-', 'month', '-', 'old', 'calf', 'was', 'submitted', 'for', 'examination', ',', 'showing', 'lame', '##ness', 'in', 'all', 'four', 'legs', 'which', 'had', 'been', 'present', 'since', 'soon', 'after', 'birth']

Original: A six-month-old calf was submitted for examination, showing lameness in all four legs which had been present since soon after birth
Punctuated: A six- month- old calf was submitted for examination , showing lame ness inall four legs, which had been present since soon after birth.


In [8]:
input_text_4 = "good company is in tiruchirapalli i am good how are you"
punctuated_text_4 = restore_punctuation(input_text_4, model, tokenizer, label_list, device, punctuation_reverse_map)
print(f"\nOriginal: {input_text_4}")
print(f"Punctuated: {punctuated_text_4}")

['good', 'company', 'is', 'in', 'tiruchirapalli', 'i', 'am', 'good', 'how', 'are', 'you']
[0, 1, 2, 3, 4, 4, 4, 4, 4, 4, 5, 6, 7, 8, 9, 10]
[0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 0, 0, 2, 0, 0, 3]
['good', 'company', 'is', 'in', 'ti', '##ru', '##chi', '##ra', '##pal', '##li', 'i', 'am', 'good', 'how', 'are', 'you']

Original: good company is in tiruchirapalli i am good how are you
Punctuated: Good company is in tiruchirapalli. i am good. how are you?


In [43]:
def create_token_labels(sentence):
    tokens = []
    labels = []
    parts = re.findall(r"\w+|[^\w\s]", sentence)
    for i, part in enumerate(parts):
        if re.match(r"\w+", part):  # token
            tokens.append(part)
            if i+1 < len(parts) and parts[i+1] in punctuation_map:
                labels.append(punctuation_map[parts[i+1]])
            else:
                labels.append("O")
    return tokens, labels