# Table of Contents
 <p><div class="lev1 toc-item"><a href="#Implementation-of-memory-networks" data-toc-modified-id="Implementation-of-memory-networks-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Implementation of memory networks</a></div><div class="lev2 toc-item"><a href="#I-Module" data-toc-modified-id="I-Module-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>I Module</a></div><div class="lev1 toc-item"><a href="#Building-the-Memory-Network" data-toc-modified-id="Building-the-Memory-Network-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Building the Memory Network</a></div><div class="lev2 toc-item"><a href="#Parametrization" data-toc-modified-id="Parametrization-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>Parametrization</a></div><div class="lev2 toc-item"><a href="#Network-implementation" data-toc-modified-id="Network-implementation-2.2"><span class="toc-item-num">2.2&nbsp;&nbsp;</span>Network implementation</a></div><div class="lev1 toc-item"><a href="#Memory-module" data-toc-modified-id="Memory-module-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Memory module</a></div><div class="lev2 toc-item"><a href="#Memory-module-testing" data-toc-modified-id="Memory-module-testing-3.1"><span class="toc-item-num">3.1&nbsp;&nbsp;</span>Memory module testing</a></div><div class="lev1 toc-item"><a href="#Inference-module" data-toc-modified-id="Inference-module-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Inference module</a></div><div class="lev2 toc-item"><a href="#Inference-module-testing" data-toc-modified-id="Inference-module-testing-4.1"><span class="toc-item-num">4.1&nbsp;&nbsp;</span>Inference module testing</a></div><div class="lev1 toc-item"><a href="#Learning-algorithm" data-toc-modified-id="Learning-algorithm-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Learning algorithm</a></div><div class="lev2 toc-item"><a href="#Margin-Ranking-Loss-criterion" data-toc-modified-id="Margin-Ranking-Loss-criterion-5.1"><span class="toc-item-num">5.1&nbsp;&nbsp;</span>Margin Ranking Loss criterion</a></div><div class="lev2 toc-item"><a href="#MR-test" data-toc-modified-id="MR-test-5.2"><span class="toc-item-num">5.2&nbsp;&nbsp;</span>MR test</a></div><div class="lev2 toc-item"><a href="#Visualizing-the-implemented-network" data-toc-modified-id="Visualizing-the-implemented-network-5.3"><span class="toc-item-num">5.3&nbsp;&nbsp;</span>Visualizing the implemented network</a></div><div class="lev1 toc-item"><a href="#Testing-the-Memory-Network" data-toc-modified-id="Testing-the-Memory-Network-6"><span class="toc-item-num">6&nbsp;&nbsp;</span>Testing the Memory Network</a></div><div class="lev1 toc-item"><a href="#Scrap" data-toc-modified-id="Scrap-7"><span class="toc-item-num">7&nbsp;&nbsp;</span>Scrap</a></div><div class="lev2 toc-item"><a href="#Copying-a-clone-to-put-it-in-a-memory" data-toc-modified-id="Copying-a-clone-to-put-it-in-a-memory-7.1"><span class="toc-item-num">7.1&nbsp;&nbsp;</span>Copying a clone to put it in a memory</a></div><div class="lev2 toc-item"><a href="#Multiplying-tensors" data-toc-modified-id="Multiplying-tensors-7.2"><span class="toc-item-num">7.2&nbsp;&nbsp;</span>Multiplying tensors</a></div><div class="lev1 toc-item"><a href="#Scrap" data-toc-modified-id="Scrap-8"><span class="toc-item-num">8&nbsp;&nbsp;</span>Scrap</a></div><div class="lev1 toc-item"><a href="#Scrap-on-nnGraph-operations" data-toc-modified-id="Scrap-on-nnGraph-operations-9"><span class="toc-item-num">9&nbsp;&nbsp;</span>Scrap on nnGraph operations</a></div>

In [29]:
require 'nn';
require 'nngraph';
require 'utils.OneHot'

# Implementation of memory networks

Memory network is a sequence of 4 modules that allows better memory addressing for tasks such as Question Answering

## I Module

**input feature map** – converts the incoming input to the internal feature representation


---

OneHot Vectorization and word2vec

# Building the Memory Network

## Parametrization

In [82]:
SEQ_LENGTH = 5
VOCAB_SIZE = 30
MEM_SIZE = 3
NUM_MEM = 2
FEATURE_SIZE = 50

