In [230]:
import torch
import numpy

In [231]:
def generate_points(dimension, n_points):
    ''' Generate a set of input points.  Continuously sampled during training.'''
    return torch.rand(n_points, dimension)

In [260]:
class func_3layer(torch.nn.Module):
    '''
    Three layer network for phi, as per the paper (in appendix)
    '''
    def __init__(self, input_dim, output_dim, bias=True):
        torch.nn.Module.__init__(self)
        
        self.layer1 = torch.nn.Linear(input_dim, 100, bias=bias)
        self.layer2 = torch.nn.Linear(100, 100, bias=bias)
        self.layer3 = torch.nn.Linear(100, output_dim, bias=bias)
        
        self.activation = torch.relu
        
    def forward(self, inputs):
        
        x = self.layer1(inputs)
        x = self.activation(x)
        x = self.layer2(x)
        x = self.activation(x)
        x = self.layer3(x)
        
        return x

class func_2layer(torch.nn.Module):
    '''
    Two layer network for rho, as per the paper (in appendix)
    '''
    def __init__(self, input_dim, output_dim, bias=True):
        torch.nn.Module.__init__(self)
        
        self.layer1 = torch.nn.Linear(input_dim, 100, bias=bias)
        self.layer2 = torch.nn.Linear(100, output_dim, bias=bias)
        
        self.activation = torch.relu
        
    def forward(self, inputs):
        
        x = self.layer1(inputs)
        x = self.activation(x)
        x = self.layer2(x)
        
        return x

In [233]:
class func_symm(torch.nn.Module):
    '''
    Explicitly symmetric function (essentially, no rho)
    '''
    def __init__(self, input_dim, output_dim, bias=True):
        torch.nn.Module.__init__(self)
        
        self.layer1 = torch.nn.Linear(input_dim, 256, bias=bias)
        self.layer2 = torch.nn.Linear(256, 256, bias=True)
        self.layer3 = torch.nn.Linear(256, output_dim, bias=bias)
        
        self.activation = torch.relu
#        self.input_dim = input_dim
        
    def forward(self, inputs):
        if (2 == 2):
            inputs_12 = torch.zeros_like(inputs)
            inputs_12[:,0] = inputs[:,1]
            inputs_12[:,1] = inputs[:,0]
            
            x = self.layer1(inputs)
            x = self.activation(x)
            x = self.layer2(x)
            x = self.activation(x)
            x = self.layer3(x) 
            
            x_12 = self.layer1(inputs_12)
            x_12 = self.activation(x_12)
            x_12 = self.layer2(x_12)
            x_12 = self.activation(x_12)
            x_12 = self.layer3(x_12)
            
            output = ( x + x_12 ) / 2
        else: 
            output = 0
            print("PROBLEM, ONLY 2-D INPUTS FOR NOW")
        return output

In [234]:
# Define the explicitly symmetric function
symm = func_symm(input_dim=DIMENSION, output_dim=1)

In [235]:
# This should generate points between 0 and 1 at random:
torch.manual_seed(19)
xy = generate_points(2,10)
x = xy[:,0]
y = xy[:,1]
print("x=",x)
print("y=",y)
print("xy=", xy)

x= tensor([0.9686, 0.8799, 0.2161, 0.3617, 0.8503, 0.2479, 0.0261, 0.4880, 0.5819,
        0.7291])
y= tensor([0.1999, 0.6622, 0.1192, 0.0167, 0.9620, 0.2946, 0.0891, 0.0860, 0.0066,
        0.4305])
xy= tensor([[0.9686, 0.1999],
        [0.8799, 0.6622],
        [0.2161, 0.1192],
        [0.3617, 0.0167],
        [0.8503, 0.9620],
        [0.2479, 0.2946],
        [0.0261, 0.0891],
        [0.4880, 0.0860],
        [0.5819, 0.0066],
        [0.7291, 0.4305]])


In [236]:
# Now, let's compute an objective function:
def real_function(xy):
    # This is symmetric by construction
    return (xy[:,0] - xy[:,1] ) **2
#     return abs(x - y)

In [237]:
real_function(xy)

