# Punctation Restoration

In [20]:
from typing import Any, Union, List, Optional, Literal, Dict, Iterator, Iterable, Tuple
from dataclasses import dataclass
from pathlib import Path
import json

import logging

from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizer
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torch import nn
import torch.nn.functional as F
import torch

from hydra import compose, initialize_config_module, initialize
from omegaconf import DictConfig

In [16]:
logger = logging.getLogger(__name__)

## Data

In [17]:
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS

In [18]:
ZeroClass = Literal['O']
PunctuationMark = Literal['.', ',', '?', '!', '-', '...']
Label = Union[ZeroClass, PunctuationMark]

@dataclass(frozen=True)
class DatasetItem:
    first_token_pos: List[int]
    token_ids: List[int]
    label_ids: List[int]

    def __post_init__(self):
        assert len(self.token_ids) == len(self.label_ids), \
            'the number of tokens does not match the number of labels'
        assert len(self.first_token_pos) < len(self.token_ids), \
            'there cannot be more words, than subtokens'
        assert all(pos < len(self.token_ids) for pos in self.first_token_pos), \
            'word\'s first token position cannot exceed the number of subtokens'

In [90]:
class PunctuationRestorationDataModule(pl.LightningDataModule):
    PUNCTUATION_MARKS = ['.', ',', '?', '!', '-', '...']

    def __init__(
        self,
        cfg: DictConfig
    ):
        super().__init__()

        self.cfg = cfg

        self.tokenizer_name: str = cfg.model.encoder.tokenizer
        self.tokenizer: Optional[PreTrainedTokenizer] = None

        self.train_batch_size: int = cfg.trainer.train_batch_size
        self.val_batch_size: int = cfg.trainer.val_batch_size
        self.test_batch_size: int = cfg.trainer.test_batch_size
        self.predict_batch_size: int = cfg.trainer.predict_batch_size

        self.train_dataset: Optional[Iterable[DatasetItem]] = None
        self.val_dataset: Optional[Iterable[DatasetItem]] = None
        self.test_dataset: Optional[Iterable[DatasetItem]] = None
        self.predict_dataset: Optional[Iterable[DatasetItem]] = None

        self.idx2label: List[Label] = ['O'] + self.PUNCTUATION_MARKS
        self.label2idx: Dict[Label, int] = {label: idx for idx, label in enumerate(self.idx2label)}
        self.num_classes: int = len(self.idx2label)

    def prepare_data(self) -> None:
        AutoTokenizer.from_pretrained(self.tokenizer_name)

    def _parse_tsv_input(self, filepath: Path) -> Iterable[DatasetItem]:
        pass

    # TODO max_seq_len
    def _parse_json_input(self, filepath: Path) -> Iterable[DatasetItem]:
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)

        items: List[DatasetItem] = []
        for document in data:
            first_token_pos: List[int] = []
            token_ids: List[int] = []
            label_ids: List[int] = []
            for word in document['words']:
                if len(token_ids) >= self.cfg.trainer.max_seq_len-2:
                    break

                first_token_pos.append(len(token_ids))

                subtokens = self.tokenizer.tokenize(word['word'], )
                token_ids.extend(self.tokenizer.convert_tokens_to_ids(subtokens))

                label_ids.extend([self.label2idx.get(word['punctuation'], 0)] + [0] * (len(subtokens) - 1))

            token_ids = [self.tokenizer.cls_token_id] + token_ids[:self.cfg.trainer.max_seq_len-2] + [self.tokenizer.sep_token_id]
            label_ids = [0] + label_ids[:self.cfg.trainer.max_seq_len-2] + [0]

            items.append(DatasetItem(first_token_pos, token_ids, label_ids))

        return items

    # TODO other formats?

    def setup(self, stage: Optional[str] = None) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)

        if stage is None or stage == 'train':
            if self.cfg.data.train.endswith('.json'):
                self.train_dataset = self._parse_json_input(self.cfg.data.train)
            else:
                raise NotImplementedError('unknown format')

        if stage is None or stage == 'val':
            if self.cfg.data.val.endswith('.json'):
                self.val_dataset = self._parse_json_input(self.cfg.data.val)
            else:
                raise NotImplementedError('unknown format')

        if stage is None or stage == 'test':
            if self.cfg.data.test.endswith('.json'):
                self.test_dataset = self._parse_json_input(self.cfg.data.test)
            else:
                raise NotImplementedError('unknown format')

        if stage is None or stage == 'predict':
            if self.cfg.data.predict.endswith('.json'):
                self.predict_dataset = self._parse_json_input(self.cfg.data.predict)
            else:
                raise NotImplementedError('unknown format')

    def train_dataloader(self) -> TRAIN_DATALOADERS:
        return DataLoader(self.train_dataset,
                          batch_size=self.train_batch_size,
                          shuffle=True,
                          collate_fn=self.collator)

    def val_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(self.val_dataset,
                          batch_size=self.val_batch_size,
                          shuffle=False,
                          collate_fn=self.collator)

    def test_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(self.test_dataset,
                          batch_size=self.test_batch_size,
                          shuffle=False,
                          collate_fn=self.collator)

    def predict_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(self.predict_dataset,
                          batch_size=self.predict_batch_size,
                          shuffle=False,
                          collate_fn=self.collator)

    def collator(self, batch: Iterable[DatasetItem]) -> Tuple[List[List[int]], List[torch.Tensor], List[torch.Tensor]]:
        first_token_pos = [item.first_token_pos for item in batch]
        token_ids = [item.token_ids for item in batch]
        label_ids = [torch.tensor(item.label_ids) for item in batch]

        padded_token_ids = [self.tokenizer.pad(
            {'input_ids': tokens},
            padding='longest',
            return_tensors='pt'
        ) for tokens in token_ids]

        return first_token_pos, padded_token_ids, label_ids

