From 3a4a56e15dc30e5f7d9c4c7706f5710d1da73e52 Mon Sep 17 00:00:00 2001 From: Bolei Zhou Date: Wed, 2 May 2018 15:49:01 -0400 Subject: [PATCH] make places demo compatible with pytorch0.4 --- run_placesCNN_basic.py | 24 +++++++----------------- run_placesCNN_unified.py | 22 ++++++++++++++-------- wideresnet.py | 6 ++++-- 3 files changed, 25 insertions(+), 27 deletions(-) diff --git a/run_placesCNN_basic.py b/run_placesCNN_basic.py index bf2ffa2..3abaac1 100644 --- a/run_placesCNN_basic.py +++ b/run_placesCNN_basic.py @@ -15,28 +15,18 @@ arch = 'resnet18' # load the pre-trained weights -model_file = 'whole_%s_places365_python36.pth.tar' % arch +model_file = '%s_places365.pth.tar' % arch if not os.access(model_file, os.W_OK): weight_url = 'http://places2.csail.mit.edu/models_places365/' + model_file os.system('wget ' + weight_url) -useGPU = 1 -if useGPU == 1: - model = torch.load(model_file) -else: - model = torch.load(model_file, map_location=lambda storage, loc: storage) # model trained in GPU could be deployed in CPU machine like this! - -## assume all the script in python36, so the following is not necessary -## if you encounter the UnicodeDecodeError when use python3 to load the model, add the following line will fix it. Thanks to @soravux -#from functools import partial -#import pickle -#pickle.load = partial(pickle.load, encoding="latin1") -#pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1") -#model = torch.load(model_file, map_location=lambda storage, loc: storage, pickle_module=pickle) -#torch.save(model, 'whole_%s_places365_python36.pth.tar'%arch) - +model = models.__dict__[arch](num_classes=365) +checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage) +state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()} +model.load_state_dict(state_dict) model.eval() + # load the image transformer centre_crop = trn.Compose([ trn.Resize((256,256)), @@ -63,7 +53,7 @@ os.system('wget ' + img_url) img = Image.open(img_name) -input_img = V(centre_crop(img).unsqueeze(0), volatile=True) +input_img = V(centre_crop(img).unsqueeze(0)) # forward pass logit = model.forward(input_img) diff --git a/run_placesCNN_unified.py b/run_placesCNN_unified.py index fd87b70..3c21fc2 100644 --- a/run_placesCNN_unified.py +++ b/run_placesCNN_unified.py @@ -86,15 +86,19 @@ def returnTF(): def load_model(): # this model has a last conv feature map as 14x14 - model_file = 'whole_wideresnet18_places365_python36.pth.tar' + model_file = 'wideresnet18_places365.pth.tar' if not os.access(model_file, os.W_OK): os.system('wget http://places2.csail.mit.edu/models_places365/' + model_file) os.system('wget https://raw.githubusercontent.com/csailvision/places365/master/wideresnet.py') - useGPU = 0 - if useGPU == 1: - model = torch.load(model_file) - else: - model = torch.load(model_file, map_location=lambda storage, loc: storage) # allow cpu + + import wideresnet + model = wideresnet.resnet18(num_classes=365) + checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage) + state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()} + model.load_state_dict(state_dict) + model.eval() + + # the following is deprecated, everything is migrated to python36 @@ -132,17 +136,19 @@ def load_model(): img_url = 'http://places2.csail.mit.edu/imgs/12.jpg' os.system('wget %s -q -O test.jpg' % img_url) img = Image.open('test.jpg') -input_img = V(tf(img).unsqueeze(0), volatile=True) +input_img = V(tf(img).unsqueeze(0)) # forward pass logit = model.forward(input_img) h_x = F.softmax(logit, 1).data.squeeze() probs, idx = h_x.sort(0, True) +probs = probs.numpy() +idx = idx.numpy() print('RESULT ON ' + img_url) # output the IO prediction -io_image = np.mean(labels_IO[idx[:10].numpy()]) # vote for the indoor or outdoor +io_image = np.mean(labels_IO[idx[:10]]) # vote for the indoor or outdoor if io_image < 0.5: print('--TYPE OF ENVIRONMENT: indoor') else: diff --git a/wideresnet.py b/wideresnet.py index 0c50e30..91b47c0 100644 --- a/wideresnet.py +++ b/wideresnet.py @@ -115,8 +115,10 @@ def __init__(self, block, layers, num_classes=1000): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() + #m.weight.data.fill_(1) + #m.bias.data.zero_() + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) def _make_layer(self, block, planes, blocks, stride=1): downsample = None