In [None]:
import warnings
warnings.filterwarnings('ignore')

import os


os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import transformers
from transformers import AutoConfig, AutoTokenizer, AutoModel, BertTokenizer

import tqdm

from datetime import datetime
from pathlib import Path
import time

import numpy as np


In [None]:
# albert_tokenizer = AutoTokenizer.from_pretrained('voidful/albert_chinese_large')
tokenizer = BertTokenizer.from_pretrained("voidful/albert_chinese_large")


In [None]:
ner_tag_file = "pos.tgt.dict"
ner_tag2id = {}

count = 0
with open(ner_tag_file) as fp:
    for line in fp:
        line = line.strip().split()
        ner_tag2id[line[0]] = count
        count += 1
        
ner_id2tag = {v:k for k, v in ner_tag2id.items()}
        
print(len(ner_tag2id))

In [None]:
pos_tag2id = {}

count = 0
for k, v in ner_tag2id.items():
    if k.split('-')[-1] not in pos_tag2id:
        pos_tag2id[k.split('-')[-1]] = count
        count += 1
        
print(len(pos_tag2id))

In [None]:
input_file = "crawler_data/word_seg/data/cna_total_token_input"
label_file = "crawler_data/word_seg/data/cna_total_token_label"

In [None]:
total_token_input = []
total_token_label = []

with open(input_file) as fp1, open(label_file) as fp2:
    for i, (tokens, labels) in enumerate(zip(fp1, fp2)):
        tokens = tokens.strip().split()
        labels = labels.strip().split()
        
        assert len(tokens) == len(labels)
        
        total_token_input.append(tokens)
        total_token_label.append(labels)
        
print(len(total_token_input), len(total_token_label))

In [None]:
class AttrDict(dict):
    """ Access dictionary keys like attribute 
        https://stackoverflow.com/questions/4984647/accessing-dict-keys-like-an-attribute
    """
    def __init__(self, *av, **kav):
        dict.__init__(self, *av, **kav)
        self.__dict__ = self

opts = AttrDict()

# Configure models
opts.vocab_size = tokenizer.vocab_size
opts.output_size = len(ner_tag2id)

# Configure optimization
opts.learning_rate = 1.5e-4
opts.bert_lr = 5e-6
opts.weight_decay = 0.01 # L2 weight regularization
opts.max_grad_norm = 1.0

opts.batch_size = 20

# Configure training
opts.max_seq_len = 256
opts.num_epochs = 20
opts.warmup_steps = 4000
opts.gradient_accumulation = 6

opts.load_pretrain = True

In [None]:
from sklearn.model_selection import train_test_split

random_seed = 202001004

train_token, dev_token, train_label, dev_label = train_test_split(total_token_input, total_token_label, test_size=0.05, random_state=random_seed, shuffle=True)

print(len(train_token), len(train_label), len(dev_token), len(dev_label))

In [None]:
for a, b in zip(dev_token, dev_label):
    assert len(a) == len(b)

In [None]:
class Dataset():
    def __init__(self, token_input, token_label):
        
        self.token_inputs = token_input
        self.token_labels = token_label
        
        print('total examples {} ...'.format(len(self.token_inputs)))
        
    def __len__(self):
        return len(self.token_inputs)
    
    def __getitem__(self, index):
        
        token_input = self.token_inputs[index]
        token_label = self.token_labels[index]
        
        token_input = ['[CLS]'] + token_input + ['[SEP]']
        token_label = ['[CLS]'] + token_label + ['[SEP]']
        
        input_ids = self.token2ids(token_input)
        label_ids = self.label2ids(token_label)
        
        return token_input, token_label, input_ids, label_ids
    
    def token2ids(self, token_input):

        return tokenizer.convert_tokens_to_ids(token_input)
        
    def label2ids(self, token_label):
        
        label_ids = [ner_tag2id[label] for label in token_label]
        
        return label_ids
        
        
