In [1]:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
__author__ = 'Shining'
__email__ = 'shining.shi@alibaba-inc.com'

# POS-Tagger-for-Punctuation-Restoration
Demo for the paper [Incorporating External POS Tagger for Punctuation Restoration](https://arxiv.org/abs/2106.06731) in proceedings of the [*2021 Conference of the International Speech Communication Association (INTERSPEECH)*](https://www.interspeech2021.org/).

***Note***
1. Please download a sample checkpoint from ***[here](https://drive.google.com/drive/folders/1mJ0ReoAToiENozPLhygJiX2FwVjcBNl7?usp=sharing)***, and save as ***main/res/check_points/funnel-transformer-xlarge/xfmr/pretrained/random/linear.pt***.
2. The file path of the checkpoint depends on your ***config***.
2. For errors regarding ipywidgets, please try:  
```Python
pip install ipywidgets
```
3. It takes time to download resources for the first time for flair pos tagger and language model.
4. The downloading progress bar may not work in jupyter, try to pass train.py first.
5. The model proposed in the original paper is trained and evaluated with a maximum sequence length of 256. Thus, the demo results may change according to the sequence length of your test samples.

# Dependency

In [2]:
# public
from flair.data import Sentence
import torch
from torch.utils import data as torch_data
from tqdm import tqdm
# private
from train import Restorer
from src.utils import pipeline

# Helper Function

In [3]:
class Dataset(torch_data.Dataset):
    """docstring for Dataset"""
    def __init__(self, xs, y_masks, y_tags):
        super(Dataset, self).__init__()
        self.xs = xs
        self.y_masks = y_masks
        self.y_tags = y_tags
        self.data_size = len(self.xs)

    def __len__(self): 
        return self.data_size

    def __getitem__(self, idx):
        return self.xs[idx], self.y_masks[idx], self.y_tags[idx]

def collate_fn(data): 
    # a customized collate function used in the data loader 
    data.sort(key=len, reverse=True)
    raw_xs, raw_y_masks, raw_y_tags = zip(*data)
    xs, x_masks, y_masks, y_tags = [], [], [], []
    for i in range(len(raw_xs)):
        x = raw_xs[i]
        y_mask = raw_y_masks[i]
        y_tag = raw_y_tags[i]
        # padding
        if len(x) < re.config.max_seq_len:
            diff_len = re.config.max_seq_len - len(x)
            x += [re.config.PAD_TOKEN for _ in range(diff_len)]
            y_mask += [0 for _ in range(diff_len)]
            y_tag += [re.config.X_TAG for _ in range(diff_len)]
        x_mask = [0 if token == re.config.PAD_TOKEN else 1 for token in x]
        x = re.tokenizer.convert_tokens_to_ids(x)
        xs.append(x)
        x_masks.append(x_mask)
        y_masks.append(y_mask)
        y_tag = re.pos_tagger.tag_dictionary.get_idx_for_items(y_tag)
        y_tags.append(y_tag)
    return (raw_xs, raw_y_masks, raw_y_tags), (xs, x_masks, y_masks, y_tags)

def translate(seq: list, trans_dict: dict) -> list: 
    return [trans_dict[token] for token in seq]

def post_process(xs, x_masks, y_masks, ys_, tokenizer, config):
    # remove padding
    xs, x_masks = (i.cpu().detach().numpy().tolist() for i in (xs, x_masks))
    ys_ = torch.argmax(ys_, dim=2).cpu().detach().numpy().tolist()
    xs_lens = [sum(x_mask) for x_mask in x_masks]
    xs = [x[:l] for x, l in zip(xs, xs_lens)]
    y_masks = [y_mask[:l] for y_mask, l in zip(y_masks.tolist(), xs_lens)]
    ys_ = [y_[:l] for y_, l in zip(ys_, xs_lens)]
    xs = [tokenizer.convert_ids_to_tokens(x) for x in xs]
    ys_ = [translate(y_, config.idx2label_dict) for y_ in ys_]
    return xs, y_masks, ys_

def restore_pun(x, y_mask, y):
    for i in range(len(x)):
        if y_mask[i] and y[i] != re.config.NORMAL_TOKEN:
            x.insert(i+1, y[i])
    return x

# Demo
## Initialization

In [4]:
# setup model class
re = Restorer()

2021-08-14 03:55:05,040 loading file /Users/shining/.flair/models/upos-english-fast/b631371788604e95f27b6567fe7220e4a7e8d03201f3d862e6204dbf90f9f164.0afb95b43b32509bf4fcc3687f7c64157d8880d08f813124c1bd371c3d8ee3f7


In [5]:
# restore model from checkpoint
model = pipeline.pick_model(re.config)
checkpoint_to_load =  torch.load(
    re.config.SAVE_POINT
    , map_location=re.config.device)
model.load_state_dict(checkpoint_to_load['model'])
model.eval()
print('Model restored from {}.'.format(re.config.SAVE_POINT))

Model restored from /Users/shining/Library/Mobile Documents/com~apple~CloudDocs/Desktop/ShiningLab/project/Punctuation Restoration/exp/main/res/check_points/funnel-transformer-xlarge/xfmr/pretrained/random/linear.pt.


## Data

In [6]:
# test case
x1 = "it can be a very complicated thing the ocean \
and it can be a very complicated thing what human health is and \
bringing those two together might seem a very daunting task"

x2 = "i 'm as a font or more precisely a high-functioning autistic \
savant it 's a rare condition and rarer still when accompanied as in \
my case by self-awareness and a mastery of language very often when i \
meet someone and they've learned is about me there 's a certain kind of awkwardness"

raw_xs = [x1, x2]

In [7]:
# tokenization
xs, y_masks = [], []
for i in range(len(raw_xs)):
    x, y_mask = [], []
    for word in raw_xs[i].split():
        tokens = re.tokenizer.tokenize(word)
        for j in range(len(tokens)-1):
            x.append(tokens[j])
            y_mask.append(0)
        x.append(tokens[-1])
        y_mask.append(1)
    xs.append(x)
    y_masks.append(y_mask)

In [8]:
# POS tagging
y_tags = []
for x, y_mask in zip(xs, y_masks):
    y_tag = []
    sent = Sentence(x)
    re.pos_tagger.predict(sent)
    tags = [e.tag for e in sent.get_spans('pos')]
    for t, m in zip(tags, y_mask):
        y_tag.append(t) if m else y_tag.append(re.config.X_TAG)
    y_tags.append(y_tag)

In [9]:
# pre-processing
data = (xs, y_masks, y_tags)
xs, y_masks, y_tags = [], [], []
for tokens, masks, tags in zip(*data):
    x, y_mask, y_tag = [re.config.BOS_TOKEN], [0], [re.config.X_TAG]
    for token, mask, tag in zip(tokens, masks, tags):
        x.append(token)
        y_mask.append(mask)
        y_tag.append(tag)
    x.append(re.config.EOS_TOKEN)
    y_mask.append(0)
    y_tag.append(re.config.X_TAG)
    xs.append(x)
    y_masks.append(y_mask)
    y_tags.append(y_tag)

In [10]:
dataset = Dataset(xs, y_masks, y_tags)
dataloader = torch_data.DataLoader(
    dataset
    , batch_size=re.config.batch_size
    , collate_fn=collate_fn
    , shuffle=False
    , num_workers=re.config.num_workers
    , pin_memory=re.config.pin_memory
    , drop_last=False)

## Inference

In [11]:
all_xs, all_ys, all_y_masks, all_ys_ = [], [], [], []
dl = tqdm(dataloader)
with torch.no_grad():
    for data_pair in dl:
        raw_data, data = data_pair
        xs, x_masks, y_masks, y_tags = (torch.LongTensor(_).to(re.config.device) for _ in data)
        ys_ = model(xs, x_masks, y_tags)
        
#         print(raw_data[0][0])
#         print(len(xs[0])) 
#         print(re.tokenizer.convert_ids_to_tokens(xs[0]))
#         print(x_masks[0])
#         print([re.config.idx2label_dict[label] for label in ys[0].tolist()])
#         print(y_masks[0])
#         print(pipeline.translate(y_tags[0], re.pos_tagger.tag_dictionary.idx2item))
        
        xs, y_masks, ys_ = post_process(xs, x_masks, y_masks, ys_, re.tokenizer, re.config)
        all_xs += xs
        all_y_masks += y_masks
        all_ys_ += ys_

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.17s/it]


In [12]:
for i in range(len(all_xs)):
    print('x: {}'.format(all_xs[i]))
    print('y_mask: {}'.format(all_y_masks[i]))
    print('y_: {}'.format(all_ys_[i]))
    print()

x: ['<s>', 'it', 'can', 'be', 'a', 'very', 'complicated', 'thing', 'the', 'ocean', 'and', 'it', 'can', 'be', 'a', 'very', 'complicated', 'thing', 'what', 'human', 'health', 'is', 'and', 'bringing', 'those', 'two', 'together', 'might', 'seem', 'a', 'very', 'da', '##unt', '##ing', 'task', '</s>']
y_mask: [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0]
y_: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'COMMA', 'O', 'COMMA', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'COMMA', 'O', 'O', 'O', 'PERIOD', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'PERIOD', 'O']

