# preprocess crawler data

In [1]:
from tqdm import tqdm_notebook, tnrange
from pathlib import Path
import json
import re

In [2]:
cna_path = "/mnt/disk3/m10615110/crawler_data/cna_news/"
cts_path = "/mnt/disk3/m10615110/crawler_data/cts_news/"
udn_path = "/mnt/disk3/m10615110/crawler_data/udn_news/"
ptt_path = "/mnt/disk3/m10615110/crawler_data/ptt_gossiping/"

## cna data

In [3]:
file_list = list(Path(cna_path).glob('*.json'))
cna_news = {}

for i in tnrange(len(file_list)):
    
    if 'focus' in file_list[i].stem:
        with open(file_list[i], 'r') as f:
            data = json.load(f)

        for title, article in data.items():
            right_gua = article.find('）')
            article = article[right_gua+1:]
            article = article[::-1]
            last_period = article.find('。')
            article = article[last_period:]
            article = article[::-1]
            if title not in cna_news:
                cna_news[title] = article
                
    if 'news' in file_list[i].stem:
        with open(file_list[i], 'r') as f:
            data = json.load(f)
            
        for cate, articles in data.items():
            for title, article in articles.items():
                right_gua = article.find('）')
                article = article[right_gua+1:]
                article = article[::-1]
                last_period = article.find('。')
                article = article[last_period:]
                article = article[::-1]
                if title not in cna_news:
                    cna_news[title] = article
    
print('total {} news'.format(len(cna_news)))

HBox(children=(IntProgress(value=0, max=2506), HTML(value='')))


total 76537 news


## cts data

In [4]:
file_list = list(Path(cts_path).glob('*.json'))
cts_news = {}

for i in tnrange(len(file_list)):
    
    if 'hots' in file_list[i].stem:
        try:
            with open(file_list[i], 'r') as f:
                data = json.load(f)
            for title, article in data.items():
                if title not in cts_news:
                    cts_news[title] = article
        except:
            pass
                
    if 'news' in file_list[i].stem:
        try:
            with open(file_list[i], 'r') as f:
                data = json.load(f)
            for cate, articles in data.items():
                for title, article in articles.items():
                    first_slide = article.find('/')
                    article = article[first_slide+1:].strip()
                    if title not in cts_news:
                        cts_news[title] = article
        except:
            pass
        
print('total {} news'.format(len(cts_news)))

HBox(children=(IntProgress(value=0, max=2447), HTML(value='')))


total 29114 news


## udn data

In [5]:
file_list = list(Path(udn_path).glob('*.json'))
udn_news = {}

for i in tnrange(len(file_list)):
    
    if 'news' in file_list[i].stem:
        try:
            with open(file_list[i], 'r') as f:
                data = json.load(f)
            for cate, articles in data.items():
                for title, article in articles.items():
                    first_slide = title.find('／')
                    title = title[first_slide+1:] if first_slide != -1 else title
                    right_gua = title.find('】')
                    title = title[right_gua+1:] if right_gua != -1 else title
                    
                    if title not in udn_news:
                        udn_news[title] = article
                
        except:
            pass
        
print('total {} news'.format(len(udn_news)))

HBox(children=(IntProgress(value=0, max=785), HTML(value='')))


total 117625 news


## ptt data

In [6]:
file_list = list(Path(ptt_path).glob('*.json'))
ptt_articles = {}

for i in tnrange(len(file_list)):
    
    try:
        with open(file_list[0], 'r') as f:
            data = json.load(f)
        for article_id, article in data.items():
            title = article['article_title']
            content = article['content']
            content = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '', content, flags=re.MULTILINE)
            messages = ''
            for message in article['messages']:
                message_content = message['push_content']
                messages += message_content
                messages += '。'
                
            if title not in ptt_articles:
                ptt_articles[title] = [content, messages]
                
            
    except:
        pass

HBox(children=(IntProgress(value=0, max=526), HTML(value='')))




## simply combine all data

In [7]:
total_data = []
for title, article in cna_news.items():
    total_data.append(title)
    total_data.append(article)
    
for title, article in cts_news.items():
    total_data.append(title)
    total_data.append(article)
    
for title, article in udn_news.items():
    total_data.append(title)
    total_data.append(article) 
    
# for title, [content, messages] in ptt_articles.items():
#     total_data.append(title)
#     total_data.append(content)
#     total_data.append(messages)

print(len(total_data))

446552


