diff --git a/train.py b/train.py index 5810686..882d4d9 100644 --- a/train.py +++ b/train.py @@ -130,7 +130,16 @@ def build_model(args): encoder = FeatureExtractor(args) decoder = RefineMask(args) if args.base_model == 'resnet101': - encoder_dict, decoder_dict,_,_,_ = load_checkpoint('../../experiments/models/one-shot-model-youtubevos/') + pretrained_path = 'experiments/dmmnet/pretrained_rvos/one-shot-model-youtubevos/') + if not os.path.isdir(pretrained_path): + msg = 'pretrained model from rvos not found in %s; please run '%pretrained_path + msg += '\n wget https://imatge.upc.edu/web/sites/default/files/projects/segmentation/public_html/rvos-pretrained-models/one-shot-model-youtubevos.zip' + msg += '\n zip -rq one-shot-model-youtubevos.zip ' + ppath = 'experiments/dmmnet/pretrained_rvos/' + msg += '\n mkdir -p %s && mv one-shot-model-youtubevos %s/'%(ppath, ppath) + print(msg) + exit() + encoder_dict, decoder_dict,_,_,_ = load_checkpoint(pretrained_path) enc_dict_new = encoder.state_dict() decoder_dict_new = decoder.state_dict() for name, param in enc_dict_new.items(): # named_parameters():