## Network implementation

In [73]:
function create_network(SEQ_LENGTH,VOCAB_SIZE,MEM_SIZE,NUM_MEM)
    ------------------ Initialization -------------------
    SEQ_LENGTH = SEQ_LENGTH or 5
    MEM_SIZE = MEM_SIZE or 3
    NUM_MEM = NUM_MEM or 2
    VOCAB_SIZE = VOCAB_SIZE or 30
    ------------------ I Module -------------------
    local mem_net = nn.Sequential()
    local branch_net = nn.ConcatTable()
    local mlp = nn.Sequential()
    local net = nn.Parallel(1,1)
    for i=1,SEQ_LENGTH do
        net:add(OneHot(VOCAB_SIZE))
    end
    mlp:add(net)    
    ------------------ G Module -------------------    
    -- m = MemoryModule.new(NUM_MEM,MEM_SIZE,VOCAB_SIZE)
    local g_mod = MemoryModule.new(NUM_MEM,MEM_SIZE,VOCAB_SIZE)
    mlp:add(g_mod)
    ------------------ O Module -------------------    
    local o_mod = InferenceModule.new(VOCAB_SIZE,3*VOCAB_SIZE)
    branch_net:add(mlp)
    branch_net:add(nn.Identity())
    mem_net:add(branch_net)
    mem_net:add(o_mod)
    --[[
    ------------------ R Module -------------------
    table.insert(outputs,o_mod)
    ]]
    return mem_net
end 

In [74]:
local mlp = create_network()
--print(mlp)

{
  1 : DoubleTensor - size: 3x30
  2 : DoubleTensor - size: 3x30
}


In [75]:
local mlp = create_network()
local res = mlp:forward(torch.Tensor(5,30):fill(1))
print('-----')
print(res)

{
  1 : DoubleTensor - size: 3x30
  2 : DoubleTensor - size: 3x30
}


-----	
{
  1 : DoubleTensor - size: 3x90
  2 : DoubleTensor - size: 5x90
}


# Memory module

In [337]:
MemoryModule, parent = torch.class('MemoryModule','nn.Module')

function MemoryModule:__init(NUM_MEM,MEM_SIZE,VOCAB_SIZE)
    parent.__init(self)
    self.num_mem = NUM_MEM
    self.mem_size = MEM_SIZE
    self.memory = {}
    for i=1,NUM_MEM do table.insert(self.memory,torch.Tensor(MEM_SIZE,VOCAB_SIZE):fill(0)) end
end

function MemoryModule:updateOutput(input)
    -- Replace the index memory with the input it as received
    assert(input:size(2) == self.memory[1]:size(2), "input size and memory size are differents")
    local input = input:clone()dd
    local loaded_mem = 0
    for i=1,#self.memory do
       for j=1,self.mem_size do
            if j + loaded_mem > input:size(1) then
                break
            end
            self.memory[i][{j}] = input[{j + loaded_mem,{}}]
        end
        loaded_mem = loaded_mem + self.mem_size
    end
    return self.memory
end

function MemoryModule:getIndex(index)
    return self.memory[index]
end

function MemoryModule:getMemorySize()
    return #self.memory
end

function MemoryModule:getMemory()
    return nn.JoinTable(1):forward(mem)
end

## Memory module testing

In [338]:
m = MemoryModule.new(NUM_MEM,MEM_SIZE,VOCAB_SIZE)
mem = m:forward(r)

{
  1 : DoubleTensor - size: 3x30
  2 : DoubleTensor - size: 3x30
}


# Inference module

In [112]:
InferenceModule, parent = torch.class('InferenceModule','nn.Module')

function InferenceModule:__init(voc_size, feature_dim)
    parent.__init(self)
    local inputs = {}
    local outputs = {}
    table.insert(inputs, nn.Identity()())
    table.insert(inputs, nn.Identity()())
    -----  
    local lin1 = nn.Linear(voc_size,feature_dim)(inputs[1])
    local lin2 = nn.Linear(voc_size,feature_dim)(inputs[2])
    table.insert(outputs,lin1)
    table.insert(outputs,lin2)
    self.mlp = nn.gModule(inputs, outputs) 
end

function InferenceModule:updateOutput(input)
    local ind = input[3] or 1
    local input1 = input[1][ind]:clone()
    local input2 = input[2]:clone()
    local i4 = self.mlp:forward{input1,input2}
    local ll = i4[2]:transpose(1,2)
    local lll = i4[1]:reshape(1,90)
    local glo = lll*ll
    return glo