In [8]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import transformers
from transformers import GPT2Config, GPT2Tokenizer, GPT2Model, GPT2LMHeadModel
from transformers import BertTokenizer, BertTokenizerFast

import numpy as np
from datetime import datetime
import os
import time

I0811 02:44:51.347485 139809267230528 file_utils.py:41] PyTorch version 1.2.0 available.


In [9]:
exp_dir = "/mnt/disk3/m10615110/gpt2_chinese/exp/"

bert_tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')

I0811 02:44:54.059484 139809267230528 tokenization_utils.py:979] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt from cache at /home/m10615110/.cache/torch/transformers/8a0c070123c1f794c42a29c6904beb7c1b8715741e235bee04aca2c7636fc83f.9b42061518a39ca00b8b52059fd2bede8daa613f8a8671500e518a8c29de8c00


In [10]:
class AttrDict(dict):
    """ Access dictionary keys like attribute 
        https://stackoverflow.com/questions/4984647/accessing-dict-keys-like-an-attribute
    """
    def __init__(self, *av, **kav):
        dict.__init__(self, *av, **kav)
        self.__dict__ = self

opts = AttrDict()

# Configure models
opts.vocab_size = bert_tokenizer.vocab_size
opts.emb = 768


# Configure optimization
opts.learning_rate = 1.5e-4
opts.bert_lr = 5e-6
opts.weight_decay = 0.01 # L2 weight regularization
opts.max_grad_norm = 1.0

opts.batch_size = 4

# Configure training
opts.max_seq_len = 512
opts.num_epochs = 300
opts.warmup_steps = 4000
opts.gradient_accumulation = 20

opts.load_pretrain = True

In [11]:
class TextDataset():
    def __init__(self, total_data):
        
        print('='*50)
        print('Dataset preprocessing log:')
        self.sents = total_data
        print('- Number of sentences: {}'.format(len(self.sents)))
        count = 0
        for sent in self.sents:
            count += len(sent)
        print('- Number of words: {}'.format(count))
        
        
    def __len__(self):
        return len(self.sents)
    
    def __getitem__(self, index):
        sent = self.sents[index]
        
        return sent
    
def collate_fn(data):
    
    def _pad_sequences(seqs):
        lens = [len(seq)-1 for seq in seqs]
        input_seqs = torch.zeros(len(seqs), max(lens)).long()
        target_seqs = torch.zeros(len(seqs), max(lens)).long()
        for i, seq in enumerate(seqs):
            input_seqs[i, :len(seq)-1] = torch.LongTensor(seq[:-1])
            target_seqs[i, :len(seq)-1] = torch.LongTensor(seq[1:])
            
        return input_seqs, target_seqs, lens
    
    def bert_tokenize(tokenizer, article, max_length=1024):
    
        sents = re.split('。|，| ', article)
        while '' in sents:
            sents.remove('')

        bert_sent = '[CLS]'
        for sent in sents:
            bert_sent += sent
            bert_sent += '[SEP]'

        tokens = tokenizer.tokenize(bert_sent)

        truncat_tokens = []
        if len(tokens) < max_length:
            truncat_tokens.append(tokens)
        else:
            truncat_tokens = []
            while len(tokens) > max_length:
                truncat_tokens.append(tokens[:max_length])
                tokens = tokens[max_length:]
            truncat_tokens.append(tokens)

        return truncat_tokens
    
    sents = data
    
    bert_tokens = []
    for sent in sents:
        tokens = bert_tokenize(bert_tokenizer, sent, max_length=opts.max_seq_len)
        bert_tokens.append(tokens[0])
        
    bert_idxs = []
    for bert_token in bert_tokens:
        idxs = bert_tokenizer.convert_tokens_to_ids(bert_token)
        bert_idxs.append(idxs)
        
    input_seqs, target_seqs, lens = _pad_sequences(bert_idxs)
    
    return sents, bert_tokens, bert_idxs, input_seqs, target_seqs, lens

In [12]:
random_seed = 202000810

In [13]:
from sklearn.model_selection import train_test_split

train_data, dev_data = train_test_split(total_data, test_size=0.05, random_state=random_seed, shuffle=True)

print(len(train_data), len(dev_data))

424224 22328


In [14]:
train_dataset = TextDataset(train_data)
dev_dataset = TextDataset(dev_data)

Dataset preprocessing log:
- Number of sentences: 424224
- Number of words: 132709780
Dataset preprocessing log:
- Number of sentences: 22328
- Number of words: 7098516