def collate_fn(data):
    
    inputs, labels, inputs_ids, labels_ids = zip(*data)    

    
    lens = [len(input_token) if len(input_token) < opts.max_seq_len else opts.max_seq_len for input_token in inputs]
    max_len = max(lens)
    
    input_seqs = torch.zeros(len(inputs), max_len).long()
    input_mask = torch.zeros(len(inputs), max_len)
    targets = torch.zeros(len(inputs), max_len).long()
    
    for i, (input_ids, label_ids) in enumerate(zip(inputs_ids, labels_ids)):
        input_seqs[i, :lens[i]] = torch.LongTensor(input_ids[:lens[i]])
        input_mask[i, :lens[i]] = torch.ones(lens[i])
        targets[i, :lens[i]] = torch.LongTensor(label_ids[:lens[i]])
    
    return inputs, labels, inputs_ids, labels_ids, input_seqs, input_mask, targets, lens

In [None]:
train_dataset = Dataset(train_token, train_label)
dev_dataset = Dataset(dev_token, dev_label)

In [None]:
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)

train_iter = DataLoader(dataset=train_dataset,
                        batch_size=opts.batch_size,
                        shuffle=True,
                        num_workers=8,
#                         sampler=train_sampler,
                        collate_fn=collate_fn)

dev_iter = DataLoader(dataset=dev_dataset,
                        batch_size=8,
                        shuffle=False,
                        num_workers=8,
#                         sampler=train_sampler,
                        collate_fn=collate_fn)

In [None]:
from typing import List, Optional, Mapping