end

## Inference module testing

In [113]:
infer = InferenceModule.new(30,3*30)

In [116]:
i = torch.Tensor(3,30)
ii = torch.Tensor(40,30)
iii = {i,ii,2}

In [None]:
local i4 = infer:forward(iii)
print(i4:size())

# Learning algorithm

## Margin Ranking Loss criterion

Memory Networks training is based on Stochastic Gradient Descent and [Margin Ranking](https://github.com/torch/nn/blob/master/doc/criterion.md#nn.MarginRankingCriterion) Loss that is already implemented in torch

## MR test

In [None]:
p1_mlp = nn.Linear(5, 2)
p2_mlp = p1_mlp:clone('weight', 'bias')

In [120]:
prl = nn.ParallelTable()
prl:add(p1_mlp)
prl:add(p2_mlp)

mlp1 = nn.Sequential()
mlp1:add(prl)
mlp1:add(nn.DotProduct())

mlp2 = mlp1:clone('weight', 'bias')

mlpa = nn.Sequential()
prla = nn.ParallelTable()
prla:add(mlp1)
prla:add(mlp2)
mlpa:add(prla)

crit = nn.MarginRankingCriterion(0.1)

x=torch.randn(5)
y=torch.randn(5)
z=torch.randn(5)

-- Use a typical generic gradient update function
function gradUpdate(mlp, x, y, criterion, learningRate)
   local pred = mlp:forward(x)
   local err = criterion:forward(pred, y)
   local gradCriterion = criterion:backward(pred, y)
   mlp:zeroGradParameters()
   mlp:backward(x, gradCriterion)
   mlp:updateParameters(learningRate)
end

for i = 1, 100 do
   gradUpdate(mlpa, {{x, y}, {x, z}}, 1, crit, 0.01)
   if true then
      o1 = mlp1:forward{x, y}[1]
      o2 = mlp2:forward{x, z}[1]
      o = crit:forward(mlpa:forward{{x, y}, {x, z}}, 1)
      print(o1, o2, o)
   end
end

-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	

-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	

0	
-0.27708781058275	-0.40454109808238	

0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	

0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	

0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	
-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


-0.27708781058275	-0.40454109808238	0	


## Visualizing the implemented network

In [80]:
-- some handy models are defined in optnet.models
-- like alexnet, googlenet, vgg and resnet

local mlp = create_network()

generateGraph = require 'optnet.graphgen'

-- visual properties of the generated graph
-- follows graphviz attributes
g = generateGraph(mlp)
graph.dot(g,"MemoryNetwork","MemoryNetwork")

{
  1 : DoubleTensor - size: 3x30
  2 : DoubleTensor - size: 3x30
}


...rs/david/torch/install/share/lua/5.1/optnet/graphgen.lua:155: bad argument #1 to 'ipairs' (table expected, got nil)
stack traceback:
	[C]: in function 'ipairs'
	...rs/david/torch/install/share/lua/5.1/optnet/graphgen.lua:155: in function 'createBoundaryNode'
	...rs/david/torch/install/share/lua/5.1/optnet/graphgen.lua:220: in function 'generateGraph'
	[string "-- some handy models are defined in optnet.mo..."]:10: in main chunk
	[C]: in function 'xpcall'
	/Users/david/torch/install/share/lua/5.1/itorch/main.lua:210: in function </Users/david/torch/install/share/lua/5.1/itorch/main.lua:174>
	/Users/david/torch/install/share/lua/5.1/lzmq/poller.lua:80: in function 'poll'
	/Users/david/torch/install/share/lua/5.1/lzmq/impl/loop.lua:307: in function 'poll'
	/Users/david/torch/install/share/lua/5.1/lzmq/impl/loop.lua:325: in function 'sleep_ex'
	/Users/david/torch/install/share/lua/5.1/lzmq/impl/loop.lua:370: in function 'start'
	/Users/david/torch/install/share/lua/5.1/itorch/main.lua:389: in main chunk
	[C]: in function 'require'
	(command line):1: in main chunk
	[C]: at 0x01061f8a10: 

# Testing the Memory Network

In [418]:
a = torch.Tensor(1,5)
for i=1,a:size(2) do
    a[1][i] = i
end

In [419]:
mlp:forward(a)

{
  1 : DoubleTensor - size: 3x30
  2 : DoubleTensor - size: 3x30
}


# Scrap

## Copying a clone to put it in a memory

In [192]:
local a = torch.Tensor(10,4)
local mem = torch.Tensor(16,4)

print('valeur de a \n')
print(a)
print('valeur de mem \n')
print(mem)

for i=1,a:size(1) do
    --print(i)
    local cl = a[{i,{}}]:clone()
    mem[{i}] = cl
end

print('valeur de a \n')
print(a)
print('valeur de mem \n')
print(mem)

valeur de a 
	
  2.0000e+00   2.0000e+00  2.2727e-322   0.0000e+00
  0.0000e+00  6.2043e+223  6.2104e+175  1.3662e+161
 7.6284e+228  1.0626e+248  3.3552e-110  1.6934e-152
 7.6284e+228  1.0626e+248  3.3553e-110  7.3587e+223
 3.2167e+257  5.9526e+135  1.7530e+243  4.0719e+223
 1.4243e+261  2.6099e+180  4.1005e+223  7.3587e+223
 4.0719e+223   7.1415e-13  6.8498e+180  4.1114e+223
 1.7258e+243  4.0719e+223  1.3662e+161  7.1345e+159
 2.2476e+142  2.0289e-110  3.7768e+180   2.9070e-14
 6.0143e+175  7.3587e+223  3.2167e+257  5.9526e+135
[torch.DoubleTensor of size 10x4]

valeur de mem 
	
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
[torch.DoubleTensor of size 16x4]

valeur de a 
	
  2.0000e+00   2.0000e+00  2.2727e-322   0.0000e+00
  0.0000e+00  6.2043e+223  6.2104e+175  1.3662e+161
 7.6284e+228  1.0626e+248  3.3552e-110  1.6934e-152
 7.6284e+228  1

## Multiplying tensors

In [439]:
a = torch.Tensor(10,2)
aa = torch.Tensor(2,8)

# Scrap

In [96]:
print(r)
print(r:select(1,2))

Columns 1 to 26
 0  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
 0  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
 0  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
 0  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
 0  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0

Columns 27 to 30
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
[torch.DoubleTensor of size 5x30]

 0
 1
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
[torch.DoubleTensor of size 30]



# Scrap on nnGraph operations

In [711]:
local inputs = {}
local outputs = {}
local par = {}
for i=1,6 do 
    table.insert(inputs,nn.Identity()()) 
    par[#par+1] = nn.ParallelTable()(inputs[i])
end
local out = nn.JoinTable()(par)
table.insert(outputs,out)
l = nn.gModule(inputs,outputs)

In [712]:
inp = nn.SplitTable(1):forward(torch.LongTensor(1,6):squeeze():fill(1))
l:forward{unpack(inp)}

/Users/david/torch/install/share/lua/5.1/nn/JoinTable.lua:13: attempt to compare nil with number
stack traceback:
	/Users/david/torch/install/share/lua/5.1/nn/JoinTable.lua:13: in function '_getPositiveDimension'
	/Users/david/torch/install/share/lua/5.1/nn/JoinTable.lua:22: in function 'func'
	...rs/david/torch/install/share/lua/5.1/nngraph/gmodule.lua:345: in function 'neteval'
	...rs/david/torch/install/share/lua/5.1/nngraph/gmodule.lua:380: in function 'forward'
	[string "inp = nn.SplitTable(1):forward(torch.LongTens..."]:2: in main chunk
	[C]: in function 'xpcall'
	/Users/david/torch/install/share/lua/5.1/itorch/main.lua:210: in function </Users/david/torch/install/share/lua/5.1/itorch/main.lua:174>
	/Users/david/torch/install/share/lua/5.1/lzmq/poller.lua:80: in function 'poll'
	/Users/david/torch/install/share/lua/5.1/lzmq/impl/loop.lua:307: in function 'poll'
	/Users/david/torch/install/share/lua/5.1/lzmq/impl/loop.lua:325: in function 'sleep_ex'
	/Users/david/torch/install/share/lua/5.1/lzmq/impl/loop.lua:370: in function 'start'
	/Users/david/torch/install/share/lua/5.1/itorch/main.lua:389: in main chunk
	[C]: in function 'require'
	(command line):1: in main chunk
	[C]: at 0x01061f8a10: 