Skip to content

Commit

Permalink
switching to main rwightman repo with hundreds of models
Browse files Browse the repository at this point in the history
  • Loading branch information
achaiah committed Jun 8, 2020
1 parent 6a0bfd5 commit 5d26a94
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions pywick/models/model_utils.py
Expand Up @@ -9,6 +9,8 @@
import os
import errno

rwightman_repo = 'rwightman/pytorch-image-models'


class ModelType(Enum):
"""
Expand Down Expand Up @@ -77,9 +79,9 @@ def get_model(model_type, model_name, num_classes, pretrained=True, **kwargs):
print("INFO: Loading Model: -- " + model_name + " with number of classes: " + str(num_classes))

if model_type == ModelType.CLASSIFICATION:
torch_hub_names = torch.hub.list('rwightman/gen-efficientnet-pytorch')
torch_hub_names = torch.hub.list(rwightman_repo)
if model_name in torch_hub_names:
model = torch.hub.load('rwightman/gen-efficientnet-pytorch', model_name, pretrained=pretrained, num_classes=num_classes)
model = torch.hub.load(rwightman_repo, model_name, pretrained=pretrained, num_classes=num_classes)
else:
# 1. Load model (pretrained or vanilla)
import ssl
Expand Down Expand Up @@ -239,7 +241,7 @@ def get_supported_models(type):
pt_excludes.append(modname.split('.')[-1])
pt_names = [x for x in torch_models.__dict__.keys() if '__' not in x and x not in pt_excludes] # includes directory and filenames

torch_hub_names = torch.hub.list('rwightman/gen-efficientnet-pytorch')
torch_hub_names = torch.hub.list(rwightman_repo)

return pywick_names + pt_names + torch_hub_names
else:
Expand Down

0 comments on commit 5d26a94

Please sign in to comment.