In [None]:
%load_ext autoreload
%autoreload 2

from config.rnn import default
from models import RNN
import numpy as np
from functional import seq
import tensorflow as tf
from tensorpack import (TrainConfig, SyncMultiGPUTrainerParameterServer as Trainer, 
                        PredictConfig, MultiProcessDatasetPredictor as Predictor,
                        SaverRestore, logger)
from tensorpack.callbacks import (ScheduledHyperParamSetter, MaxSaver, ModelSaver,
                                  DataParallelInferenceRunner as InfRunner)
from tensorpack.predict import SimpleDatasetPredictor
from tensorpack.tfutils.common import get_default_sess_config
from utils import DataManager
from utils.validation import (Accumulator, AggregateMetric, calcu_metrics)

resnet_loc = "./data/resnet_v2_101/resnet_v2_101.ckpt"
log_dir = './train_log/rnn_v2/'
logger.set_logger_dir(log_dir)

In [None]:
config = default
ignore_restore = ['learning_rate', 'global_step']
save_name = "all-stages-max-micro-auc.tfmodel"

In [None]:
config.stages = [2, 3, 4, 5, 6]
config.proportion = {'train': 0.55, 'val': 0.0, 'test': 0.45}
config.annotation_number = None
dm = DataManager.from_config(config)
train_set = dm.get_train_set()
test_set = dm.get_test_set()

In [None]:
config.proportion = {'train': 0.8, 'val':0.2, 'test': 0.0}
config.annotation_number = 30
dm = DataManager.from_dataset(train_set, test_set, config)

In [None]:
config.weight_decay = 0.0
config.dropout_keep_prob = 0.5
config.gamma = 2
config.use_glimpse = True
config.read_time = 5
config.batch_size = 64

threshold = 0.4
train_data = dm.get_train_stream()
val_data = dm.get_validation_stream()
model = RNNV2(config, is_finetuning=False)

In [None]:
tf.reset_default_graph()
train_config = TrainConfig(model=model, dataflow=train_data,
                           callbacks=[
                               ScheduledHyperParamSetter('learning_rate', [(0, 1e-4), (15, 1e-5)]),
                               InfRunner(val_data, [AggregateMetric(config.validation_metrics, threshold)],
                                         [0, 1]),
                               ModelSaver(var_collections='model_variables'),
                               MaxSaver('micro_auc', save_name),
                           ],
                           session_init=SaverRestore(
                               model_path=resnet_loc, ignore=ignore_restore),
                           max_epoch=20, tower=[0, 1])
Trainer(train_config).train()