In [15]:
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)

train_iter = DataLoader(dataset=train_dataset,
                        batch_size=opts.batch_size,
                        shuffle=True,
                        num_workers=16,
#                         sampler=train_sampler,
                        collate_fn=collate_fn)

dev_iter = DataLoader(dataset=dev_dataset,
                        batch_size=2,
                        shuffle=False,
                        num_workers=16,
#                         sampler=train_sampler,
                        collate_fn=collate_fn)

In [16]:
if opts.load_pretrain:
    model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
    model.resize_token_embeddings(bert_tokenizer.vocab_size)
#     model.half()
else:
    gpt2_config = GPT2Config()
    model = GPT2LMHeadModel(config=gpt2_config)
    model.resize_token_embeddings(bert_tokenizer.vocab_size)
#     model.half()

print('total parms : ', sum(p.numel() for p in model.parameters()))
print('trainable parms : ', sum(p.numel() for p in model.parameters() if p.requires_grad))

I0811 02:44:56.785063 139809267230528 configuration_utils.py:286] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json from cache at /home/m10615110/.cache/torch/transformers/98aa65385e18b0efd17acd8bf64dcdf21406bb0c99c801c2d3c9f6bfd1f48f29.250d6dc755ccb17d19c7c1a7677636683aa35f0f6cb5461b3c0587bc091551a0
I0811 02:44:56.787188 139809267230528 configuration_utils.py:322] Model config GPT2Config {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 1024,
  "n_head": 16,
  "n_layer": 24,
  "n_positions": 1024,
  "n_special": 0,
  "predict_special_tokens": true,
  "resid_pdrop": 0.1,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",


total parms :  324995072
trainable parms :  324995072


In [17]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(21128, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2):

In [18]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
USE_CUDA = torch.cuda.is_available()
USE_CUDA = True

In [19]:
## distribute data parallel

# dist_backend = 'nccl'
# dist_url = 'tcp://127.0.0.1:45655'
# world_size = 1
# rank = 0

# torch.distributed.init_process_group(backend=dist_backend, 
#                                      init_method=dist_url, 
#                                      world_size=world_size, 
#                                      rank=rank)


# bertlm = torch.nn.parallel.DistributedDataParallel(bertlm, find_unused_parameters=False)

# if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
# dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
model = nn.DataParallel(model)

if USE_CUDA:
    model.cuda()

Let's use 2 GPUs!


In [20]:
last_epoch = -1
model_name = 'gpt2_medium_noptt_len{}_batch_{}'.format(opts.max_seq_len, opts.batch_size)
now = str(datetime.now()).split('.')[0]
experiment_name = '{}_{}'.format(model_name, now)
experiment_dir = Path(exp_dir) / experiment_name
experiment_dir.mkdir(exist_ok=True, parents=True)
print(experiment_dir)

/mnt/disk3/m10615110/gpt2_chinese/exp/gpt2_medium_noptt_len512_batch_4_2020-08-11 02:45:14


In [21]:
def log2file(log_file, msg):
    with open(log_file, 'a') as fw:
        fw.write(msg)
        fw.write('\n')

experiment_trainlog = experiment_dir / 'train_log.txt'
experiment_devlog = experiment_dir / 'dev_log.txt'

In [22]:
print(opts.learning_rate)
print(opts.bert_lr)

optimizer = transformers.AdamW([
    {'params': model.module.parameters(), 'lr':opts.learning_rate},
], lr=opts.learning_rate)

scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 
                                                         num_warmup_steps=opts.warmup_steps, 
                                                         num_training_steps=len(train_iter)*opts.num_epochs)

criterion = torch.nn.CrossEntropyLoss(ignore_index=bert_tokenizer.pad_token_id,)

0.00015
5e-06


In [23]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0.001, exp_dir=''):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.best_epoch = 0
        self.exp_dir=Path(exp_dir)

    def __call__(self, val_loss, model, epoch):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, epoch)
        elif score < self.best_score:
#         elif score < self.best_score or score < self.best_score * (1-self.delta):
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
                msg = 'best epoch : {}'.format(self.best_epoch)
                print(msg)
                log2file(self.exp_dir / 'train_log.txt', msg)
                (self.exp_dir / 'best_model').symlink_to(self.exp_dir / 'epoch_{}.mdl'.format(self.best_epoch))
        else:
            self.best_score = score
            self.best_epoch = epoch
            self.save_checkpoint(val_loss, model, epoch)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, epoch):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            msg = f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f})'
            print(msg)
            log2file(self.exp_dir / 'train_log.txt', msg)
