# Punctation Restoration

In [39]:
from typing import Any, Union, List, Optional, Literal, Dict, Iterator, Iterable, Tuple
from dataclasses import dataclass
from pathlib import Path
import json, sys, os

import logging

import numpy as np
import numpy.typing as npt
from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizer
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torch import nn
import torch.nn.functional as F
import torch

from sklearn.metrics import precision_recall_fscore_support

from hydra import compose, initialize
from omegaconf import DictConfig

In [41]:
np.set_printoptions(threshold=sys.maxsize)
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

In [7]:
logger = logging.getLogger('lightning')

## Data

In [8]:
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS

In [9]:
ZeroClass = Literal['O']
PunctuationMark = Literal['.', ',', '?', '!', '-', '...']
Label = Union[ZeroClass, PunctuationMark]

ValidationEpochOutputs = Union[
    List[
        Union[torch.Tensor, Dict[str, Any]]
    ],

    List[List[
        Union[torch.Tensor, Dict[str, Any]]
    ]]
]

@dataclass(frozen=True)
class DatasetItem:
    first_token_pos: List[int]
    token_ids: List[int]
    label_ids: List[int]

    def __post_init__(self):
        assert len(self.token_ids) == len(self.label_ids), \
            'the number of tokens does not match the number of labels'
        assert len(self.first_token_pos) < len(self.token_ids), \
            'there cannot be more words, than subtokens'
        assert all(pos < len(self.token_ids) for pos in self.first_token_pos), \
            'word\'s first token position cannot exceed the number of subtokens'

In [34]:
class PunctuationRestorationDataModule(pl.LightningDataModule):
    PUNCTUATION_MARKS = ['.', ',', '?', '!', '-', '...']

    def __init__(
        self,
        cfg: DictConfig
    ):
        super().__init__()

        self.cfg = cfg

        self.tokenizer_name: str = cfg.model.encoder.tokenizer
        self.tokenizer: Optional[PreTrainedTokenizer] = None

        self.train_batch_size: int = cfg.trainer.train_batch_size
        self.val_batch_size: int = cfg.trainer.val_batch_size
        self.test_batch_size: int = cfg.trainer.test_batch_size
        self.predict_batch_size: int = cfg.trainer.predict_batch_size

        self.train_dataset: Optional[Iterable[DatasetItem]] = None
        self.val_dataset: Optional[Iterable[DatasetItem]] = None
        self.test_dataset: Optional[Iterable[DatasetItem]] = None
        self.predict_dataset: Optional[Iterable[DatasetItem]] = None

        self.idx2label: List[Label] = ['O'] + self.PUNCTUATION_MARKS
        self.label2idx: Dict[Label, int] = {label: idx for idx, label in enumerate(self.idx2label)}
        self.num_classes: int = len(self.idx2label)

    def prepare_data(self) -> None:
        tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
        del tokenizer

    def _parse_tsv_input(self, filepath: Path) -> Iterable[DatasetItem]:
        pass

    # TODO max_seq_len
    def _parse_json_input(self, filepath: Path) -> Iterable[DatasetItem]:
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)

        items: List[DatasetItem] = []
        for document in data[:100]:
            first_token_pos: List[int] = []
            token_ids: List[int] = []
            label_ids: List[int] = []
            for word in document['words']:
                if len(token_ids) >= self.cfg.trainer.max_seq_len-2:
                    break

                first_token_pos.append(len(token_ids))

                subtokens = self.tokenizer.tokenize(word['word'], )
                token_ids.extend(self.tokenizer.convert_tokens_to_ids(subtokens))

                label_ids.extend([self.label2idx.get(word['punctuation'], 0)] + [0] * (len(subtokens) - 1))

            token_ids = [self.tokenizer.cls_token_id] + token_ids[:self.cfg.trainer.max_seq_len-2] + [self.tokenizer.sep_token_id]
            label_ids = [0] + label_ids[:self.cfg.trainer.max_seq_len-2] + [0]

            items.append(DatasetItem(first_token_pos, token_ids, label_ids))

        return items

    # TODO other formats?

    def setup(self, stage: Optional[str] = None) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)

        if stage is None or stage == 'train':
            if self.cfg.data.train.endswith('.json'):
                self.train_dataset = self._parse_json_input(self.cfg.data.train)
            else:
                raise NotImplementedError('unknown format')

        if stage is None or stage == 'val':
            if self.cfg.data.val.endswith('.json'):
                self.val_dataset = self._parse_json_input(self.cfg.data.val)
            else:
                raise NotImplementedError('unknown format')

        if stage is None or stage == 'test':
            if self.cfg.data.test.endswith('.json'):
                self.test_dataset = self._parse_json_input(self.cfg.data.test)
            else:
                raise NotImplementedError('unknown format')

        if stage is None or stage == 'predict':
            if self.cfg.data.predict.endswith('.json'):
                self.predict_dataset = self._parse_json_input(self.cfg.data.predict)
            else:
                raise NotImplementedError('unknown format')

    def train_dataloader(self) -> TRAIN_DATALOADERS:
        return DataLoader(self.train_dataset,
                          batch_size=self.train_batch_size,
                          shuffle=True,
                          collate_fn=self.collator,
                          num_workers=6)

    def val_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(self.val_dataset,
                          batch_size=self.val_batch_size,
                          shuffle=False,
                          collate_fn=self.collator,
                          num_workers=3)

    def test_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(self.test_dataset,
                          batch_size=self.test_batch_size,
                          shuffle=False,
                          collate_fn=self.collator,
                          num_workers=3)

    def predict_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(self.predict_dataset,
                          batch_size=self.predict_batch_size,
                          shuffle=False,
                          collate_fn=self.collator,
                          num_workers=3)

    def collator(self, batch: Iterable[DatasetItem]) -> Tuple[List[List[int]], List[torch.Tensor], List[torch.Tensor]]:
        first_token_pos = [item.first_token_pos for item in batch]
        token_ids = [item.token_ids for item in batch]
        label_ids = [torch.tensor(item.label_ids) for item in batch]

        padded_token_ids = [self.tokenizer.pad(
            {'input_ids': tokens},
            padding='longest',
            return_tensors='pt'
        ) for tokens in token_ids]

        return first_token_pos, padded_token_ids, label_ids