class CRF(nn.Module):
    """Conditional random field.
    This module implements a conditional random field [LMP]. The forward computation
    of this class computes the log likelihood of the given sequence of tags and
    emission score tensor. This class also has ``decode`` method which finds the
    best tag sequence given an emission score tensor using `Viterbi algorithm`_.
    Arguments
    ---------
    num_tags : int
        Number of tags.
    batch_first : bool, optional
        Whether the first dimension corresponds to the size of a minibatch.
    Attributes
    ----------
    start_transitions : :class:`~torch.nn.Parameter`
        Start transition score tensor of size ``(num_tags,)``.
    end_transitions : :class:`~torch.nn.Parameter`
        End transition score tensor of size ``(num_tags,)``.
    transitions : :class:`~torch.nn.Parameter`
        Transition score tensor of size ``(num_tags, num_tags)``.
    References
    ----------
    .. [LMP] Lafferty, J., McCallum, A., Pereira, F. (2001).
             "Conditional random fields: Probabilistic models for segmenting and
             labeling sequence data". *Proc. 18th International Conf. on Machine
             Learning*. Morgan Kaufmann. pp. 282–289.
    .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
    """

    def __init__(self, num_tags: int, batch_first: bool = False, tag_id_to_name: Mapping[int, str] = None) -> None:
        if num_tags <= 0:
            raise ValueError(f'invalid number of tags: {num_tags}')
        super().__init__()
        self.num_tags = num_tags
        self.batch_first = batch_first
        self.start_transitions = nn.Parameter(torch.empty(num_tags))
        self.end_transitions = nn.Parameter(torch.empty(num_tags))
        self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))

        self.reset_parameters()
        if not tag_id_to_name: return

        self.tag_id_to_name = dict()
        for t_i_n in tag_id_to_name.items():
            t_i, t_name = t_i_n
            t_name_part = t_name.split('-')
            assert len(t_name_part) <= 3, "tag name {} error".format(t_name)
            t_pref = t_name_part[0]
            t_suff = t_name_part[1] if len(t_name_part) > 1 else t_name_part[0]
            if len(t_name_part) == 3:
                t_suff += '-'
                t_suff += t_name_part[2]
            self.tag_id_to_name[t_i] = (t_pref, t_suff)

        self.set_transition_constraint()

    def reset_parameters(self) -> None:
        """Initialize the transition parameters.
        The parameters will be initialized randomly from a uniform distribution
        between -0.01 and 0.01.
        """
        nn.init.uniform_(self.start_transitions, -0.01, 0.01)
        nn.init.uniform_(self.end_transitions, -0.01, 0.01)
        nn.init.uniform_(self.transitions, -0.01, 0.01)

    def set_transition_constraint(self) -> None:
        """Set impossible transitions
        Set transitions between impossible labels to a very low score, effectively disabling them
        """
        if len(self.tag_id_to_name) < 1:
            print("no tag id to name dict")
            return
        for source_i, (source_pref, source_suff) in self.tag_id_to_name.items():
            if source_pref in ['[PAD]', '[SEP]']:
                self.start_transitions.data[source_i] = -10000.
                # can only go to PAD
                self.transitions.data[source_i].fill_(-10000.)
                for targ_i in range(len(self.tag_id_to_name)):
                    targ_pref, targ_suff = self.tag_id_to_name[targ_i]
                    if (targ_pref in ['[PAD]']):
                        self.transitions.data[source_i, targ_i] = 0.001
            if source_pref in ['[CLS]']:
                # possible ends are S- and B-
                self.transitions.data[source_i].fill_(-10000.)
                for targ_i in range(len(self.tag_id_to_name)):
                    targ_pref, targ_suff = self.tag_id_to_name[targ_i]
                    if targ_pref in ['S', 'B']:
                        self.transitions.data[source_i, targ_i] = 0.001
                    if targ_pref in ['I', 'E']:
                        self.transitions.data[source_i, targ_i] = -10000.
            if source_pref in ['B', 'I']:
                # possible ends are I- and E-
                self.transitions.data[source_i].fill_(-10000.)
                for targ_i in range(len(self.tag_id_to_name)):
                    targ_pref, targ_suff = self.tag_id_to_name[targ_i]
                    if (targ_suff == source_suff) and (targ_pref in ['I', 'E']):
                        self.transitions.data[source_i, targ_i] = 0.001
            if source_pref in ['S', 'E']:
                # cannot go to I or E
                for targ_i in range(len(self.tag_id_to_name)):
                    targ_pref, targ_suff = self.tag_id_to_name[targ_i]
                    if targ_pref in ['I', 'E']:
                        self.transitions.data[source_i, targ_i] = -10000.
            if source_pref in ['I', 'E']:
                # cannot be start transitions
                self.start_transitions.data[source_i] = -10000.

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(num_tags={self.num_tags})'

    def forward(
            self,
            emissions: torch.Tensor,
            tags: torch.LongTensor,
            mask: Optional[torch.ByteTensor] = None,
            reduction: str = 'sum',
    ) -> torch.Tensor:
        """Compute the conditional log likelihood of a sequence of tags given emission scores.
        Arguments
        ---------
        emissions : :class:`~torch.Tensor`
            Emission score tensor of size ``(seq_length, batch_size, num_tags)`` if
            ``batch_first`` is ``False``, ``(batch_size, seq_length, num_tags)`` otherwise.
        tags : :class:`~torch.LongTensor`
            Sequence of tags tensor of size ``(seq_length, batch_size)`` if
            ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
        mask : :class:`~torch.ByteTensor`, optional
            Mask tensor of size ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
            ``(batch_size, seq_length)`` otherwise.
        reduction : str, optional
            Specifies  the reduction to apply to the output: 'none'|'sum'|'mean'|'token_mean'.
            'none': no reduction will be applied. 'sum': the output will be summed over batches.
            'mean': the output will be averaged over batches. 'token_mean': the output will be
            averaged over tokens.
        Returns
        -------
        :class:`~torch.Tensor`
            The log likelihood. This will have size ``(batch_size,)`` if reduction is 'none',
            ``()`` otherwise.
        """
        self._validate(emissions, tags=tags, mask=mask)
        if reduction not in ('none', 'sum', 'mean', 'token_mean'):
            raise ValueError(f'invalid reduction: {reduction}')
        if mask is None:
            mask = torch.ones_like(tags, dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            tags = tags.transpose(0, 1)
            mask = mask.transpose(0, 1)

        # shape: (batch_size,)
        numerator = self._compute_score(emissions, tags, mask)
        # shape: (batch_size,)
        denominator = self._compute_normalizer(emissions, mask)
        # shape: (batch_size,)
        llh = numerator - denominator

        if reduction == 'none':
            return llh
        if reduction == 'sum':
            return llh.sum()
        if reduction == 'mean':
            return llh.mean()
        assert reduction == 'token_mean'
        return llh.sum() / mask.float().sum()

    def decode(self, emissions: torch.Tensor,
               mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
        """Find the most likely tag sequence using Viterbi algorithm.
        Arguments
        ---------
        emissions : :class:`~torch.Tensor`
            Emission score tensor of size ``(seq_length, batch_size, num_tags)`` if
            ``batch_first`` is ``False``, ``(batch_size, seq_length, num_tags)`` otherwise.
        mask : :class:`~torch.ByteTensor`, optional
            Mask tensor of size ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
            ``(batch_size, seq_length)`` otherwise.
        Returns
        -------
        List[List[int]]
            List of list containing the best tag sequence for each batch.
        """
        self._validate(emissions, mask=mask)
        if mask is None:
            mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            mask = mask.transpose(0, 1)

        return self._viterbi_decode(emissions, mask)

    def _validate(
            self,
            emissions: torch.Tensor,
            tags: Optional[torch.LongTensor] = None,
            mask: Optional[torch.ByteTensor] = None) -> None:
        if emissions.dim() != 3:
            raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
        if emissions.size(2) != self.num_tags:
            raise ValueError(
                f'expected last dimension of emissions is {self.num_tags}, '
                f'got {emissions.size(2)}')

        if tags is not None:
            if emissions.shape[:2] != tags.shape:
                raise ValueError(
                    'the first two dimensions of emissions and tags must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')

        if mask is not None:
            if emissions.shape[:2] != mask.shape:
                raise ValueError(
                    'the first two dimensions of emissions and mask must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
            no_empty_seq = not self.batch_first and mask[0].all()
            no_empty_seq_bf = self.batch_first and mask[:, 0].all()
            if not no_empty_seq and not no_empty_seq_bf:
                raise ValueError('mask of the first timestep must all be on')

    def _compute_score(
            self, emissions: torch.Tensor, tags: torch.LongTensor,
            mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # tags: (seq_length, batch_size)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and tags.dim() == 2
        assert emissions.shape[:2] == tags.shape
        assert emissions.size(2) == self.num_tags
        assert mask.shape == tags.shape
        assert mask[0].all()

        seq_length, batch_size = tags.shape
        mask = mask.float()

        # Start transition score and first emission
        # shape: (batch_size,)
        score = self.start_transitions[tags[0]]
        score += emissions[0, torch.arange(batch_size), tags[0]]

        for i in range(1, seq_length):
            # Transition score to next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += self.transitions[tags[i - 1], tags[i]] * mask[i]

            # Emission score for next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]

        # End transition score
        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        # shape: (batch_size,)
        last_tags = tags[seq_ends, torch.arange(batch_size)]
        # shape: (batch_size,)
        score += self.end_transitions[last_tags]

        return score

    def _compute_normalizer(
            self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length = emissions.size(0)

        # Start transition score and first emission; score has size of
        # (batch_size, num_tags) where for each batch, the j-th column stores
        # the score that the first timestep has tag j
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]

        for i in range(1, seq_length):
            # Broadcast score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emissions = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the sum of scores of all
            # possible tag sequences so far that end with transitioning from tag i to tag j
            # and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emissions

            # Sum over all possible current tags, but we're in score space, so a sum
            # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
            # all possible tag sequences so far, that end in tag i
            # shape: (batch_size, num_tags)
            next_score = torch.logsumexp(next_score, dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Sum (log-sum-exp) over all possible tags
        # shape: (batch_size,)
        return torch.logsumexp(score, dim=1)

    def _viterbi_decode(self, emissions: torch.FloatTensor,
                        mask: torch.ByteTensor) -> List[List[int]]:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length, batch_size = mask.shape

        # Start transition and first emission
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]
        history = []

        # score is a tensor of size (batch_size, num_tags) where for every batch,
        # value at column j stores the score of the best tag sequence so far that ends
        # with tag j
        # history saves where the best tags candidate transitioned from; this is used
        # when we trace back the best tag sequence

        # Viterbi algorithm recursive case: we compute the score of the best tag sequence
        # for every possible next tag
        for i in range(1, seq_length):
            # Broadcast viterbi score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emission = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the score of the best
            # tag sequence so far that ends with transitioning from tag i to tag j and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emission

            # Find the maximum score over all possible current tag
            # shape: (batch_size, num_tags)
            next_score, indices = next_score.max(dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # and save the index that produces the next score
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)
            history.append(indices)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Now, compute the best path for each sample

        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        best_tags_list = []

        for idx in range(batch_size):
            # Find the tag which maximizes the score at the last timestep; this is our best tag
            # for the last timestep
            _, best_last_tag = score[idx].max(dim=0)
            best_tags = [best_last_tag.item()]

            # We trace back where the best last tag comes from, append that to our best tag
            # sequence, and trace it back again, and so on
            for hist in reversed(history[:seq_ends[idx]]):
                best_last_tag = hist[idx][best_tags[-1]]
                best_tags.append(best_last_tag.item())

            # Reverse the order because we start from the last timestep
            best_tags.reverse()
            best_tags_list.append(best_tags)

        return best_tags_list

In [None]:
class Albert_CRF(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        
        self.albert = AutoModel.from_pretrained("voidful/albert_chinese_large", return_dict=True)
        
        self.classifier = nn.Linear(self.albert.config.hidden_size, len(ner_tag2id))
        
        self.crf = CRF(len(ner_tag2id), batch_first=True, tag_id_to_name=ner_id2tag)
        
        self.CELoss_fn = nn.CrossEntropyLoss(ignore_index=ner_tag2id['[PAD]'])
        
    def forward(self, input_seqs, input_mask, targets):
        
        output = self.albert(input_ids=input_seqs, 
                             attention_mask=input_mask,)
        
        output = output['last_hidden_state']
        
        output = self.classifier(output)
        
        ce_loss = self.CELoss_fn(output.view(-1, len(ner_tag2id)), targets.view(-1))
        
        crf_loss = self.crf(emissions=output, tags=targets, mask=input_mask.byte(), reduction='mean')
        
        crf_loss = crf_loss*-1
        
        loss = ce_loss + 1e-2*crf_loss
        
        return output, ce_loss, crf_loss, loss
        
        

In [None]:
model = Albert_CRF()

print('total parms : ', sum(p.numel() for p in model.parameters()))
print('trainable parms : ', sum(p.numel() for p in model.parameters() if p.requires_grad))

In [None]:
model

In [None]:
USE_CUDA = torch.cuda.is_available()
USE_CUDA = True

In [None]:
## distribute data parallel

# dist_backend = 'nccl'
# dist_url = 'tcp://127.0.0.1:45655'
# world_size = 1
# rank = 0

# torch.distributed.init_process_group(backend=dist_backend, 
#                                      init_method=dist_url, 
#                                      world_size=world_size, 
#                                      rank=rank)


# bertlm = torch.nn.parallel.DistributedDataParallel(bertlm, find_unused_parameters=False)

# if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
# dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
model = nn.DataParallel(model)

if USE_CUDA:
    model.cuda()

In [None]:
exp_dir = "albert_ner/exp/"

In [None]:
last_epoch = -1
model_name = 'albert_ner_len{}_batch_{}'.format(opts.max_seq_len, opts.batch_size)
now = str(datetime.now()).split('.')[0]
experiment_name = '{}_{}'.format(model_name, now)
experiment_dir = Path(exp_dir) / experiment_name
experiment_dir.mkdir(exist_ok=True, parents=True)
print(experiment_dir)

In [None]:
def log2file(log_file, msg):
    with open(log_file, 'a') as fw:
        fw.write(msg)
        fw.write('\n')

experiment_trainlog = experiment_dir / 'train_log.txt'
experiment_devlog = experiment_dir / 'dev_log.txt'

In [None]:
print(opts.learning_rate)
print(opts.bert_lr)

optimizer = transformers.AdamW([
    {'params': model.module.parameters(), 'lr':opts.learning_rate},
], lr=opts.learning_rate)

scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 
                                                         num_warmup_steps=opts.warmup_steps, 
                                                         num_training_steps=len(train_iter)*opts.num_epochs)

# criterion = torch.nn.CrossEntropyLoss(ignore_index=bert_tokenizer.pad_token_id,)

In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0.001, exp_dir=''):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.best_epoch = 0
        self.exp_dir=Path(exp_dir)

    def __call__(self, val_loss, model, epoch):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch)
        elif score < self.best_score:
#         elif score < self.best_score or score < self.best_score * (1-self.delta):
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
                msg = 'best epoch : {}'.format(self.best_epoch)
                print(msg)
                log2file(self.exp_dir / 'train_log.txt', msg)
                (self.exp_dir / 'best_model').symlink_to(self.exp_dir / 'epoch_{}.mdl'.format(self.best_epoch))
        else:
            self.best_score = score
            self.best_epoch = epoch
            self.save_checkpoint(val_loss, model, epoch)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, epoch):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            msg = f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f})'
            print(msg)
            log2file(self.exp_dir / 'train_log.txt', msg)