#         torch.save(model.state_dict(), self.exp_dir / 'checkpoint.pt')
        self.val_loss_min = val_loss
        

In [24]:
experiment_dir

PosixPath('/mnt/disk3/m10615110/gpt2_chinese/exp/gpt2_medium_noptt_len512_batch_4_2020-08-11 02:45:14')

In [None]:
early_stopping = EarlyStopping(patience=50, verbose=True, exp_dir=str(experiment_dir))

for k,v in opts.items():
    log_msg = '- {}: {}'.format(k, v)
    log2file(str(experiment_trainlog), log_msg)
    print(log_msg)
    
pbar_train = tqdm_notebook(total=len(train_iter))
pbar_dev = tqdm_notebook(total=len(dev_iter))
    
log_msg = '='*50
print(log_msg)
log2file(str(experiment_trainlog), log_msg)
log_msg = 'optim : \n' + str(optimizer)
print(log_msg)   
log2file(str(experiment_trainlog), log_msg)


s = 10
checkpoint = [int(len(train_iter)/s*i) for i in range(1, s)]

oom_time = 0

print('check point : ', checkpoint)

for epoch in range(last_epoch+1,  opts.num_epochs, 1):
    
    pbar_train.reset()
    pbar_dev.reset()
    
    log_msg = '='*50
    print(log_msg)
    log2file(str(experiment_trainlog), log_msg)
    loss_tracker = []
    time_tracker = []
    time_tracker.append(time.time())
    
    global_step = 0
    
    
    for iteration, batch in enumerate(train_iter):
        
        sents, bert_tokens, bert_idxs, input_seqs, target_seqs, lens = batch
        
        batch_size = input_seqs.size(0)
        assert(batch_size == target_seqs.size(0))
        
        if USE_CUDA:
            input_seqs = input_seqs.cuda()
    #         lens = lens.cuda()
            target_seqs = target_seqs.cuda()
        
        model.train()
        
        try:

            outputs = model(input_seqs)

            loss = criterion(outputs[0].view(outputs[0].size(0)*outputs[0].size(1), -1), 
                             target_seqs.view(target_seqs.size(0)*target_seqs.size(1)))

            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), opts.max_grad_norm)

            if (iteration + 1) % opts.gradient_accumulation == 0 or iteration == len(train_iter)-1:

                optimizer.step()

                optimizer.zero_grad()

                scheduler.step()

            loss_tracker.append(loss.item()*batch_size)

#             torch.cuda.empty_cache()
            
        except RuntimeError as exception:
            
            if "out of memory" in str(exception):
                oom_time += 1
                log_msg = "WARNING: ran out of memory,times: {}".format(oom_time)
                print(log_msg)   
                log2file(str(experiment_trainlog), log_msg)
                torch.cuda.empty_cache()
                if hasattr(torch.cuda, 'empty_cache'):
                    torch.cuda.empty_cache()
            else:
                log_msg = str(exception)
                print(log_msg)   
                log2file(str(experiment_trainlog), log_msg)
                raise exception
                
        #=================================
        
        if global_step in checkpoint:
            
            now_time = time.time()
            time_tracker.append(time.time())
            cur_avg_loss = np.sum(np.array(loss_tracker)) / (global_step * opts.batch_size)
            log_msg = "{} | Batch {:d}/{:d} | Mean Loss {:5.5f} | time cost {:d} s"  \
                    .format('train'.upper(), global_step, len(train_iter), cur_avg_loss, int(time_tracker[-1] - time_tracker[-2]))
            print(log_msg)
            log2file(str(experiment_trainlog), log_msg)
            now_percent = checkpoint.index(global_step)+1
            torch.save(model.state_dict(), experiment_dir / 'epoch_{}_{}.mdl'.format(epoch-1, now_percent))
            
            loss_tracker = []

            start = time.time()

            pbar_dev.reset()

            for iteration, batch in enumerate(dev_iter):

                sents, bert_tokens, bert_idxs, input_seqs, target_seqs, lens = batch

                batch_size = input_seqs.size(0)
                assert(batch_size == target_seqs.size(0))

                if USE_CUDA:
                    input_seqs = input_seqs.cuda()
            #         lens = lens.cuda()
                    target_seqs = target_seqs.cuda()

                model.eval()

                try:

                    outputs = model(input_seqs)

                    loss = criterion(outputs[0].view(outputs[0].size(0)*outputs[0].size(1), -1), 
                                     target_seqs.view(target_seqs.size(0)*target_seqs.size(1)))

                    loss_tracker.append(loss.item()*batch_size)

                except RuntimeError as exception:

                    if "out of memory" in str(exception):
                        oom_time += 1
                        log_msg = "WARNING: ran out of memory,times: {}".format(oom_time)
                        print(log_msg)   
                        log2file(str(experiment_trainlog), log_msg)
                        torch.cuda.empty_cache()
                        if hasattr(torch.cuda, 'empty_cache'):
                            torch.cuda.empty_cache()
                    else:
                        log_msg = str(exception)
                        print(log_msg)   
                        log2file(str(experiment_trainlog), log_msg)
                        raise exception
