In [2]:
-- Configuration
torch.manualSeed(1)
torch.setdefaulttensortype("torch.FloatTensor")
package.path = package.path .. ";models/?.lua"
require "nn"
require "image"
require "visualize_features_utils"
require "SpatialReplicatePadding"

In [3]:
-- Initialize OverFeat model
overfeat_factory = require "OverFeatFactory"
overfeat_params = {type = "small", pretrained = true, weights_path = "models/net_weight_0"}
conv_net, softmax_net = overfeat_factory.create(overfeat_params)
-- Remove non-convolutional layers
for i=1,5 do conv_net:remove() end
-- "Fix" convolutional layers
new_conv_net = nn.Sequential()
for i,m in pairs(conv_net.modules) do
    -- Check layer
    if torch.typename(m) == "nn.SpatialConvolutionMM" then
        -- Change stride
        m.dH = 1
        m.dW = 1
        -- Remove padding
        m.padH = 0
        m.padW = 0
        -- Add replicate-padding layer to conv net
        new_conv_net:add(nn.SpatialReplicatePadding((m.kW-1)/2, (m.kH-1)/2))
    elseif torch.typename(m) == "nn.SpatialMaxPooling" then
        -- Make the module know some unpoolers will be attached
        -- (it won't set iwidth and iheight during forward() otherwise)
        local dummy = nn.SpatialMaxUnpooling(m)
    end
    -- Add module
    new_conv_net:add(m)
end
conv_net = new_conv_net

In [None]:
-- Test image (bee)
dim = 231
img = image.load("bee.jpg"):mul(255)
img = image.scale(img, "^" .. dim)
img = image.crop(img, "c", dim, dim)
img = img:floor()
orig_img = img:clone()
img:add(-118.380948):div(61.896913)
itorch.image(orig_img)
-- Forward input image
conv_net:forward(img)

In [None]:
-- Show maps
resize = 300
-- Interesting layers
observed_layers = {4, 8, 11, 14, 18} --, 6, 8, 10, 13} -- 5
-- Number of neurons to show
limits = {6, 6, 6, 6, 6}
-- Configure map overlap
allow_map_overlap = true
overlap_margin = 5
-- Process all layers
for target_layer_idx,target_layer in pairs(observed_layers) do
    print("Layer " .. target_layer_idx .. " (" .. target_layer .. ")")
    -- Build reconstruction network
    local deconv_net = buildReconstructionNet(conv_net, target_layer)
    -- Select layer and map
    target_map = nil
    -- Get active neurons
    local active_values, active_coords = getActiveNeurons(conv_net, target_layer, target_map)
    -- Filter neurons
    -- This may take a while if overlap is not allowed and there are not enough non-overlapping
    local n = limits[target_layer_idx]
    active_values, active_coords = filterNeurons(active_values, active_coords, n, overlap_margin, allow_map_overlap)
    -- Compute reconstruction inputs
    rec_inputs = getReconstructionInput(active_values, active_coords, conv_net, target_layer, target_map)
    -- Compute reconstruction outputs
    rec_outputs = {}
    orig_outputs = {}
    crops = {}
    for i=1,#rec_inputs do
        -- Reconstruct
        rec_output = deconv_net:forward(rec_inputs[i])
        -- Compute activity mask
        local activity_mask = localizeActivity(rec_output, {activity_thr = 0.0001})
        -- Compute crop points
        local crop = getMaskCoords(activity_mask)
        -- Crop reconstruction and scale
        rec_output = image.crop(rec_output, unpack(crop))
        rec_output = image.scale(rec_output, resize, resize)
        -- Crop original and scale
        local orig_output = image.crop(orig_img, unpack(crop))
        orig_output = image.scale(orig_output, resize, resize)
        -- Add to tables
        table.insert(rec_outputs, rec_output)
        table.insert(orig_outputs, orig_output)
        table.insert(crops, crop)
        collectgarbage()
    end
    -- Highlight active neurons in original images
    draw_crops = orig_img:clone()
    for _,crop in pairs(crops) do
        -- Draw rectangle
        drawRectangle(draw_crops, {crop[1], crop[2]}, {crop[3], crop[4]}, torch.Tensor({255,0,0}))
    end
    itorch.image(draw_crops, {min = 0, max = 255})
    -- Show reconstruction besides original patches
    rec_outputs_img = image.toDisplayTensor({input = rec_outputs, padding = 2})
    orig_outputs_img = image.toDisplayTensor({input = orig_outputs, padding = 2, min = 0, max = 255})
    itorch.image(torch.cat(rec_outputs_img, orig_outputs_img, 3))
    -- Fixes some problems on my machine which caused a bad ordering between prints and itorch.image() calls
    -- Uncomment if you have the sys package (and add a "require")
    --sys.sleep(3) 
end