#### Testing

In [29]:
!ls

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
configs        datatypes.py	       main.py	 __pycache__
data	       developing-model.ipynb  model.py  requirements.txt
datamodule.py  lightning_logs	       outputs	 scripts


In [12]:
with initialize(version_base=None, config_path='configs/'):
    cfg = compose(config_name='config.yaml')
    datamodule = PunctuationRestorationDataModule(cfg)
    datamodule.setup()

    dataloader = datamodule.test_dataloader()

    for item in dataloader:
        first_token_pos, token_ids, label_ids = item

        print(first_token_pos[0], end='\n\n')
        print(token_ids[0], end='\n\n')
        print(label_ids[0], end='\n\n')

        break

[0, 2, 4, 8, 9, 10, 12, 14, 15, 16, 19, 22, 23, 26, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 42, 44, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 62, 63, 64, 65, 66, 67, 68, 71, 72, 73, 74, 75, 77, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125]

{'input_ids': tensor([    0,  2181, 10347,  7990,  1055,  2289,    76,  6425,  4323,  4714,
         2282,  4737,  2178,  7418,  5657,  6105, 13087, 25528,  3453,  1024,
         2003,    93, 17861,  4506,  2003,    93,  2572, 16425,  2040,  5060,
        42135,  2642, 26168,  2408, 49699,  9791,  9843,  1019,  2194,  2254,
        11938,  2181, 10347,  7990,  1055,  2289,    76,  6425,  4323, 13944,
        37028, 35023, 13376, 43704,  5874, 12593, 34122,  1029,  1019,  2289,
           76,  6425,  1997, 14324,  2040, 27743, 19801,  1998, 12093, 25528,
         3453,  1024,

## Model

In [68]:
class RestorationModel(pl.LightningModule):
    def __init__(self, cfg: DictConfig, num_classes: int) -> None:
        super().__init__()

        self.num_classes = num_classes
        self.save_hyperparameters(cfg)

        self.loss_weights = torch.tensor([self.hparams.trainer.zero_class_weight] + [1.0 for _ in range(self.num_classes - 1)])

        self.encoder = AutoModel.from_pretrained(cfg.model.encoder.name)
        self.head = self._construct_head()

        self.head.apply(self._init_weights)

    @staticmethod
    def _init_weights(module: nn.Module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform(module.weight)
            module.bias.data.fill_(0.01)

    def _construct_head(self) -> nn.Module:
        if self.hparams.model.head.architecture == 'mlp':
            mlp_config = self.hparams.model.head.mlp

            head = nn.Sequential()
            prev_layer_dim = self.encoder.config.hidden_size
            for i in range(mlp_config.num_layers - 1):
                layer = nn.Linear(prev_layer_dim, mlp_config.num_hiddens)
                prev_layer_dim = layer.out_features

                if mlp_config.nonlinearity == 'relu':
                    nonlinearity = nn.ReLU()
                else:
                    raise ValueError('unknown nonlinearity')

                head.add_module(f'cls{i}', layer)
                head.add_module(f'cls{i}_nonlinearity', nonlinearity)

            head.add_module('cls_pred', nn.Linear(prev_layer_dim, self.num_classes))

            return head

        else:
            raise ValueError('unknown head architecture')

    def _calculate_metrics(self, preds: npt.NDArray[np.int], golds: npt.NDArray[np.int]) -> Dict[str, Any]:
        labels = list(range(1, self.trainer.datamodule.num_classes))
        pr_micro, rc_micro, f1_micro, _ = \
            precision_recall_fscore_support(golds, preds, average='micro', zero_division=0, labels=labels)
        pr_macro, rc_macro, f1_macro, _ = \
            precision_recall_fscore_support(golds, preds, average='macro', zero_division=0, labels=labels)
        pr_weighted, rc_weighted, f1_weighted, _ = \
            precision_recall_fscore_support(golds, preds, average='weighted', zero_division=0, labels=labels)
        pr_per_label, rc_per_label, f1_per_label, _ = \
            precision_recall_fscore_support(golds, preds, average=None, zero_division=0, labels=labels)

        metrics = {
                f'pr_{self.trainer.datamodule.idx2label[key]}': val
                for key, val in zip(range(1, self.trainer.datamodule.num_classes), pr_per_label)
            } | {
                f'rc_{self.trainer.datamodule.idx2label[key]}': val
                for key, val in zip(range(1, self.trainer.datamodule.num_classes), rc_per_label)
            } | {
                f'f1_{self.trainer.datamodule.idx2label[key]}': val
                for key, val in zip(range(1, self.trainer.datamodule.num_classes), f1_per_label)
        }

        metrics.update({
            'pr_micro': pr_micro,
            'rc_micro': rc_micro,
            'f1_micro': f1_micro,
            'pr_macro': pr_macro,
            'rc_macro': rc_macro,
            'f1_macro': f1_macro,
            'pr_weighted': pr_weighted,
            'rc_weighted': rc_weighted,
            'f1_weighted': f1_weighted,
        })
        return metrics

    def common_step(self, batch) -> (List[torch.Tensor], List[torch.Tensor], float):
        first_token_pos, model_inputs_batch, label_ids_batch = batch

        batch_preds = []
        batch_golds = []
        loss = 0

        for model_inputs, label_ids in zip(model_inputs_batch, label_ids_batch):
            encoded = self.encoder(
                input_ids=model_inputs['input_ids'].unsqueeze(0),
                attention_mask=model_inputs['attention_mask'].unsqueeze(0),
                return_dict=True
            )['last_hidden_state'].squeeze()

            logits = self.head(encoded)

            batch_preds.append(logits)
            batch_golds.append(label_ids)

            loss += F.cross_entropy(logits, label_ids, weight=self.loss_weights)

        return batch_preds, batch_golds, loss

    def training_step(self, batch, batch_idx):
        preds, golds, loss = self.common_step(batch)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return {
            'loss': loss
        }

    def validation_step(self, batch, batch_idx):
        preds, golds, loss = self.common_step(batch)
        self.log('val_loss', loss)
        return {
            'preds': preds,
            'golds': golds
        }

    # FIXME shouldn't it be validation_step_end instead? https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#validating-with-dataparallel
    def validation_epoch_end(self, validation_step_outputs: ValidationEpochOutputs) -> None:
        # TODO as a future feature we can aggregate results for each dataloader separately
        # flattening outputs from dataloaders if there are multiple
        if validation_step_outputs and isinstance(validation_step_outputs[0], list):
            outputs = [
                step_output
                for dataloader_output in validation_step_outputs
                for step_output in dataloader_output
            ]
        else:
            outputs = validation_step_outputs

        preds_logits = torch.cat([pred for step_output in outputs for pred in step_output['preds']])
        preds: torch.Tensor = torch.max(preds_logits, dim=1).indices
        golds: torch.Tensor = torch.cat([gold for step_output in outputs for gold in step_output['golds']])

        preds_numpy = preds.cpu().numpy()
        golds_numpy = golds.cpu().numpy()

        # rank_zero_info(str(preds_numpy))
        # rank_zero_info(str(golds_numpy))
        rank_zero_info(f'tp: {np.sum(np.where(golds_numpy != 0, preds_numpy == golds_numpy, False))}')
        rank_zero_info(f'nonzero count:, {np.count_nonzero(preds_numpy)}')
        rank_zero_info(f'preds_logits: {preds_logits[0]}')

        metrics = self._calculate_metrics(preds_numpy, golds_numpy)
        self.log_dict(metrics, on_epoch=True)

    def test_step(self, batch, batch_idx):
        preds, golds, loss = self.common_step(batch)
        self.log('test_loss', loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.trainer.learning_rate)
        return optimizer

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  def _calculate_metrics(self, preds: npt.NDArray[np.int], golds: npt.NDArray[np.int]) -> Dict[str, Any]:


## Callbacks

In [14]:
from pytorch_lightning.utilities import rank_zero_info

In [15]:
class MetricsLoggingCallback(pl.Callback):
    def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        micro_recall = trainer.callback_metrics['rc_micro']
        micro_precision = trainer.callback_metrics['pr_micro']
        micro_f1 = trainer.callback_metrics['f1_micro']

        weighted_recall = trainer.callback_metrics['rc_micro']
        weighted_precision = trainer.callback_metrics['pr_micro']
        weighted_f1 = trainer.callback_metrics['f1_weighted']

        rank_zero_info(f'micro // f1: {100*micro_f1:.2f}, recall: {100*micro_recall:.2f}, precision: {100*micro_precision:.2f}')
        rank_zero_info(f'weighted // f1: {100*weighted_f1:.2f}, recall: {100*weighted_recall:.2f}, precision: {100*weighted_precision:.2f}')

## Training

In [69]:
with initialize(version_base=None, config_path='configs/'):
    cfg = compose(config_name='config.yaml')

In [70]:
datamodule = PunctuationRestorationDataModule(cfg)
datamodule.setup()

                                                                   

In [71]:
model = RestorationModel(cfg, num_classes=7)

Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.sso.sso_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.sso.sso_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  torch.nn.init.xavier_uniform(module.weight)


In [None]:
trainer = pl.Trainer(accelerator='cpu', devices=1, max_epochs=20, callbacks=[MetricsLoggingCallback()])
trainer.fit(model, datamodule=datamodule)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name    | Type       | Params
---------------------------------------
0 | encoder | BertModel  | 124 M 
1 | head    | Sequential | 99.3 K
---------------------------------------
124 M     Trainable params
0         Non-trainable params
124 M     Total params
498.169   Total estimated model params size (MB)


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:01<00:01,  1.34s/it]

  value = torch.tensor(value, device=self.device)


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:02<00:00,  1.26s/it]

tp: 21
nonzero count:, 3567
preds_logits: tensor([-0.5580, -0.3248,  0.7525,  1.0948, -2.0867, -1.1107,  0.2378])
micro // f1: 1.06, recall: 5.41, precision: 0.59
weighted // f1: 4.03, recall: 5.41, precision: 0.59


                                                                           

  rank_zero_warn(


Epoch 0:   0%|          | 0/14 [00:00<?, ?it/s] 

  value = torch.tensor(value, device=self.device)


Epoch 0:  50%|█████     | 7/14 [00:54<00:54,  7.75s/it, loss=1.28e+03, v_num=62, train_loss_step=310.0]  
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/7 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/7 [00:00<?, ?it/s][A
Epoch 0:  57%|█████▋    | 8/14 [00:56<00:42,  7.02s/it, loss=1.28e+03, v_num=62, train_loss_step=310.0]
Epoch 0:  64%|██████▍   | 9/14 [00:57<00:31,  6.38s/it, loss=1.28e+03, v_num=62, train_loss_step=310.0]
Epoch 0:  71%|███████▏  | 10/14 [00:58<00:23,  5.87s/it, loss=1.28e+03, v_num=62, train_loss_step=310.0]
Epoch 0:  79%|███████▊  | 11/14 [01:00<00:16,  5.50s/it, loss=1.28e+03, v_num=62, train_loss_step=310.0]
Epoch 0:  86%|████████▌ | 12/14 [01:01<00:10,  5.15s/it, loss=1.28e+03, v_num=62, train_loss_step=310.0]
Epoch 0:  93%|█████████▎| 13/14 [01:03<00:04,  4.87s/it, loss=1.28e+03, v_num=62, train_loss_step=310.0]
Epoch 0: 100%|██████████| 14/14 [01:03<00:00,  4.56s/it, loss=1.28e+03, v_num=62, train_loss_step=310.0]

tp: 0
nonzero count:, 0
preds_logits: tensor([  88.6416,   85.6447,   22.0767, -124.7560, -112.0170,   13.8141,
        -119.8799])
micro // f1: 0.00, recall: 0.00, precision: 0.00
weighted // f1: 0.00, recall: 0.00, precision: 0.00


Epoch 0: 100%|██████████| 14/14 [01:03<00:00,  4.56s/it, loss=1.28e+03, v_num=62, train_loss_step=310.0]
Epoch 1:   0%|          | 0/14 [00:00<?, ?it/s, loss=1.28e+03, v_num=62, train_loss_step=310.0, train_loss_epoch=1.45e+3]         

  value = torch.tensor(value, device=self.device)


Epoch 1:  50%|█████     | 7/14 [01:28<01:28, 12.62s/it, loss=698, v_num=62, train_loss_step=7.870, train_loss_epoch=1.45e+3]     
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/7 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/7 [00:00<?, ?it/s][A
Epoch 1:  57%|█████▋    | 8/14 [01:30<01:07, 11.28s/it, loss=698, v_num=62, train_loss_step=7.870, train_loss_epoch=1.45e+3]
Epoch 1:  64%|██████▍   | 9/14 [01:32<00:51, 10.24s/it, loss=698, v_num=62, train_loss_step=7.870, train_loss_epoch=1.45e+3]
Epoch 1:  71%|███████▏  | 10/14 [01:34<00:37,  9.41s/it, loss=698, v_num=62, train_loss_step=7.870, train_loss_epoch=1.45e+3]
Epoch 1:  79%|███████▊  | 11/14 [01:35<00:26,  8.70s/it, loss=698, v_num=62, train_loss_step=7.870, train_loss_epoch=1.45e+3]
Epoch 1:  86%|████████▌ | 12/14 [01:37<00:16,  8.12s/it, loss=698, v_num=62, train_loss_step=7.870, train_loss_epoch=1.45e+3]
Epoch 1:  93%|█████████▎| 13/14 [01:39<00:07,  7.62s/it, loss=698, v_num=62, train_loss_s

tp: 513
nonzero count:, 11265
preds_logits: tensor([  7.6541,   5.6231,   9.2959,  -9.4332, -13.3155,   0.7387, -16.3373])
micro // f1: 8.20, recall: 41.17, precision: 4.55
weighted // f1: 3.59, recall: 41.17, precision: 4.55


Epoch 1: 100%|██████████| 14/14 [01:39<00:00,  7.11s/it, loss=698, v_num=62, train_loss_step=7.870, train_loss_epoch=1.45e+3]
Epoch 2:   0%|          | 0/14 [00:00<?, ?it/s, loss=698, v_num=62, train_loss_step=7.870, train_loss_epoch=112.0]           

  value = torch.tensor(value, device=self.device)


Epoch 2:  50%|█████     | 7/14 [01:46<01:46, 15.22s/it, loss=497, v_num=62, train_loss_step=5.590, train_loss_epoch=112.0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/7 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/7 [00:00<?, ?it/s][A
Epoch 2:  57%|█████▋    | 8/14 [01:48<01:21, 13.60s/it, loss=497, v_num=62, train_loss_step=5.590, train_loss_epoch=112.0]
Epoch 2:  64%|██████▍   | 9/14 [01:50<01:01, 12.27s/it, loss=497, v_num=62, train_loss_step=5.590, train_loss_epoch=112.0]
Epoch 2:  71%|███████▏  | 10/14 [01:51<00:44, 11.19s/it, loss=497, v_num=62, train_loss_step=5.590, train_loss_epoch=112.0]
Epoch 2:  79%|███████▊  | 11/14 [01:53<00:31, 10.34s/it, loss=497, v_num=62, train_loss_step=5.590, train_loss_epoch=112.0]
Epoch 2:  86%|████████▌ | 12/14 [01:55<00:19,  9.62s/it, loss=497, v_num=62, train_loss_step=5.590, train_loss_epoch=112.0]
Epoch 2:  93%|█████████▎| 13/14 [01:58<00:09,  9.08s/it, loss=497, v_num=62, train_loss_step=5.590, train_

tp: 0
nonzero count:, 0
preds_logits: tensor([ 1.3348,  0.6628,  0.2702, -0.7144, -1.4205,  0.2358, -1.5312])
micro // f1: 0.00, recall: 0.00, precision: 0.00
weighted // f1: 0.00, recall: 0.00, precision: 0.00


Epoch 2: 100%|██████████| 14/14 [01:58<00:00,  8.47s/it, loss=497, v_num=62, train_loss_step=5.590, train_loss_epoch=112.0]
Epoch 3:   0%|          | 0/14 [00:00<?, ?it/s, loss=497, v_num=62, train_loss_step=5.590, train_loss_epoch=29.30]         

  value = torch.tensor(value, device=self.device)


Epoch 3:  50%|█████     | 7/14 [02:12<02:12, 18.91s/it, loss=44.6, v_num=62, train_loss_step=5.680, train_loss_epoch=29.30]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/7 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/7 [00:00<?, ?it/s][A
Epoch 3:  57%|█████▋    | 8/14 [02:14<01:41, 16.87s/it, loss=44.6, v_num=62, train_loss_step=5.680, train_loss_epoch=29.30]
Epoch 3:  64%|██████▍   | 9/14 [02:16<01:15, 15.17s/it, loss=44.6, v_num=62, train_loss_step=5.680, train_loss_epoch=29.30]
Epoch 3:  71%|███████▏  | 10/14 [02:19<00:55, 13.94s/it, loss=44.6, v_num=62, train_loss_step=5.680, train_loss_epoch=29.30]
Epoch 3:  79%|███████▊  | 11/14 [02:20<00:38, 12.79s/it, loss=44.6, v_num=62, train_loss_step=5.680, train_loss_epoch=29.30]
Epoch 3:  86%|████████▌ | 12/14 [02:22<00:23, 11.85s/it, loss=44.6, v_num=62, train_loss_step=5.680, train_loss_epoch=29.30]
Epoch 3:  93%|█████████▎| 13/14 [02:23<00:11, 11.05s/it, loss=44.6, v_num=62, train_loss_step=5.680,

tp: 513
nonzero count:, 11265
preds_logits: tensor([  8.2577,   6.5615,   8.5260,  -9.3597, -27.3703,   5.3974, -26.0650])
micro // f1: 8.20, recall: 41.17, precision: 4.55
weighted // f1: 3.59, recall: 41.17, precision: 4.55


Epoch 3: 100%|██████████| 14/14 [02:24<00:00, 10.30s/it, loss=44.6, v_num=62, train_loss_step=5.680, train_loss_epoch=29.30]
Epoch 4:   0%|          | 0/14 [00:00<?, ?it/s, loss=44.6, v_num=62, train_loss_step=5.680, train_loss_epoch=18.30]         

  value = torch.tensor(value, device=self.device)


Epoch 4:  50%|█████     | 7/14 [02:25<02:25, 20.73s/it, loss=22.4, v_num=62, train_loss_step=4.960, train_loss_epoch=18.30]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/7 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/7 [00:00<?, ?it/s][A
Epoch 4:  57%|█████▋    | 8/14 [02:27<01:50, 18.42s/it, loss=22.4, v_num=62, train_loss_step=4.960, train_loss_epoch=18.30]
Epoch 4:  64%|██████▍   | 9/14 [02:29<01:22, 16.57s/it, loss=22.4, v_num=62, train_loss_step=4.960, train_loss_epoch=18.30]
Epoch 4:  71%|███████▏  | 10/14 [02:30<01:00, 15.05s/it, loss=22.4, v_num=62, train_loss_step=4.960, train_loss_epoch=18.30]
Epoch 4:  79%|███████▊  | 11/14 [02:33<00:41, 13.92s/it, loss=22.4, v_num=62, train_loss_step=4.960, train_loss_epoch=18.30]
Epoch 4:  86%|████████▌ | 12/14 [02:34<00:25, 12.88s/it, loss=22.4, v_num=62, train_loss_step=4.960, train_loss_epoch=18.30]
Epoch 4:  93%|█████████▎| 13/14 [02:35<00:11, 12.00s/it, loss=22.4, v_num=62, train_loss_step=4.960,

tp: 513
nonzero count:, 11265
preds_logits: tensor([  6.3035,   6.5229,   6.7505,  -9.1935, -30.4175,   5.4487, -30.4645])
micro // f1: 8.20, recall: 41.17, precision: 4.55
weighted // f1: 3.59, recall: 41.17, precision: 4.55


Epoch 4: 100%|██████████| 14/14 [02:36<00:00, 11.18s/it, loss=22.4, v_num=62, train_loss_step=4.960, train_loss_epoch=18.30]
Epoch 5:   0%|          | 0/14 [00:00<?, ?it/s, loss=22.4, v_num=62, train_loss_step=4.960, train_loss_epoch=18.90]         

  value = torch.tensor(value, device=self.device)


Epoch 5:  50%|█████     | 7/14 [02:36<02:36, 22.39s/it, loss=18.6, v_num=62, train_loss_step=5.340, train_loss_epoch=18.90]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/7 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/7 [00:00<?, ?it/s][A
Epoch 5:  57%|█████▋    | 8/14 [02:39<01:59, 19.94s/it, loss=18.6, v_num=62, train_loss_step=5.340, train_loss_epoch=18.90]
Epoch 5:  64%|██████▍   | 9/14 [02:41<01:29, 17.91s/it, loss=18.6, v_num=62, train_loss_step=5.340, train_loss_epoch=18.90]
Epoch 5:  71%|███████▏  | 10/14 [02:43<01:05, 16.32s/it, loss=18.6, v_num=62, train_loss_step=5.340, train_loss_epoch=18.90]
Epoch 5:  79%|███████▊  | 11/14 [02:45<00:45, 15.00s/it, loss=18.6, v_num=62, train_loss_step=5.340, train_loss_epoch=18.90]
Epoch 5:  86%|████████▌ | 12/14 [02:46<00:27, 13.87s/it, loss=18.6, v_num=62, train_loss_step=5.340, train_loss_epoch=18.90]
Epoch 5:  93%|█████████▎| 13/14 [02:48<00:12, 12.96s/it, loss=18.6, v_num=62, train_loss_step=5.340,

tp: 0
nonzero count:, 0
preds_logits: tensor([  3.5999,   2.6459,   2.2143,  -2.6841,  -9.5338,   1.1630, -13.0484])
micro // f1: 0.00, recall: 0.00, precision: 0.00
weighted // f1: 0.00, recall: 0.00, precision: 0.00


Epoch 5: 100%|██████████| 14/14 [02:49<00:00, 12.07s/it, loss=18.6, v_num=62, train_loss_step=5.340, train_loss_epoch=18.90]
Epoch 6:   0%|          | 0/14 [00:00<?, ?it/s, loss=18.6, v_num=62, train_loss_step=5.340, train_loss_epoch=17.70]         

  value = torch.tensor(value, device=self.device)


In [20]:
model.__dict__.keys()

dict_keys(['training', '_parameters', '_buffers', '_non_persistent_buffers_set', '_backward_hooks', '_is_full_backward_hook', '_forward_hooks', '_forward_pre_hooks', '_state_dict_hooks', '_load_state_dict_pre_hooks', '_load_state_dict_post_hooks', '_modules', 'prepare_data_per_node', 'allow_zero_length_dataloader_with_multiple_devices', '_log_hyperparams', '_dtype', '_device', '_trainer', '_use_amp', 'precision', '_example_input_array', '_current_fx_name', '_automatic_optimization', '_truncated_bptt_steps', '_param_requires_grad_state', '_metric_attributes', '_should_prevent_trainer_and_dataloaders_deepcopy', '_running_torchscript', 'num_classes', '_hparams_name', '_hparams', '_hparams_initial'])

In [None]:
trainer.predict(model, datamodule=datamodule)