#                 torch.cuda.empty_cache()
                        
                pbar_dev.update(1)


            total_time = time.time() - start

            mean_loss = np.sum(np.array(loss_tracker)) / dev_dataset.__len__()
            log_msg = "{}   | Batch {:d}/{:d} | Mean Loss {:5.5f} | Total time cost {:d} s"  \
                .format('dev'.upper(), global_step, len(train_iter), mean_loss, int(total_time))
            print(log_msg)
            log2file(str(experiment_trainlog), log_msg)

            val_loss = mean_loss

            early_stopping(val_loss, model, epoch)

            if early_stopping.early_stop:
                print("Early stopping")
                break
        
        global_step += 1
        pbar_train.update(1)

    
    total_time = time.time() - time_tracker[0]    
    mean_loss = np.sum(np.array(loss_tracker)) / train_dataset.__len__()
    log_msg = "{} | Epoch {:d}/{:d} | Mean Loss {:5.5f} | Total time cost {:d} s"  \
        .format('train'.upper(), epoch, opts.num_epochs, mean_loss, int(total_time))
    print(log_msg)
    log2file(str(experiment_trainlog), log_msg)

    #-----------------------

    loss_tracker = []

    start = time.time()
    
    pbar_dev.reset()

    for iteration, batch in enumerate(dev_iter):

        sents, bert_tokens, bert_idxs, input_seqs, target_seqs, lens = batch
        
        batch_size = input_seqs.size(0)
        assert(batch_size == target_seqs.size(0))
        
        if USE_CUDA:
            input_seqs = input_seqs.cuda()
    #         lens = lens.cuda()
            target_seqs = target_seqs.cuda()
        
        
        model.eval()
        
        try:
        
            outputs = model(input_seqs)

            loss = criterion(outputs[0].view(outputs[0].size(0)*outputs[0].size(1), -1), 
                             target_seqs.view(target_seqs.size(0)*target_seqs.size(1)))

            loss_tracker.append(loss.item()*batch_size)
            
        except RuntimeError as exception:
            
            if "out of memory" in str(exception):
                oom_time += 1
                log_msg = "WARNING: ran out of memory,times: {}".format(oom_time)
                print(log_msg)   
                log2file(str(experiment_trainlog), log_msg)
                torch.cuda.empty_cache()
                if hasattr(torch.cuda, 'empty_cache'):
                    torch.cuda.empty_cache()
            else:
                log_msg = str(exception)
                print(log_msg)   
                log2file(str(experiment_trainlog), log_msg)
                raise exception
                
#         torch.cuda.empty_cache()
        pbar_dev.update(1)

    total_time = time.time() - start

    mean_loss = np.sum(np.array(loss_tracker)) / dev_dataset.__len__()
    log_msg = "{}   | Epoch {:d}/{:d} | Mean Loss {:5.5f}  | Total time cost {:d} s"  \
        .format('dev'.upper(), epoch, opts.num_epochs, mean_loss, int(total_time))
    print(log_msg)
    log2file(str(experiment_trainlog), log_msg)
    
    val_loss = mean_loss
    
    early_stopping(val_loss, model, epoch)

    if early_stopping.early_stop:
        print("Early stopping")
        break
        
    torch.save(model.state_dict(), experiment_dir / 'epoch_{}.mdl'.format(epoch))

    print("="*50)

- vocab_size: 21128
- emb: 768
- learning_rate: 0.00015
- bert_lr: 5e-06
- weight_decay: 0.01
- max_grad_norm: 1.0
- batch_size: 4
- max_seq_len: 512
- num_epochs: 300
- warmup_steps: 4000
- gradient_accumulation: 20
- load_pretrain: True