tensor([0.5910, 0.0474, 0.0094, 0.1190, 0.0125, 0.0022, 0.0040, 0.1617, 0.3310,
        0.0892])

In [238]:
#symm = func_symm(input_dim=2, output_dim=1)
target = torch.t(symm(xy))[0]
print(target)

tensor([0.0942, 0.1036, 0.0643, 0.0673, 0.1070, 0.0779, 0.0522, 0.0765, 0.0766,
        0.0997], grad_fn=<SelectBackward>)


In [239]:
# For training the net, create optimizers:
params = list(symm.parameters())
optimizer =  torch.optim.Adam(params, lr=0.001)

In [248]:
BATCH_SIZE = 4096
xy = generate_points(2,BATCH_SIZE)
for i in range(10000):
    
    optimizer.zero_grad()  
    correct_answer = real_function(xy)
#    approximation = symm(xy)
    approximation = torch.t(symm(xy))[0]    
#     print(correct_answer)
#     print(approximation)
    
    loss = torch.nn.MSELoss()(target=correct_answer, input=approximation)
    if i % 1000 == 0:
#        print("correct_answer", correct_answer)
#        print("approximation", approximation)
        print("loss=",loss)
    loss.backward()
    optimizer.step()

loss= tensor(7.7134e-05, grad_fn=<MseLossBackward>)
loss= tensor(6.7934e-08, grad_fn=<MseLossBackward>)
loss= tensor(1.7408e-06, grad_fn=<MseLossBackward>)
loss= tensor(4.5348e-08, grad_fn=<MseLossBackward>)
loss= tensor(4.4011e-08, grad_fn=<MseLossBackward>)
loss= tensor(3.9494e-08, grad_fn=<MseLossBackward>)
loss= tensor(7.1034e-08, grad_fn=<MseLossBackward>)
loss= tensor(3.4735e-08, grad_fn=<MseLossBackward>)
loss= tensor(3.5059e-08, grad_fn=<MseLossBackward>)
loss= tensor(2.9682e-08, grad_fn=<MseLossBackward>)


In [277]:
# This should generate points between 0 and 1 at random to check the correctness of the answer:
torch.manual_seed(19)
xy = generate_points(2,20)
x = xy[:,0]
y = xy[:,1]
correct_answer = real_function(xy)
approximation = torch.t(symm(xy))[0]
print("correct_answer", correct_answer)
print("approximation", approximation)

correct_answer tensor([0.5910, 0.0474, 0.0094, 0.1190, 0.0125, 0.0022, 0.0040, 0.1617, 0.3310,
        0.0892, 0.0051, 0.0786, 0.1675, 0.1472, 0.0839, 0.0013, 0.0018, 0.0209,
        0.1559, 0.4953])
approximation tensor([0.5905, 0.0478, 0.0092, 0.1180, 0.0122, 0.0021, 0.0042, 0.1616, 0.3310,
        0.0874, 0.0054, 0.0773, 0.1683, 0.1462, 0.0833, 0.0019, 0.0025, 0.0206,
        0.1556, 0.4954], grad_fn=<SelectBackward>)


In [259]:
# Check for symmetry
yx = torch.index_select(xy, 1, torch.LongTensor([1,0]))
correct_answer = real_function(yx)
approximation = torch.t(symm(yx))[0]
print("correct_answer", correct_answer)
print("approximation", approximation)

correct_answer tensor([0.5910, 0.0474, 0.0094, 0.1190, 0.0125, 0.0022, 0.0040, 0.1617, 0.3310,
        0.0892, 0.0051, 0.0786, 0.1675, 0.1472, 0.0839, 0.0013, 0.0018, 0.0209,
        0.1559, 0.4953])
approximation tensor([0.5915, 0.0474, 0.0094, 0.1190, 0.0123, 0.0021, 0.0039, 0.1617, 0.3311,
        0.0893, 0.0051, 0.0787, 0.1674, 0.1473, 0.0837, 0.0012, 0.0018, 0.0210,
        0.1559, 0.4951], grad_fn=<SelectBackward>)


In [296]:
# Set the dimension of input data and the latent size:
DIMENSION   = 1
LATENT_SIZE = 2

