In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from config.rnn import default as default_config
from models import RNN
from utils import DataManager
from utils.validation import (Accumulator, AggregateMetric, calcu_metrics)


from functional import seq
import json
from multiprocessing import Process, Pipe
from pathlib import Path
import pickle
from tqdm import tqdm
import tensorflow as tf
from tensorpack import (TrainConfig, SyncMultiGPUTrainerParameterServer as Trainer,
                        PredictConfig, SaverRestore, logger)
from tensorpack.train import launch_train_with_config
from tensorpack.predict import SimpleDatasetPredictor as Predictor
from tensorpack.callbacks import (
    ScheduledHyperParamSetter, MaxSaver, ModelSaver, DataParallelInferenceRunner as InfRunner)
from tensorpack.tfutils.common import get_default_sess_config
from tensorpack.tfutils.sesscreate import ReuseSessionCreator

In [None]:
RESNET_LOC = "data/resnet_v2_101/resnet_v2_101.ckpt"
LOG_LOC = "train_log/train_with_all"
ignore_restore = ['learning_rate', 'global_step', 'logits/weights', 'logits/biases', 
                      'hidden_fc/weights', 'hidden_fc/biases']
save_name = "max-training-macro-f1.tfmodel"

logger.set_logger_dir(LOG_LOC)

In [None]:
config = default_config
config.stages = [2, 3, 4, 5, 6]
config.proportion = {'train': 0.8, 'val': 0.2, 'test': 0.0}
config.annotation_number = 30
config.batch_size = 16
threshold = 0.4

dm = DataManager.from_config(config)
train_data = dm.get_train_stream()
validation_data = dm.get_validation_stream(batch_size=32)

with open(LOG_LOC + '/vocabulary.json', 'w') as f:
    json.dump(list(dm.get_vocabulary()), f)

In [None]:
model = RNN(config, is_finetuning=True, label_weights=dm.get_imbalance_ratio().train.values)

tf.reset_default_graph()
train_config = TrainConfig(model=model, dataflow=train_data,
                            callbacks=[
                                ScheduledHyperParamSetter(
                                    'learning_rate', [(0, 1e-4), (40, 1e-5)]),
                                InfRunner(validation_data, [AggregateMetric(config.validation_metrics, threshold)],
                                        [0, 1]),
                                ModelSaver(max_to_keep=5),
                                MaxSaver('macro_f1', save_name),
                            ],
                            session_init=SaverRestore(
                                model_path=RESNET_LOC, ignore=ignore_restore),
                            max_epoch=65)
launch_train_with_config(train_config, Trainer(gpus=[0, 1]))