diff --git a/scripts/train/train_101.sh b/scripts/train/train_101.sh index ed435be..efd789f 100644 --- a/scripts/train/train_101.sh +++ b/scripts/train/train_101.sh @@ -1,3 +1,14 @@ +pretrain_path=experiments/dmmnet/pretrained_rvos/ +# if [ ! -d $dir ] +if [ ! -e $pretrain_path ]; then + echo "start downloading pretrained model to ${pretrain_path}" + if [ ! -e one-shot-model-youtubevos.zip ]; then + wget https://imatge.upc.edu/web/sites/default/files/projects/segmentation/public_html/rvos-pretrained-models/one-shot-model-youtubevos.zip + fi + # unzip -rq one-shot-model-youtubevos.zip # && rm one-shot-model-youtubevos.zip + mkdir -p ${pretrain_path} && mv one-shot-model-youtubevos ${pretrain_path} +fi + BS=4 NGPU=4 train_h=255 diff --git a/train.py b/train.py index 882d4d9..45b7427 100644 --- a/train.py +++ b/train.py @@ -130,7 +130,7 @@ def build_model(args): encoder = FeatureExtractor(args) decoder = RefineMask(args) if args.base_model == 'resnet101': - pretrained_path = 'experiments/dmmnet/pretrained_rvos/one-shot-model-youtubevos/') + pretrained_path = 'experiments/dmmnet/pretrained_rvos/one-shot-model-youtubevos/' if not os.path.isdir(pretrained_path): msg = 'pretrained model from rvos not found in %s; please run '%pretrained_path msg += '\n wget https://imatge.upc.edu/web/sites/default/files/projects/segmentation/public_html/rvos-pretrained-models/one-shot-model-youtubevos.zip' @@ -138,7 +138,8 @@ def build_model(args): ppath = 'experiments/dmmnet/pretrained_rvos/' msg += '\n mkdir -p %s && mv one-shot-model-youtubevos %s/'%(ppath, ppath) print(msg) - exit() + raise FileNotFoundError(pretrained_path) + encoder_dict, decoder_dict,_,_,_ = load_checkpoint(pretrained_path) enc_dict_new = encoder.state_dict() decoder_dict_new = decoder.state_dict()