#### Testing

In [91]:
!ls

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
configs  dataloader.py		 main.py   outputs
data	 developing-model.ipynb  model.py  scripts


In [92]:
with initialize(version_base=None, config_path='configs/'):
    cfg = compose(config_name='config.yaml')
    datamodule = PunctuationRestorationDataModule(cfg)
    datamodule.setup()

    dataloader = datamodule.test_dataloader()

    for item in dataloader:
        first_token_pos, token_ids, label_ids = item

        print(first_token_pos[0], end='\n\n')
        print(token_ids[0], end='\n\n')
        print(label_ids[0], end='\n\n')

        break

[0, 2, 4, 8, 9, 10, 12, 14, 15, 16, 19, 22, 23, 26, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 42, 44, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 62, 63, 64, 65, 66, 67, 68, 71, 72, 73, 74, 75, 77, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125]

{'input_ids': tensor([    0,  2181, 10347,  7990,  1055,  2289,    76,  6425,  4323,  4714,
         2282,  4737,  2178,  7418,  5657,  6105, 13087, 25528,  3453,  1024,
         2003,    93, 17861,  4506,  2003,    93,  2572, 16425,  2040,  5060,
        42135,  2642, 26168,  2408, 49699,  9791,  9843,  1019,  2194,  2254,
        11938,  2181, 10347,  7990,  1055,  2289,    76,  6425,  4323, 13944,
        37028, 35023, 13376, 43704,  5874, 12593, 34122,  1029,  1019,  2289,
           76,  6425,  1997, 14324,  2040, 27743, 19801,  1998, 12093, 25528,
         3453,  1024,

## Model

In [3]:
class RestorationModel(pl.LightningModule):
    def __init__(self, cfg: DictConfig, num_classes: int) -> None:
        super().__init__()

        self.num_classes = num_classes
        self.save_hyperparameters(cfg)

        self.encoder = AutoModel.from_pretrained(cfg.model.encoder.name)
        self.head = self._construct_head()

    def _construct_head(self) -> nn.Module:
        if self.hparams.model.head.architecture == 'mlp':
            mlp_config = self.hparams.model.head.mlp

            head = nn.Sequential()
            prev_layer_dim = self.encoder.config.hidden_size
            for i in range(mlp_config.num_layers - 1):
                layer = nn.Linear(prev_layer_dim, mlp_config.num_hiddens)
                prev_layer_dim = layer.out_features

                if mlp_config.nonlinearity == 'relu':
                    nonlinearity = nn.ReLU()
                else:
                    raise ValueError('unknown nonlinearity')

                head.add_module(f'cls{i}', layer)
                head.add_module(f'cls{i}_nonlinearity', nonlinearity)

            head.add_module('cls_pred', nn.Linear(prev_layer_dim, self.num_classes))

            return head

        else:
            raise ValueError('unknown head architecture')

    def common_step(self, batch):
        first_token_pos, model_inputs_batch, label_ids_batch = batch

        batch_preds = []
        batch_golds = []

        for model_inputs, label_ids in zip(model_inputs_batch, label_ids_batch):
            encoded = self.encoder(
                input_ids=model_inputs['input_ids'].unsqueeze(0),
                attention_mask=model_inputs['attention_mask'].unsqueeze(0),
                return_dict=True
            )['last_hidden_state'].squeeze()

            logits = self.head(encoded)

            batch_preds.append(logits)
            batch_golds.append(label_ids)

        batch_preds_tensor = torch.tensor(batch_preds)
        batch_golds_tensor = torch.tensor(batch_golds)

        loss = F.cross_entropy(batch_preds_tensor.view(-1, batch_preds_tensor.size(-1)),
                               batch_golds_tensor.view(-1))

        return batch_preds, batch_golds, loss

    def training_step(self, batch, batch_idx):
        preds, golds, loss = self.common_step(batch)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        preds, golds, loss = self.common_step(batch)
        self.log('val_loss', loss)
        return preds, golds

    # FIXME shouldn't it be validation_step_end instead? https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#validating-with-dataparallel
    def validation_epoch_end(self, validation_step_outputs) -> None:
        print(validation_step_outputs)

        # TODO softmax
        all_preds = torch.stack(validation_step_outputs)

        # TODO calculate and log metrics


    def test_step(self, batch, batch_idx):
        preds, golds, loss = self.common_step(batch)
        self.log('test_loss', loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.training.learning_rate)
        return optimizer


In [94]:
model_inputs = {'input_ids': torch.tensor([    0,  2181, 10347,  7990,  1055,  2289,    76,  6425,  4323,  4714,
                                      2282,  4737,  2178,  7418,  5657,  6105, 13087, 25528,  3453,  1024,
                                      2003,    93, 17861,  4506,  2003,    93,  2572, 16425,  2040,  5060,
                                      42135,  2642, 26168,  2408, 49699,  9791,  9843,  1019,  2194,  2254,
                                      11938,  2181, 10347,  7990,  1055,  2289,    76,  6425,  4323, 13944,
                                      37028, 35023, 13376, 43704,  5874, 12593, 34122,  1029,  1019,  2289,
                                      76,  6425,  1997, 14324,  2040, 27743, 19801,  1998, 12093, 25528,
                                      3453,  1024,  5114,  2022,  2408, 49699,  9791,  9843, 10387,  8905,
                                      13789, 12593,  1046,  3131,  2431,  2291,  1998,  2553,  2291,  3420,
                                      20488,  1998,  6530,  3723, 25062, 29592,  2944, 12462,  1009,  3938,
                                      11938, 14324,  2040, 13171, 31267,  1998,  2194, 12093,  9460,  2022,
                                      1019,  5235,  6105, 13087, 15922,  2040,  6012, 15895, 12593, 32710,
                                      2413,  1998,  9302, 11351,  8112,  2634,  3695,     2]), 'attention_mask': torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                                                                                                                         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                                                                                                                         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                                                                                                                         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                                                                                                                         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                                                                                                                         1, 1, 1, 1, 1, 1, 1, 1])}


In [75]:
encoder = AutoModel.from_pretrained('allegro/herbert-base-cased')

Downloading pytorch_model.bin:   0%|          | 0.00/624M [00:00<?, ?B/s]

Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.sso.sso_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.sso.sso_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [78]:
token_ids = token_ids.unsqueeze(0)

In [99]:
r = encoder(token_ids[0]['input_ids'].unsqueeze(0), return_dict=True)

In [102]:
r.keys()

odict_keys(['last_hidden_state', 'pooler_output'])

In [106]:
r['last_hidden_state'].size(), r['pooler_output'].size()

(torch.Size([1, 128, 768]), torch.Size([1, 768]))

In [100]:
len(r)

2

In [86]:
r[0].shape, r[1].shape

(torch.Size([1, 128, 768]), torch.Size([1, 768]))

In [88]:
r[0]

tensor([[[-0.0703, -0.0108, -0.0752,  ..., -0.7747,  0.0734,  0.2588],
         [-0.3318,  0.2007, -0.4598,  ...,  0.6457,  0.1525,  0.0627],
         [ 0.3860,  0.2852, -0.7873,  ..., -1.2532,  0.2973,  1.0523],
         ...,
         [ 0.1820, -0.0097, -0.1943,  ..., -0.4040,  0.1927, -0.1603],
         [ 0.2354, -0.1440, -0.0039,  ..., -0.9941, -0.1513, -0.6591],
         [ 0.3854,  0.2836, -0.7765,  ..., -1.2490,  0.3054,  1.0535]]],
       grad_fn=<NativeLayerNormBackward>)

In [89]:
type(encoder)

transformers.models.bert.modeling_bert.BertModel

In [108]:
r['last_hidden_state'].squeeze().size()

torch.Size([128, 768])

In [115]:
r['last_hidden_state'].size(-1)

768

In [117]:
r['last_hidden_state'].view(-1).size()

torch.Size([98304])