In [1]:
#!/usr/bin/env python
"""
    Main training workflow
"""
from __future__ import division

import argparse
import glob
import os
import sys
import random
import signal
import time

import torch
from pytorch_pretrained_bert import BertConfig

import distributed
from models import data_loader, model_builder
from models.data_loader import load_dataset
from models.model_builder import Summarizer
from models.trainer import build_trainer
from others.logging import logger, init_logger

model_flags = ['hidden_size', 'ff_size', 'heads', 'inter_layers','encoder','ff_actv', 'use_interval','rnn_size']


In [2]:

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

## command from BertSum github

```python
python train.py 
    -mode train 
    -encoder classifier 
    -dropout 0.1 
    -bert_data_path ../bert_data/cnndm 
    -model_path ../models/bert_classifier 
    -lr 2e-3 
    -visible_gpus 0,1,2  # microfocus serve have 2 tesla gpu, should change from 0,1,2 to 0,1
    -gpu_ranks 0,1,2     # microfocus serve have 2 tesla gpu, should change from 0,1,2 to 0,1
    -world_size 3        # microfocus serve have 2 tesla gpu, should change from 3 to 2
    -report_every 50 
    -save_checkpoint_steps 1000 
    -batch_size 3000 
    -decay_method noam 
    -train_steps 50000 
    -accum_count 2 
    -log_file ../logs/bert_classifier 
    -use_interval true 
    -warmup_steps 10000
```

In [81]:
!$sys.executable -m python3 python train.py -mode train -encoder classifier -dropout 0.1 \
-bert_data_path ../bert_data/cnndm \
-model_path ../models/bert_classifier \
-lr 2e-3 \
-visible_gpus 0,1 \
-gpu_ranks 0,1 \
-world_size 2 \
-report_every 50 \
-save_checkpoint_steps 1000 \
-batch_size 3000 \
-decay_method noam \
-train_steps 50000 \
-accum_count 2 \
-log_file ../logs/bert_classifier \
-use_interval true \
-warmup_steps 10000\

/usr/bin/python3: No module named python3


In [73]:
#if __name__ == '__main__':
parser = argparse.ArgumentParser()

parser.add_argument("-encoder", default='classifier', type=str, choices=['classifier','transformer','rnn','baseline'])
parser.add_argument("-mode", default='train', type=str, choices=['train','validate','test'])
parser.add_argument("-bert_data_path", default='../bert_data/cnndm')
parser.add_argument("-model_path", default='../models/')
parser.add_argument("-result_path", default='../results/cnndm')
parser.add_argument("-temp_dir", default='../temp')
parser.add_argument("-bert_config_path", default='../bert_config_uncased_base.json')

parser.add_argument("-batch_size", default=1000, type=int)

parser.add_argument("-use_interval", type=str2bool, nargs='?',const=True,default=True)
parser.add_argument("-hidden_size", default=128, type=int)
parser.add_argument("-ff_size", default=512, type=int)
parser.add_argument("-heads", default=4, type=int)
parser.add_argument("-inter_layers", default=2, type=int)
parser.add_argument("-rnn_size", default=512, type=int)

parser.add_argument("-param_init", default=0, type=float)
parser.add_argument("-param_init_glorot", type=str2bool, nargs='?',const=True,default=True)
parser.add_argument("-dropout", default=0.1, type=float)
parser.add_argument("-optim", default='adam', type=str)
parser.add_argument("-lr", default=1, type=float)
parser.add_argument("-beta1", default= 0.9, type=float)
parser.add_argument("-beta2", default=0.999, type=float)
parser.add_argument("-decay_method", default='', type=str)
parser.add_argument("-warmup_steps", default=8000, type=int)
parser.add_argument("-max_grad_norm", default=0, type=float)

parser.add_argument("-save_checkpoint_steps", default=5, type=int)
parser.add_argument("-accum_count", default=1, type=int)
parser.add_argument("-world_size", default=1, type=int)
parser.add_argument("-report_every", default=1, type=int)
parser.add_argument("-train_steps", default=1000, type=int)
parser.add_argument("-recall_eval", type=str2bool, nargs='?',const=True,default=False)