#         torch.save(model.state_dict(), self.exp_dir / 'checkpoint.pt')
        self.val_loss_min = val_loss
        

In [None]:
experiment_dir

In [None]:
early_stopping = EarlyStopping(patience=50, verbose=True, exp_dir=str(experiment_dir))

for k,v in opts.items():
    log_msg = '- {}: {}'.format(k, v)
    log2file(str(experiment_trainlog), log_msg)
    print(log_msg)
    
pbar_train = tqdm.notebook.tqdm(total=len(train_iter))
pbar_dev = tqdm.notebook.tqdm(total=len(dev_iter))
    
log_msg = '='*50
print(log_msg)
log2file(str(experiment_trainlog), log_msg)
log_msg = 'optim : \n' + str(optimizer)
print(log_msg)   
log2file(str(experiment_trainlog), log_msg)


s = 5
checkpoint = [int(len(train_iter)/s*i) for i in range(1, s)]

oom_time = 0

print('check point : ', checkpoint)

for epoch in range(last_epoch+1,  opts.num_epochs, 1):
    
    pbar_train.reset()
    pbar_dev.reset()
    
    log_msg = '='*50
    print(log_msg)
    log2file(str(experiment_trainlog), log_msg)
    loss_tracker = []
    celoss_tracker = []
    crfloss_tracker = []
    time_tracker = []
    time_tracker.append(time.time())
    
    global_step = 0
    
    
    for iteration, batch in enumerate(train_iter):
        
        inputs, labels, inputs_ids, labels_ids, input_seqs, input_mask, targets, lens = batch
        
        batch_size = input_seqs.size(0)
        assert(batch_size == targets.size(0))
        
        if USE_CUDA:
            input_seqs = input_seqs.cuda()
            input_mask = input_mask.cuda()
    #         lens = lens.cuda()
            targets = targets.cuda()
        
        model.train()
        
        try:

            output, ce_loss, crf_loss, loss = model(input_seqs, input_mask, targets)
            
            ce_loss = ce_loss.mean()
            crf_loss = crf_loss.mean()
            loss = loss.mean()

            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), opts.max_grad_norm)

            if (iteration + 1) % opts.gradient_accumulation == 0 or iteration == len(train_iter)-1:

                optimizer.step()
                scheduler.step()

                optimizer.zero_grad()

                

            loss_tracker.append(loss.item()*batch_size)
            celoss_tracker.append(ce_loss.item()*batch_size)
            crfloss_tracker.append(crf_loss.item()*batch_size)