# Create the 3 NN to train
phi = func_3layer(input_dim=DIMENSION, output_dim=LATENT_SIZE)
rho = func_2layer(input_dim=LATENT_SIZE, output_dim=1)

In [297]:
# For training the net, create optimizers:
params = list(phi.parameters()) + list(rho.parameters())
optimizer =  torch.optim.Adam(params, lr=0.0001)
BATCH_SIZE = 16
xy = generate_points(2,BATCH_SIZE)
x = xy[:,0]
y = xy[:,1]
print("x=", x)
print("phi_x=", phi(x))

x= tensor([0.9357, 0.9449, 0.5773, 0.5564, 0.2720, 0.7857, 0.4088, 0.1350, 0.3043,
        0.1808, 0.9956, 0.1764, 0.4303, 0.6732, 0.4642, 0.5476])


RuntimeError: size mismatch, m1: [1 x 16], m2: [1 x 100] at ../aten/src/TH/generic/THTensorMath.cpp:41

In [312]:
# For training the net, create optimizers:
params = list(phi.parameters()) + list(rho.parameters())
optimizer =  torch.optim.Adam(params, lr=0.0001)
BATCH_SIZE = 1024
xy = generate_points(2, BATCH_SIZE)
x = xy[:,0].view(-1,1)
y = xy[:,1].view(-1,1)
#print("x=", x)
#print("phi_x=", phi(x))

for i in range(10000):
    optimizer.zero_grad()  
    correct_answer = real_function(xy)
    phi_x = phi(x)
#    print("phi_x", phi_x)
    phi_y = phi(y)
    approximation = torch.t(rho(phi_x + phi_y))[0]    
#    print("correct_answer", correct_answer)
#    print("approximation", approximation)
    
    loss = torch.nn.MSELoss()(target=correct_answer, input=approximation)
    if i % 500 == 0:
#        print("correct_answer", correct_answer)
#        print("approximation", approximation)
        print("loss=",loss)
    loss.backward()
    optimizer.step()

loss= tensor(0.0012, grad_fn=<MseLossBackward>)
loss= tensor(6.2478e-05, grad_fn=<MseLossBackward>)
loss= tensor(2.4542e-05, grad_fn=<MseLossBackward>)
loss= tensor(1.0793e-05, grad_fn=<MseLossBackward>)
loss= tensor(6.6826e-06, grad_fn=<MseLossBackward>)
loss= tensor(4.6332e-06, grad_fn=<MseLossBackward>)
loss= tensor(3.2047e-06, grad_fn=<MseLossBackward>)
loss= tensor(2.3724e-06, grad_fn=<MseLossBackward>)
loss= tensor(1.8893e-06, grad_fn=<MseLossBackward>)
loss= tensor(1.5880e-06, grad_fn=<MseLossBackward>)
loss= tensor(1.3524e-06, grad_fn=<MseLossBackward>)
loss= tensor(1.2036e-06, grad_fn=<MseLossBackward>)
loss= tensor(2.4821e-06, grad_fn=<MseLossBackward>)
loss= tensor(1.0527e-06, grad_fn=<MseLossBackward>)
loss= tensor(9.8178e-07, grad_fn=<MseLossBackward>)
loss= tensor(9.3043e-07, grad_fn=<MseLossBackward>)
loss= tensor(8.9621e-07, grad_fn=<MseLossBackward>)
loss= tensor(8.6352e-07, grad_fn=<MseLossBackward>)
loss= tensor(8.3760e-07, grad_fn=<MseLossBackward>)
loss= tensor(8.2

In [313]:
# This should generate points between 0 and 1 at random to check the correctness of the answer:
torch.manual_seed(19)
xy = generate_points(2,2)
x = xy[:,0].view(-1,1)
y = xy[:,1].view(-1,1)
correct_answer = real_function(xy)
phi_x = phi(x)
phi_y = phi(y)
approximation = torch.t(rho(phi_x + phi_y))[0]
print("correct_answer", correct_answer)
print("approximation", approximation)

correct_answer tensor([0.5910, 0.0474])
approximation tensor([0.5922, 0.0443], grad_fn=<SelectBackward>)
