Skip to content

Commit

Permalink
fixed the netdissect pytorch issues of image directory and input model
Browse files Browse the repository at this point in the history
  • Loading branch information
Bolei Zhou committed Sep 29, 2017
1 parent 6c6fc51 commit 735ad4f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 13 deletions.
1 change: 0 additions & 1 deletion script/rundissect_pytorch_external.sh
Expand Up @@ -4,7 +4,6 @@ DIR=pytorch_alexnet_imagenet
ARCH='alexnet' # [alexnet,squeezenet1_1,resnet18,...]. It should work for all the models in https://github.com/pytorch/vision/tree/master/torchvision/models
LAYERS="features"
DATASET=dataset/broden1_224
WEIGHTS="none"
NUMCLASSES=1000

# default setting
Expand Down
15 changes: 5 additions & 10 deletions src/netprobe_pytorch.py
Expand Up @@ -16,9 +16,6 @@
import sys

os.environ['GLOG_minloglevel'] = '2'
import caffe
from caffe.proto import caffe_pb2
from google.protobuf import text_format
from scipy.misc import imresize, imread
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage.interpolation import zoom
Expand All @@ -29,8 +26,6 @@
import rotate
import expdir

caffe.set_mode_gpu()
caffe.set_device(0)

def create_probe(
directory, dataset, definition, weights, mean, blobs,
Expand All @@ -47,8 +42,8 @@ def create_probe(
'''
directory: where to place the probe_conv5.mmap files.
data: the AbstractSegmentation data source to draw upon
definition: the filename for the caffe prototxt
weights: the filename for the caffe model weights
definition: the filename for the pytorch
weights: the filename for the weights
mean: to use to normalize rgb values for the network
blobs: ['conv3', 'conv4', 'conv5'] to probe
'''
Expand All @@ -57,7 +52,7 @@ def create_probe(
data = loadseg.SegmentationData(args.dataset)

# the network to dissect
if args.weights == None:
if args.weights == 'none':
# load the imagenet pretrained model
net = torchvision.models.__dict__[args.definition](pretrained=True)
else:
Expand Down Expand Up @@ -158,7 +153,7 @@ def hook_feature(module, input, output):
# previous feedforward case
inp = inp[:,::-1,:,:]
inp_tensor = V(torch.from_numpy(inp.copy()))
inp_tensor.div_(255.0*0.224) # approximately normalize the input to make the images scaled at around 1.
inp_tensor.div_(255.0*0.224) # hack: approximately normalize the input to make the images scaled at around 1.
inp_tensor = inp_tensor.cuda()
result = net.forward(inp_tensor)
# output the hooked feature
Expand Down Expand Up @@ -227,7 +222,7 @@ def report(txt):
import loadseg

parser = argparse.ArgumentParser(description=
'Probe a caffe network and save results in a directory.')
'Probe a pytorch network and save results in a directory.')
parser.add_argument(
'--directory',
default='.',
Expand Down
4 changes: 2 additions & 2 deletions src/viewprobe.py
Expand Up @@ -109,13 +109,13 @@ def generate_html_summary(self, layer,
col*(imsize+1):col*(imsize+1)+imsize,:] = vis
imfn = 'image/%s%s-%04d.jpg' % (
expdir.fn_safe(layer), gridname, unit)
imsave(self.ed.filename(['html', imfn]), tiled)
imsave(self.ed.filename(os.path.join('html', imfn)), tiled)
labels = '; '.join(['%s (%s, %f)' %
(name_pciou[c][unit], categories[c], score_pciou[c, unit])
for c in bestcat_pciou[:,unit]])
html.extend([
'<h6>%s unit %d: %s</h6>' % (layer, unit + 1, labels),
'<img src="%s" height="%d">' % (imfn, imscale)
'<img src="%s" height="%d">' % (os.path.join('html', imfn), imscale)
])
html.extend([
'</div>', '</body>', '</html>', ''])
Expand Down

0 comments on commit 735ad4f

Please sign in to comment.