In [None]:
from torch.utils.data.sampler import Sampler, SequentialSampler, RandomSampler
from torch.utils.data.dataloader import DataLoader
from transformers import BertTokenizer,AdamW, get_linear_schedule_with_warmup
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn
import os
import time

from train_utils import get_torch_device, ModelSave, EarlyStop
from dataset import SeqLabelDataset
from model import BertSoftmax
from metric import seq_tag_metrics, multi_cls_log
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

device = get_torch_device()

In [None]:
def load(file_name):
    import json
    file = '/data/junmo_data/BertManual/trainsample/people_daily/{}.txt'
    data = []
    with open(file.format(file_name), 'r') as f:
        for line in f.readlines():
            data.append(json.loads(line.strip()))
    return data

batch_size =16
max_seq_len=150
pretrain_model = 'bert-base-chinese'
train_params = {
    'lr': 5e-5,
    'eps':1e-10,
    'epoch_size': 5,
    'batch_size': batch_size,
    'max_seq_len': max_seq_len,
}

model_params = {
    'pretrain_model': pretrain_model,
    'loss_fn': nn.CrossEntropyLoss(),
    'dropout':0.5,
    'label_size':7,
}

tokenizer = BertTokenizer.from_pretrained(pretrain_model, do_lower_case=True)

train_dataset = SeqLabelDataset('train', max_seq_len, tokenizer, load)
valid_dataset = SeqLabelDataset('valid', max_seq_len, tokenizer, load)

train_sampler = RandomSampler(train_dataset)
valid_sampler = SequentialSampler(valid_dataset)

train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)
valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=batch_size)

In [None]:
# Instantiate Bert Classifier
model = BertSoftmax(**model_params)

# Tell PyTorch to run the model on GPU
model.to(device)

# Create the optimizer
optimizer = AdamW(model.parameters(),
                  lr=train_params['lr'],    # Default learning rate
                  eps=train_params['eps']   # Default epsilon value
                  )

train_params.update({
    'num_train_steps':len(train_loader),
    'total_train_steps':len(train_loader) * train_params['epoch_size']
})
# Set up the learning rate scheduler
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=0, # Default value
                                            num_training_steps=train_params['total_train_steps'])

In [None]:
CKPT = '/data/junmo_data/BertManual/checkpoint/msra'
saver = ModelSave(CKPT, continue_train=False)

global_step = 0
saver.init()

tb = SummaryWriter(CKPT)
save_steps =100
log_steps=20
es = EarlyStop(monitor='acc_macro', mode='max', verbose=True)
for epoch_i in range(1, train_params['epoch_size']):
    # =======================================
    #               Training
    # =======================================
    # Print the header of the result table
    print(f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10}  | {'Elapsed':^9}")
    print("-"*60)

    t0_epoch, t0_batch = time.time(), time.time()
    total_loss, batch_loss, batch_counts = 0, 0, 0

    model.train()
    for step, batch in enumerate(train_loader):
        global_step +=1
        batch_counts +=1
        input_ids, token_type_ids, attention_mask, label_ids = tuple(t.to(device) for t in batch.values())
        model.zero_grad()

        logits, loss = model(input_ids, token_type_ids, attention_mask, label_ids)
        batch_loss += loss.item()
        total_loss += loss.item()

        if global_step==1:
            # add graph to tensorboard, only do it one time
            tb.add_graph(model, (input_ids, token_type_ids, attention_mask))

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        if (step % log_steps == 0 and step != 0) or (step == len(train_loader) - 1):
            time_elapsed = time.time() - t0_batch
            print(f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^9} | {time_elapsed:^9.2f}")
            tb.add_scalar('loss/batch_train', batch_loss / batch_counts, global_step=global_step)
            batch_loss, batch_counts = 0, 0
            t0_batch = time.time()

        if (step % save_steps==0 and step!=0) or (step == len(train_loader) - 1):
            avg_train_loss = total_loss / step
            val_metrics = seq_tag_metrics(model, valid_loader, device)
            for key, val in val_metrics.items():
                tb.add_scalar(f'metric/{key}', val, global_step=global_step)
            saver(avg_train_loss, val_metrics['val_loss'], global_step, model, optimizer, scheduler)
            tb.add_scalars('loss/train_valid',{'train': avg_train_loss,
                                                'valid': val_metrics['val_loss']}, global_step=global_step)
            if es.check(val_metrics):
                break

    avg_train_loss = total_loss / step
    print("-"*70)
    val_metrics = seq_tag_metrics(model, valid_loader, device)
    time_elapsed = time.time() - t0_epoch
    print(f"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_metrics['val_loss']:^9.6f} | {time_elapsed:^9.2f}")
    multi_cls_log(epoch_i, val_metrics)
    print("\n")
