In [25]:
require 'torch'
require 'nn'
require 'nngraph'
require 'load'
CVAE = require 'CVAE'
require 'KLDCriterion'
require 'Sampler'

In [26]:
kld = nn.KLDCriterion()
bce = nn.BCECriterion()
bce.sizeAverage = false

In [27]:
data = loadmnist()
train = data.train

In [28]:
batch_size = 3
indices = torch.randperm(train:size(1)):long():split(batch_size)

In [None]:
xs = 784
ys = 784
zs = 2
hs = 400
encoder = CVAE.create_encoder_network(xs, ys, zs, hs)
prior = CVAE.create_prior_network(xs, zs, hs)
decoder = CVAE.create_decoder_network(xs, ys, zs, hs)
sampler = nn.Sampler()

# Manual forwarding

In [None]:
minix = train:index(1, indices[1])
miniy = train:index(1, indices[1])
-- mask first quarter of x
minix[{{},{1,392}}] = 0.

In [None]:
code_blueprint = encoder:forward({minix, miniy})
prior_code_blueprint = prior:forward(minix)
code = sampler:forward(code_blueprint)
recon = decoder:forward({minix, code})

# Constructing a graph

In [None]:
x_input = nn.Identity()()
y_input = nn.Identity()()
mu, logv = encoder({x_input, y_input}):split(2)
pmu, plogv = prior(x_input):split(2)
code = sampler({mu, logv})
recon = decoder({x_input, code})
model = nn.gModule({x_input, y_input}, {recon, mu, logv, pmu, plogv})

In [None]:
recon, mu, logv, pmu, plogv = unpack(model:forward({minix, miniy}))

In [None]:
bce:forward(recon, miniy)

In [None]:
kld:forward({pmu, plogv}, {mu, logv})

In [None]:
drecon = bce:backward(recon, miniy)
dpmu, dplogv, dmu, dlogv = unpack(kld:backward({pmu, plogv}, {mu, logv}))

In [None]:
error_grads = {drecon, dmu, dlogv, dpmu, dplogv}

In [None]:
model:zeroGradParameters()
model:backward({minix, miniy}, error_grads)

In [None]:
params, grads = model:getParameters()

In [None]:
grads:size()

# Check Encapsulation

In [32]:
indices[1]

 29973
 46402
 14022
[torch.LongTensor of size 3]



In [33]:
indices[#indices] = nil

In [35]:
#indices

16666	


In [43]:
indices[#indices] = nil
print(#indices)

16659	


In [20]:
opfunc = function(var)
    print(var)
end
naiveopfunc = function()
    print(var)
end

In [23]:
i = 0
while i < 10 do
    local var = i
    local zfunc = function()
        opfunc(var)
    end
    zfunc(var)
    naiveopfunc()
    i = i + 1
end


0	
nil	
1	
nil	
2	
nil	
3	
nil	
4	
nil	
5	
nil	
6	
nil	
7	
nil	
8	
nil	
9	
nil	


# Inplace v. cloned when indexing

In [44]:
x = torch.randn(3, 3)

In [46]:
y = x:index(1, torch.Tensor(1):fill(2):long())

In [51]:
x

-0.0577  0.4570 -0.6838
-0.8594 -0.2602  1.4491
-0.3079 -1.1362  1.1505
[torch.DoubleTensor of size 3x3]



In [48]:
y

-0.8594 -0.2602  1.4491
[torch.DoubleTensor of size 1x3]



In [52]:
y[{1,1}] = 1

In [53]:
y

 1.0000 -0.2602  1.4491
[torch.DoubleTensor of size 1x3]



In [54]:
x

-0.0577  0.4570 -0.6838
-0.8594 -0.2602  1.4491
-0.3079 -1.1362  1.1505
[torch.DoubleTensor of size 3x3]

