Skip to content

Commit

Permalink
Merge pull request #4 from BorisTestov/cpu-key
Browse files Browse the repository at this point in the history
Cpu key
  • Loading branch information
Tramac committed May 13, 2019
2 parents bac15d8 + ae41051 commit a197f5e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
8 changes: 5 additions & 3 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
help='path to the input picture')
parser.add_argument('--outdir', default='./test_result', type=str,
help='path to save the predict result')

parser.add_argument('--cpu', dest='cpu', action='store_true')
parser.set_defaults(cpu=False)

args = parser.parse_args()


Expand All @@ -36,10 +40,8 @@ def demo():
])
image = Image.open(args.input_pic).convert('RGB')
image = transform(image).unsqueeze(0).to(device)

model = get_fast_scnn(args.dataset, pretrained=True, root=args.weights_folder).to(device)
model = get_fast_scnn(args.dataset, pretrained=True, root=args.weights_folder, map_cpu=args.cpu).to(device)
print('Finished loading model!')

model.eval()
with torch.no_grad():
outputs = model(image)
Expand Down
7 changes: 5 additions & 2 deletions models/fast_scnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def forward(self, x):
return x


def get_fast_scnn(dataset='citys', pretrained=False, root='./weights', **kwargs):
def get_fast_scnn(dataset='citys', pretrained=False, root='./weights', map_cpu=False, **kwargs):
acronyms = {
'pascal_voc': 'voc',
'pascal_aug': 'voc',
Expand All @@ -244,7 +244,10 @@ def get_fast_scnn(dataset='citys', pretrained=False, root='./weights', **kwargs)
from data_loader import datasets
model = FastSCNN(datasets[dataset].NUM_CLASS, **kwargs)
if pretrained:
model.load_state_dict(torch.load(os.path.join(root, 'fast_scnn_%s_best_model.pth' % acronyms[dataset])))
if(map_cpu):
model.load_state_dict(torch.load(os.path.join(root, 'fast_scnn_%s.pth' % acronyms[dataset]), map_location='cpu'))
else:
model.load_state_dict(torch.load(os.path.join(root, 'fast_scnn_%s.pth' % acronyms[dataset])))
return model


Expand Down

0 comments on commit a197f5e

Please sign in to comment.