#             torch.cuda.empty_cache()
            
        except RuntimeError as exception:
            
            if "out of memory" in str(exception):
                oom_time += 1
                log_msg = "WARNING: ran out of memory,times: {}".format(oom_time)
                print(log_msg)   
                log2file(str(experiment_trainlog), log_msg)
                torch.cuda.empty_cache()
                if hasattr(torch.cuda, 'empty_cache'):
                    torch.cuda.empty_cache()
            elif "Gather got an" in str(exception):
                log_msg = str(exception)
                print(log_msg)   
                log2file(str(experiment_trainlog), log_msg)
            else:
                log_msg = str(exception)
                print(log_msg)   
                log2file(str(experiment_trainlog), log_msg)
                raise exception
                
        #=================================
        
        if global_step in checkpoint:
            
            now_time = time.time()
            time_tracker.append(time.time())
            cur_avg_loss = np.sum(np.array(loss_tracker)) / (global_step * opts.batch_size)
            cur_avg_celoss = np.sum(np.array(celoss_tracker)) / (global_step * opts.batch_size)
            cur_avg_crfloss = np.sum(np.array(crfloss_tracker)) / (global_step * opts.batch_size)
            log_msg = "{} | Batch {:d}/{:d} | Mean Loss {:5.5f} | Mean CE Loss {:5.5f} | Mean CRF Loss {:5.5f} | time cost {:d} s"  \
                    .format('train'.upper(), global_step, len(train_iter), cur_avg_loss, cur_avg_celoss, cur_avg_crfloss, int(time_tracker[-1] - time_tracker[-2]))
            print(log_msg)
            log2file(str(experiment_trainlog), log_msg)
            now_percent = checkpoint.index(global_step)+1
            
            ckpt = {
                "net": model.state_dict(),
                'optimizer':optimizer.state_dict(),
                "epoch": epoch
            }            
            
            torch.save(ckpt, experiment_dir / 'epoch_{}_{}.ckpt'.format(epoch-1, now_percent))
            
            dev_loss_tracker = []
            dev_celoss_tracker = []
            dev_crfloss_tracker = []

            start = time.time()

            pbar_dev.reset()

            for iteration, batch in enumerate(dev_iter):

                inputs, labels, inputs_ids, labels_ids, input_seqs, input_mask, targets, lens = batch

                batch_size = input_seqs.size(0)
                assert(batch_size == targets.size(0))

                if USE_CUDA:
                    input_seqs = input_seqs.cuda()
                    input_mask = input_mask.cuda()
            #         lens = lens.cuda()
                    targets = targets.cuda()

                model.eval()

                try:

                    output, ce_loss, crf_loss, loss = model(input_seqs, input_mask, targets)
                    
                    ce_loss = ce_loss.mean()
                    crf_loss = crf_loss.mean()
                    loss = loss.mean()

                    dev_loss_tracker.append(loss.item()*batch_size)
                    dev_celoss_tracker.append(ce_loss.item()*batch_size)
                    dev_crfloss_tracker.append(crf_loss.item()*batch_size)

                except RuntimeError as exception:

                    if "out of memory" in str(exception):
                        oom_time += 1
                        log_msg = "WARNING: ran out of memory,times: {}".format(oom_time)
                        print(log_msg)   
                        log2file(str(experiment_trainlog), log_msg)
                        torch.cuda.empty_cache()
                        if hasattr(torch.cuda, 'empty_cache'):
                            torch.cuda.empty_cache()
                    elif "Gather got an" in str(exception):
                        log_msg = str(exception)
                        print(log_msg)   
                        log2file(str(experiment_trainlog), log_msg)
                    else:
                        log_msg = str(exception)
                        print(log_msg)   
                        log2file(str(experiment_trainlog), log_msg)
                        raise exception
