In [1]:
import matplotlib.pyplot as plt
%matplotlib inline
import os
import json
import pickle
from transformers import BertTokenizer
import torch
import numpy as np
import re

import warnings
warnings.filterwarnings("ignore")

from tqdm.notebook import tqdm

In [2]:
data_path = "D:/Data/neural-punctuator/szeged/"
file_path = data_path + "szeged.txt"

In [3]:
with open(file_path, 'r', encoding='utf-8') as f:
    text = f.readlines()

In [4]:
len(text), text[1]

(82100,
 ' A szállásunk egy Balaton melletti kis üdülőfaluban, Zamárdiban volt, a Postának az üdülőházában.\n')

In [5]:
# tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-multilingual-uncased')

Using cache found in C:\Users\gbenc/.cache\torch\hub\huggingface_pytorch-transformers_master


In [6]:
def clean_text(text):
    escape_words = (' (Laughter)', ' (Applause)', ' (Music)', ' (Cheering)', ' (Singing)', ' (Video)')

    for ew in escape_words:
        text = text.replace(ew, '')

    text = text.replace('!', '.')
    text = text.replace(':', ',')
    text = text.replace('--', ',')
    text = text.replace('-', ',')
    text = text.replace(' ,', ',')
    text = text.replace('♫', '')

    text = re.sub(r'--\s?--', '', text)
    text = re.sub(r'\s+', ' ', text)

    text = text.replace(' ,', ',')
    
    return text.strip()

In [7]:
text = [clean_text(t) for t in text]
text[1]

'A szállásunk egy Balaton melletti kis üdülőfaluban, Zamárdiban volt, a Postának az üdülőházában.'

In [8]:
tokenizer.encode(".?,")

[101, 119, 136, 117, 102]

In [17]:
id2target = {-1: 0,
              119: 1, # .
              136: 2, # ?
              117: 3,  # ,
              -2: -1, # will be masked
             }
target2id = {value: key for key, value in id2target.items()}
    
    
def create_target(encoded):
    targets = []
    text = []

    target = -2 # Always mask after [CLS] token

    text.append(encoded[0])
    for word in encoded[1:]:
        if word in id2target.keys():
            target = word
        else:
            if tokenizer._convert_id_to_token(word).startswith('##'):
                target = -2
            targets.append(target)
#             print(target)
#             print(tokenizer._convert_id_to_token(word), '\t', end="")
            
            target = -1
            text.append(word)

    targets.append(target)

    targets = [id2target[t] for t in targets]

    return text, targets

In [10]:
train_n = 73_000
valid_n = 8_500

train_text = ' '.join(text[:train_n])
valid_text = ' '.join(text[train_n:train_n+valid_n])
test_text = ' '.join(text[train_n+valid_n:])

len(train_text.split(' ')), len(valid_text.split(' ')), len(test_text.split(' '))

(1119582, 124808, 10817)

In [13]:
train_tokens = tokenizer.encode(train_text)
valid_tokens = tokenizer.encode(valid_text)
test_tokens = tokenizer.encode(test_text)

Token indices sequence length is longer than the specified maximum sequence length for this model (2487559 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (274098 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (25593 > 512). Running this sequence through the model will result in indexing errors


In [18]:
train_tokens, train_targets = create_target(train_tokens)
valid_tokens, valid_targets = create_target(valid_tokens)
test_tokens, test_targets = create_target(test_tokens)

In [20]:
# For backward campatibility
train_tokens, train_targets = [train_tokens], [train_targets]
valid_tokens, valid_targets = [valid_tokens], [valid_targets]
test_tokens, test_targets = [test_tokens], [test_targets]

In [24]:
with open(data_path + 'train_data.pkl', 'wb') as f:
    pickle.dump((train_tokens, train_targets), f)
with open(data_path + 'valid_data.pkl', 'wb') as f:
    pickle.dump((valid_tokens, valid_targets), f)
with open(data_path + 'test_data.pkl', 'wb') as f:
    pickle.dump((test_tokens, test_targets), f)