Skip to content

Commit

Permalink
make places demo compatible with pytorch0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
zhoubolei committed May 2, 2018
1 parent 7c728f5 commit 3a4a56e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 27 deletions.
24 changes: 7 additions & 17 deletions run_placesCNN_basic.py
Expand Up @@ -15,28 +15,18 @@
arch = 'resnet18'

# load the pre-trained weights
model_file = 'whole_%s_places365_python36.pth.tar' % arch
model_file = '%s_places365.pth.tar' % arch
if not os.access(model_file, os.W_OK):
weight_url = 'http://places2.csail.mit.edu/models_places365/' + model_file
os.system('wget ' + weight_url)

useGPU = 1
if useGPU == 1:
model = torch.load(model_file)
else:
model = torch.load(model_file, map_location=lambda storage, loc: storage) # model trained in GPU could be deployed in CPU machine like this!

## assume all the script in python36, so the following is not necessary
## if you encounter the UnicodeDecodeError when use python3 to load the model, add the following line will fix it. Thanks to @soravux
#from functools import partial
#import pickle
#pickle.load = partial(pickle.load, encoding="latin1")
#pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
#model = torch.load(model_file, map_location=lambda storage, loc: storage, pickle_module=pickle)
#torch.save(model, 'whole_%s_places365_python36.pth.tar'%arch)

model = models.__dict__[arch](num_classes=365)
checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
model.load_state_dict(state_dict)
model.eval()


# load the image transformer
centre_crop = trn.Compose([
trn.Resize((256,256)),
Expand All @@ -63,7 +53,7 @@
os.system('wget ' + img_url)

img = Image.open(img_name)
input_img = V(centre_crop(img).unsqueeze(0), volatile=True)
input_img = V(centre_crop(img).unsqueeze(0))

# forward pass
logit = model.forward(input_img)
Expand Down
22 changes: 14 additions & 8 deletions run_placesCNN_unified.py
Expand Up @@ -86,15 +86,19 @@ def returnTF():
def load_model():
# this model has a last conv feature map as 14x14

model_file = 'whole_wideresnet18_places365_python36.pth.tar'
model_file = 'wideresnet18_places365.pth.tar'
if not os.access(model_file, os.W_OK):
os.system('wget http://places2.csail.mit.edu/models_places365/' + model_file)
os.system('wget https://raw.githubusercontent.com/csailvision/places365/master/wideresnet.py')
useGPU = 0
if useGPU == 1:
model = torch.load(model_file)
else:
model = torch.load(model_file, map_location=lambda storage, loc: storage) # allow cpu

import wideresnet
model = wideresnet.resnet18(num_classes=365)
checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
model.load_state_dict(state_dict)
model.eval()



# the following is deprecated, everything is migrated to python36

Expand Down Expand Up @@ -132,17 +136,19 @@ def load_model():
img_url = 'http://places2.csail.mit.edu/imgs/12.jpg'
os.system('wget %s -q -O test.jpg' % img_url)
img = Image.open('test.jpg')
input_img = V(tf(img).unsqueeze(0), volatile=True)
input_img = V(tf(img).unsqueeze(0))

# forward pass
logit = model.forward(input_img)
h_x = F.softmax(logit, 1).data.squeeze()
probs, idx = h_x.sort(0, True)
probs = probs.numpy()
idx = idx.numpy()

print('RESULT ON ' + img_url)

# output the IO prediction
io_image = np.mean(labels_IO[idx[:10].numpy()]) # vote for the indoor or outdoor
io_image = np.mean(labels_IO[idx[:10]]) # vote for the indoor or outdoor
if io_image < 0.5:
print('--TYPE OF ENVIRONMENT: indoor')
else:
Expand Down
6 changes: 4 additions & 2 deletions wideresnet.py
Expand Up @@ -115,8 +115,10 @@ def __init__(self, block, layers, num_classes=1000):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
#m.weight.data.fill_(1)
#m.bias.data.zero_()
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
Expand Down

0 comments on commit 3a4a56e

Please sign in to comment.