Skip to content

Commit

Permalink
Merge pull request #230 from CosmiQ/iss228_check_model_dl
Browse files Browse the repository at this point in the history
ISS228: removing gdal from setup.py; fix model download check
  • Loading branch information
nrweir committed Aug 11, 2019
2 parents bf7f100 + 8b6616b commit f09bf7e
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 164 deletions.
161 changes: 0 additions & 161 deletions docs/tutorials/notebooks/api_inference_spacenet.ipynb

This file was deleted.

1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def check_output(cmd):
'opencv-python==4.1.0.25',
'numpy>=1.16.4',
'tqdm>=4.32.2',
'GDAL>=2.4.0',
'rtree>=0.8.3',
'networkx>=2.3',
'rasterio>=1.0.18',
Expand Down
25 changes: 23 additions & 2 deletions solaris/nets/model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,33 @@ def _load_model_weights(model, path, framework):
try:
model.load_weights(path)
except OSError:
raise FileNotFoundError("{} doesn't exist.".format(path))
# first, check to see if the weights are in the default sol dir
default_path = os.path.join(weights_dir, os.path.split(path)[1])
try:
model.load_weights(default_path)
except OSError:
# if they can't be found anywhere, raise the error.
raise FileNotFoundError("{} doesn't exist.".format(path))

elif framework.lower() in ['torch', 'pytorch']:
# pytorch already throws the right error on failed load, so no need
# to fix exception
loaded = torch.load(path)
if torch.cuda.is_available():
try:
loaded = torch.load(path)
except FileNotFoundError:
# first, check to see if the weights are in the default sol dir
default_path = os.path.join(weights_dir,
os.path.split(path)[1])
loaded = torch.load(path)
else:
try:
loaded = torch.load(path, map_location='cpu')
except FileNotFoundError:
default_path = os.path.join(weights_dir,
os.path.split(path)[1])
loaded = torch.load(path, map_location='cpu')

if isinstance(loaded, torch.nn.Module): # if it's a full model already
model.load_state_dict(loaded.state_dict())
else:
Expand Down

0 comments on commit f09bf7e

Please sign in to comment.