Skip to content

Commit

Permalink
Rename preloaded model.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamvitJ committed Jun 10, 2019
1 parent 4424d1f commit bc4089d
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 10 deletions.
3 changes: 2 additions & 1 deletion dff_deeplab/config/config.py
Expand Up @@ -26,7 +26,8 @@
# network related params
config.network = edict()
config.network.pretrained = ''
config.network.pretrained_flow = ''
config.network.pretrained_base = ''
config.network.pretrained_ec = ''
config.network.pretrained_epoch = 0
config.network.PIXEL_MEANS = np.array([0, 0, 0])
config.network.IMAGE_STRIDE = 0
Expand Down
10 changes: 5 additions & 5 deletions dff_deeplab/train_end2end.py
Expand Up @@ -49,7 +49,7 @@ def parse_args():
from utils.lr_scheduler import WarmupMultiFactorScheduler


def train_net(args, ctx, pretrained, pretrained_flow, pretrained_ec, epoch, prefix, begin_epoch, end_epoch, lr, lr_step):
def train_net(args, ctx, pretrained, pretrained_base, pretrained_ec, epoch, prefix, begin_epoch, end_epoch, lr, lr_step):
logger, final_output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set)
prefix = os.path.join(final_output_path, prefix)

Expand Down Expand Up @@ -95,9 +95,9 @@ def train_net(args, ctx, pretrained, pretrained_flow, pretrained_ec, epoch, pref
else:
print pretrained
arg_params, aux_params = load_param(pretrained, epoch, convert=True)
arg_params_flow, aux_params_flow = load_param(pretrained_flow, epoch, convert=True)
arg_params.update(arg_params_flow)
aux_params.update(aux_params_flow)
arg_params_base, aux_params_base = load_param(pretrained_base, epoch, convert=True)
arg_params.update(arg_params_base)
aux_params.update(aux_params_base)
arg_params_ec, aux_params_ec = load_param(pretrained_ec, epoch, convert=True, argprefix=config.TRAIN.arg_prefix)
arg_params.update(arg_params_ec)
aux_params.update(aux_params_ec)
Expand Down Expand Up @@ -162,7 +162,7 @@ def train_net(args, ctx, pretrained, pretrained_flow, pretrained_ec, epoch, pref
def main():
print('Called with argument:', args)
ctx = [mx.gpu(int(i)) for i in config.gpus.split(',')]
train_net(args, ctx, config.network.pretrained, config.network.pretrained_flow, config.network.pretrained_ec, config.network.pretrained_epoch, config.TRAIN.model_prefix,
train_net(args, ctx, config.network.pretrained, config.network.pretrained_base, config.network.pretrained_ec, config.network.pretrained_epoch, config.TRAIN.model_prefix,
config.TRAIN.begin_epoch, config.TRAIN.end_epoch, config.TRAIN.lr, config.TRAIN.lr_step)

if __name__ == '__main__':
Expand Down
Expand Up @@ -11,7 +11,7 @@ default:
kvstore: device
network:
pretrained: "./model/rfcn_dff_flownet_vid"
pretrained_flow: "./model/pretrained/deeplab-101"
pretrained_base: "./model/pretrained/deeplab-101"
pretrained_ec: "./model/pretrained/deeplab-101"
pretrained_epoch: 0
PIXEL_MEANS:
Expand Down
Expand Up @@ -11,7 +11,7 @@ default:
kvstore: device
network:
pretrained: "./model/rfcn_dff_flownet_vid"
pretrained_flow: "./model/pretrained/deeplab-101"
pretrained_base: "./model/pretrained/deeplab-101"
pretrained_ec: "./model/pretrained/deeplab-18"
pretrained_epoch: 0
PIXEL_MEANS:
Expand Down
Expand Up @@ -11,7 +11,7 @@ default:
kvstore: device
network:
pretrained: "./model/rfcn_dff_flownet_vid"
pretrained_flow: "./model/pretrained/deeplab-101"
pretrained_base: "./model/pretrained/deeplab-101"
pretrained_ec: "./model/pretrained/deeplab-34"
pretrained_epoch: 0
PIXEL_MEANS:
Expand Down
Expand Up @@ -11,7 +11,7 @@ default:
kvstore: device
network:
pretrained: "./model/rfcn_dff_flownet_vid"
pretrained_flow: "./model/pretrained/deeplab-101"
pretrained_base: "./model/pretrained/deeplab-101"
pretrained_ec: "./model/pretrained/deeplab-50"
pretrained_epoch: 0
PIXEL_MEANS:
Expand Down

0 comments on commit bc4089d

Please sign in to comment.