x: ['<s>', 'i', "'", 'm', 'as', 'a', 'font', 'or', 'more', 'precisely', 'a', 'high', '-', 'functioning', 'au', '##tist', '##ic', 'sava', '##nt', 'it', "'", 's', 'a', 'rare', 'condition', 'and', 'rare', '##r', 'still', 'when', 'accompanied', 'as', 'in', 'my', 'case', 'by', 'self', '-', 'awareness', 'and', 'a', 'mastery', 'of', 'language', 'very', 'often', 'when', 'i', 'meet', '

In [13]:
for i in range(len(all_xs)):
    print('***Test {}***'.format(i))
    for word, pun in zip(all_xs[i], all_ys_[i]):
        print('{}\t{}'.format(word, pun))
    print()

***Test 0***
<s>	O
it	O
can	O
be	O
a	O
very	O
complicated	O
thing	COMMA
the	O
ocean	COMMA
and	O
it	O
can	O
be	O
a	O
very	O
complicated	O
thing	COMMA
what	O
human	O
health	O
is	PERIOD
and	O
bringing	O
those	O
two	O
together	O
might	O
seem	O
a	O
very	O
da	O
##unt	O
##ing	O
task	PERIOD
</s>	O

***Test 1***
<s>	O
i	O
'	O
m	COMMA
as	O
a	O
font	COMMA
or	O
more	O
precisely	COMMA
a	O
high	O
-	O
functioning	O
au	O
##tist	O
##ic	O
sava	O
##nt	PERIOD
it	O
'	O
s	O
a	O
rare	O
condition	COMMA
and	O
rare	O
##r	O
still	O
when	O
accompanied	COMMA
as	O
in	O
my	O
case	COMMA
by	O
self	O
-	O
awareness	O
and	O
a	O
mastery	O
of	O
language	PERIOD
very	O
often	O
when	O
i	O
meet	O
someone	O
and	O
they	O
'	O
ve	O
learned	O
is	O
about	O
me	COMMA
there	O
'	O
s	O
a	O
certain	O
kind	O
of	O
awkward	O
##ness	O
</s>	O