HBox(children=(IntProgress(value=0, max=106056), HTML(value='')))

HBox(children=(IntProgress(value=0, max=11164), HTML(value='')))

optim : 
AdamW (
Parameter Group 0
    betas: (0.9, 0.999)
    correct_bias: True
    eps: 1e-06
    initial_lr: 0.00015
    lr: 0.0
    weight_decay: 0.0
)
check point :  [10605, 21211, 31816, 42422, 53028, 63633, 74239, 84844, 95450]
TRAIN | Batch 10605/106056 | Mean Loss 7.35276 | time cost 8167 s
DEV   | Batch 10605/106056 | Mean Loss 6.12653 | Total time cost 3420 s
Validation loss decreased (inf --> 6.126530)
TRAIN | Batch 21211/106056 | Mean Loss 4.49267 | time cost 11608 s
DEV   | Batch 21211/106056 | Mean Loss 5.66984 | Total time cost 3425 s
Validation loss decreased (6.126530 --> 5.669838)
TRAIN | Batch 31816/106056 | Mean Loss 2.75569 | time cost 11589 s
DEV   | Batch 31816/106056 | Mean Loss 5.21175 | Total time cost 3423 s
Validation loss decreased (5.669838 --> 5.211749)
TRAIN | Batch 42422/106056 | Mean Loss 1.90314 | time cost 11616 s
DEV   | Batch 42422/106056 | Mean Loss 4.86590 | Total time cost 3429 s
Validation loss decreased (5.211749 --> 4.865902)
TRAIN | Batch 

TRAIN | Batch 10605/106056 | Mean Loss 2.82035 | time cost 8177 s
DEV   | Batch 10605/106056 | Mean Loss 2.95687 | Total time cost 3427 s
Validation loss decreased (2.970645 --> 2.956875)
TRAIN | Batch 21211/106056 | Mean Loss 2.18348 | time cost 11586 s
DEV   | Batch 21211/106056 | Mean Loss 2.94529 | Total time cost 3438 s
Validation loss decreased (2.956875 --> 2.945290)
TRAIN | Batch 31816/106056 | Mean Loss 1.45058 | time cost 11615 s
DEV   | Batch 31816/106056 | Mean Loss 2.93581 | Total time cost 3427 s
Validation loss decreased (2.945290 --> 2.935807)
TRAIN | Batch 42422/106056 | Mean Loss 1.08507 | time cost 11577 s
DEV   | Batch 42422/106056 | Mean Loss 2.92653 | Total time cost 3425 s
Validation loss decreased (2.935807 --> 2.926532)
TRAIN | Batch 53028/106056 | Mean Loss 0.86466 | time cost 11602 s
DEV   | Batch 53028/106056 | Mean Loss 2.91725 | Total time cost 3423 s
Validation loss decreased (2.926532 --> 2.917250)
TRAIN | Batch 63633/106056 | Mean Loss 0.71969 | time co

TRAIN | Batch 31816/106056 | Mean Loss 1.30705 | time cost 11543 s
DEV   | Batch 31816/106056 | Mean Loss 2.68171 | Total time cost 3460 s
Validation loss decreased (2.683124 --> 2.681710)
TRAIN | Batch 42422/106056 | Mean Loss 0.98026 | time cost 11629 s
DEV   | Batch 42422/106056 | Mean Loss 2.67883 | Total time cost 3463 s
Validation loss decreased (2.681710 --> 2.678827)
TRAIN | Batch 53028/106056 | Mean Loss 0.78209 | time cost 11728 s
DEV   | Batch 53028/106056 | Mean Loss 2.67230 | Total time cost 3469 s
Validation loss decreased (2.678827 --> 2.672295)
TRAIN | Batch 63633/106056 | Mean Loss 0.65063 | time cost 11741 s
DEV   | Batch 63633/106056 | Mean Loss 2.66769 | Total time cost 3476 s
Validation loss decreased (2.672295 --> 2.667690)
TRAIN | Batch 74239/106056 | Mean Loss 0.55713 | time cost 11742 s
DEV   | Batch 74239/106056 | Mean Loss 2.66190 | Total time cost 3472 s
Validation loss decreased (2.667690 --> 2.661897)
TRAIN | Batch 84844/106056 | Mean Loss 0.48702 | time c