parser.add_argument('-visible_gpus', default='-1', type=str)
parser.add_argument('-gpu_ranks', default='0', type=str)
parser.add_argument('-log_file', default='../logs/cnndm.log')
parser.add_argument('-dataset', default='')
parser.add_argument('-seed', default=666, type=int)

parser.add_argument("-test_all", type=str2bool, nargs='?',const=True,default=False)
parser.add_argument("-test_from", default='')
parser.add_argument("-train_from", default='')
parser.add_argument("-report_rouge", type=str2bool, nargs='?',const=True,default=True)
parser.add_argument("-block_trigram", type=str2bool, nargs='?', const=True, default=True)

args = parser.parse_args()
# ipython notebook import parser issue, solution and root cause
# https://blog.csdn.net/u012869752/article/details/72513141
#args = parser.parse_args(args=[])

#args.gpu_ranks = [int(i) for i in args.gpu_ranks.split(',')]
args.gpu_ranks = [int(i) for i in args.gpu_ranks.split(',')]
os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpus

init_logger(args.log_file)
device = "cpu" if args.visible_gpus == '-1' else "cuda"
device_id = 0 if device == "cuda" else -1



_StoreAction(option_strings=['-block_trigram'], dest='block_trigram', nargs='?', const=True, default=True, type=<function str2bool at 0x7f85cdd88730>, choices=None, help=None, metavar=None)

# decide which component should be executed
# this part could be choice manually 

```python
if(args.world_size>1):
    multi_main(args)
elif (args.mode == 'train'):
    train(args, device_id)
elif (args.mode == 'validate'):
    wait_and_validate(args, device_id)
elif (args.mode == 'lead'):
    baseline(args, cal_lead=True)
elif (args.mode == 'oracle'):
    baseline(args, cal_oracle=True)
elif (args.mode == 'test'):
    cp = args.test_from
    try:
        step = int(cp.split('.')[-2].split('_')[-1])
    except:
        step = 0
    test(args, device_id, cp, step)
```

In [None]:
def train(args, device_id):
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)


    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    def train_iter_fct():
        return data_loader.Dataloader(args, load_dataset(args, 'train', shuffle=True), args.batch_size, device,
                                                 shuffle=True, is_test=False)

    model = Summarizer(args, device, load_pretrained_bert=True)
    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
        model.load_cp(checkpoint)
        optim = model_builder.build_optim(args, model, checkpoint)
    else:
        optim = model_builder.build_optim(args, model, None)

    logger.info(model)
    trainer = build_trainer(args, device_id, model, optim)
    trainer.train(train_iter_fct, args.train_steps)

In [83]:
device_id

0

In [82]:
def train(args, device_id):
    print(device_id)
    print(args)

In [84]:
train(args, device_id)

0
Namespace(accum_count=2, batch_size=3000, bert_config_path='../bert_config_uncased_base.json', bert_data_path='../bert_data/cnndm', beta1=0.9, beta2=0.999, block_trigram=True, dataset='', decay_method='noam', dropout=0.1, encoder='classifier', ff_size=512, gpu_ranks=[0, 1, 2], heads=4, hidden_size=128, inter_layers=2, log_file='../logs/bert_classifier ', lr=0.002, max_grad_norm=0, mode='train', model_path='../models/bert_classifier', optim='adam', param_init=0, param_init_glorot=True, recall_eval=False, report_every=50, report_rouge=True, result_path='../results/cnndm', rnn_size=512, save_checkpoint_steps=1000, seed=666, temp_dir='../temp', test_all=False, test_from='', train_from='', train_steps=50000, use_interval=True, visible_gpus='0,1,2', warmup_steps=10000, world_size=3)


In [85]:
init_logger(args.log_file)

<RootLogger root (INFO)>