In [2]:
require 'nn'
require 'nngraph'
require 'image'

In [125]:
-- Convolution = cudnn.SpatialConvolution
Convolution = nn.SpatialConvolution

-- Avg = cudnn.SpatialAveragePooling
Avg = nn.SpatialAveragePooling

-- ReLU = cudnn.ReLU
ReLU = nn.ReLU

Max = nn.SpatialMaxPooling

SBatchNorm = nn.SpatialBatchNormalization


function vgg()
    model = nn.Sequential()
    model:add(Convolution(  1,  64, 3,3, 2,2, 1,1))
    model:add(ReLU(true))
    model:add(Convolution( 64, 128, 3,3, 1,1, 1,1))
    model:add(ReLU(true))
    model:add(Convolution(128, 128, 3,3, 2,2, 1,1))
    model:add(ReLU(true))
    model:add(Convolution(128, 256, 3,3, 1,1, 1,1))
    model:add(ReLU(true))
    model:add(Convolution(256, 256, 3,3, 2,2, 1,1))
    model:add(ReLU(true))
    model:add(Convolution(256, 512, 3,3, 1,1, 1,1))
    model:add(ReLU(true))
    return model
end

function global_feature()
    --[[
        
    ]]--
    model = nn.Sequential()
    model:add(Convolution(512, 512, 3,3, 2,2, 1,1))
    model:add(ReLU(true))
    model:add(Convolution(512, 512, 3,3, 1,1, 1,1))
    model:add(ReLU(true))
    model:add(Convolution(512, 512, 3,3, 2,2, 1,1))
    model:add(ReLU(true))
    model:add(Convolution(512, 512, 3,3, 1,1, 1,1))
    model:add(ReLU(true))
    model:add(nn.View(-1, 25088))
    model:add(nn.Linear(25088, 1024))
    model:add(ReLU(true))
    model:add(nn.Linear(1024, 512))
    model:add(ReLU(true))
    return model
end

function classification(nLabels)
    model = nn.Sequential()
    model:add(nn.Linear(512, 512))
    model:add(ReLU(true))
    model:add(nn.Dropout(0.5))
    
    model:add(nn.Linear(512, nLabels))
    model:add(nn.LogSoftMax())
    return model
end


function mid_level_feature()
    model = nn.Sequential()
    model:add(Convolution(512, 512, 3, 3, 1, 1, 1, 1))
    model:add(ReLU(true))
    model:add(Convolution(512, 256, 3, 3, 1, 1, 1, 1))
    model:add(ReLU(true))
    return model
end

function img2feat(nfeatures)
    model = nn.Sequential()
    model:add(nn.Linear(512, 256))
    model:add(ReLU(true))
    model:add(nn.Replicate(nfeatures, 2, 1))
    model:add(nn.Replicate(nfeatures, 2, 1))
    return model
end

function upsample_and_color()
    model = nn.Sequential()
    
    model:add(Convolution(512, 256, 3,3, 1,1, 1,1))
    model:add(ReLU(true))
    model:add(Convolution(256, 128, 3,3, 1,1, 1,1))
    model:add(ReLU(true))
    model:add(nn.SpatialUpSamplingNearest(2))
    
    model:add(Convolution(128, 64, 3,3, 1,1, 1,1))
    model:add(ReLU(true))
    model:add(Convolution( 64, 64, 3,3, 1,1, 1,1))
    model:add(ReLU(true))
    model:add(nn.SpatialUpSamplingNearest(2))
    
    model:add(Convolution( 64, 32, 3,3, 1,1, 1,1))
    model:add(ReLU(true))
    model:add(Convolution( 32,  2, 3,3, 1,1, 1,1))
    model:add(nn.Sigmoid())
    return model
end



In [111]:
m1 = vgg()
m2 = global_feature()

nSamples = 1
data = torch.rand(nSamples,1,224,224)
temp = m1:forward(data)
temp = m2:forward(temp)
print(temp:size())
res1 = temp

   1
 512
[torch.LongStorage of size 2]



In [112]:
k1 = vgg()
k2 = mid_level_feature()

data = torch.rand(nSamples,1,256,256)
temp = k1:forward(data)
temp = k2:forward(temp)
print(temp:size())
res2 = temp

   1
 256
  32
  32
[torch.LongStorage of size 4]



In [113]:
r = img2feat(32)
res3 = r:forward(res1)
print(res3:size())

   1
 256
  32
  32
[torch.LongStorage of size 4]



In [114]:
res4 = nn.JoinTable(2):forward{res2, res3}
print(res4:size())

   1
 512
  32
  32
[torch.LongStorage of size 4]



In [115]:
cnet = upsample_and_color()
res5 = cnet:forward(res4)


In [116]:
print(res5:size())

   1
   2
 128
 128
[torch.LongStorage of size 4]



In [126]:
local model = classification(10)
local predict_label = model:forward(res1)
print(predict_label)

-2.3077 -2.3106 -2.3009 -2.2907 -2.3012 -2.3012 -2.2927 -2.3289 -2.3125 -2.2802
[torch.DoubleTensor of size 1x10]



In [129]:
function build_net(h, w)
    -- two inputs
    local origin_img = torch.rand(nSamples,1,256,256)
    local scaled_img = torch.rand(nSamples,1,224,224)

    -- local feature / original image
    local low_feat1 = vgg()()
    local mid_feat = mid_level_feature()(low_feat1)

    -- global feature / scaled image
    local low_feat2 = vgg()()
    local global_feat = global_feature()(low_feat2)
    
    local fusion_layer = img2feat()(global_feat)
    
    -- classification
    local predict_label = classification(10)(global_feat)
    
    
    local mixed_res = nn.JoinTable(2)({mid_feat, fusion_layer})

    local predict_YUV = upsample_and_color()(mixed_res)

    local img = torch.rand(nSamples,1,224,224)
    return nn.gModule({low_feat1, low_feat2}, {predict_YUV, predict_label})
end

In [130]:
local net = build_net(1,1)
graph.dot(net.fg, 'temp', 'temp')
itorch.image('temp.svg')