#                 torch.cuda.empty_cache()
                        
                pbar_dev.update(1)


            total_time = time.time() - start

            mean_loss = np.sum(np.array(dev_loss_tracker)) / dev_dataset.__len__()
            mean_celoss = np.sum(np.array(dev_celoss_tracker)) / dev_dataset.__len__()
            mean_crfloss = np.sum(np.array(dev_crfloss_tracker)) / dev_dataset.__len__()
            log_msg = "{}   | Batch {:d}/{:d} | Mean Loss {:5.5f} | Mean CE Loss {:5.5f} | Mean CRF Loss {:5.5f} | Total time cost {:d} s"  \
                .format('dev'.upper(), global_step, len(train_iter), mean_loss, mean_celoss, mean_crfloss, int(total_time))
            print(log_msg)
            log2file(str(experiment_trainlog), log_msg)

            val_loss = mean_loss

            early_stopping(val_loss, model, epoch)

            if early_stopping.early_stop:
                print("Early stopping")
                break
        
        global_step += 1
        pbar_train.update(1)

    
    total_time = time.time() - time_tracker[0]    
    mean_loss = np.sum(np.array(loss_tracker)) / train_dataset.__len__()
    mean_celoss = np.sum(np.array(dev_celoss_tracker)) / dev_dataset.__len__()
    mean_crfloss = np.sum(np.array(dev_crfloss_tracker)) / dev_dataset.__len__()
    log_msg = "{} | Epoch {:d}/{:d} | Mean Loss {:5.5f} | Mean CE Loss {:5.5f} | Mean CRF Loss {:5.5f} | Total time cost {:d} s"  \
        .format('train'.upper(), epoch, opts.num_epochs, mean_loss, mean_celoss, mean_crfloss, int(total_time))
    print(log_msg)
    log2file(str(experiment_trainlog), log_msg)

    #-----------------------

    loss_tracker = []
    celoss_tracker = []
    crfloss_tracker = []

    start = time.time()
    
    pbar_dev.reset()

    for iteration, batch in enumerate(dev_iter):

        inputs, labels, inputs_ids, labels_ids, input_seqs, input_mask, targets, lens = batch
        
        batch_size = input_seqs.size(0)
        assert(batch_size == targets.size(0))
        
        if USE_CUDA:
            input_seqs = input_seqs.cuda()
            input_mask = input_mask.cuda()
    #         lens = lens.cuda()
            targets = targets.cuda()
        
        
        model.eval()
        
        try:
        
            output, ce_loss, crf_loss, loss = model(input_seqs, input_mask, targets)
            
            ce_loss = ce_loss.mean()
            crf_loss = crf_loss.mean()
            loss = loss.mean()

            loss_tracker.append(loss.item()*batch_size)
            celoss_tracker.append(ce_loss.item()*batch_size)
            crfloss_tracker.append(crf_loss.item()*batch_size)
            
        except RuntimeError as exception:
            
            if "out of memory" in str(exception):
                oom_time += 1
                log_msg = "WARNING: ran out of memory,times: {}".format(oom_time)
                print(log_msg)   
                log2file(str(experiment_trainlog), log_msg)
                torch.cuda.empty_cache()
                if hasattr(torch.cuda, 'empty_cache'):
                    torch.cuda.empty_cache()
            elif "Gather got an" in str(exception):
                log_msg = str(exception)
                print(log_msg)   
                log2file(str(experiment_trainlog), log_msg)
            else:
                log_msg = str(exception)
                print(log_msg)   
                log2file(str(experiment_trainlog), log_msg)
                raise exception
                
#         torch.cuda.empty_cache()
        pbar_dev.update(1)

    total_time = time.time() - start

    mean_loss = np.sum(np.array(loss_tracker)) / dev_dataset.__len__()
    mean_celoss = np.sum(np.array(celoss_tracker)) / dev_dataset.__len__()
    mean_crfloss = np.sum(np.array(crfloss_tracker)) / dev_dataset.__len__()
    log_msg = "{}   | Epoch {:d}/{:d} | Mean Loss {:5.5f}  | Mean CE Loss {:5.5f}  | Mean CRF Loss {:5.5f}  | Total time cost {:d} s"  \
        .format('dev'.upper(), epoch, opts.num_epochs, mean_loss, mean_celoss, mean_crfloss, int(total_time))
    print(log_msg)
    log2file(str(experiment_trainlog), log_msg)
    
    val_loss = mean_loss
    
    early_stopping(val_loss, model, epoch)

    if early_stopping.early_stop:
        print("Early stopping")
        break
        
    ckpt = {
        "net": model.state_dict(),
        'optimizer':optimizer.state_dict(),
        "epoch": epoch
    } 
        
    torch.save(ckpt, experiment_dir / 'epoch_{}.ckpt'.format(epoch))

    print("="*50)