In [1]:
import torch
import numpy

In [19]:
def generate_points(dimension, n_points):
    ''' Generate a set of input points.  Continuously sampled during training.'''
    return torch.rand(n_points, dimension)

In [20]:
class func_3layer(torch.nn.Module):
    '''
    Three layer network for phi, as per the paper (in appendix)
    '''
    def __init__(self, input_dim, output_dim, bias=False):
        torch.nn.Module.__init__(self)
        
        self.layer1 = torch.nn.Linear(input_dim, 1000, bias=bias)
        self.layer2 = torch.nn.Linear(1000, 1000, bias=bias)
        self.layer3 = torch.nn.Linear(1000, output_dim, bias=bias)
        
        self.activation = torch.relu
        
    def forward(self, inputs):
        
        x = self.layer1(inputs)
        x = self.activation(x)
        x = self.layer2(x)
        x = self.activation(x)
        x = self.layer3(x)
        
        return x

class func_2layer(torch.nn.Module):
    '''
    Two layer network for rho, as per the paper (in appendix)
    '''
    def __init__(self, input_dim, output_dim, bias=False):
        torch.nn.Module.__init__(self)
        
        self.layer1 = torch.nn.Linear(input_dim, 1000, bias=bias)
        self.layer2 = torch.nn.Linear(1000, output_dim, bias=bias)
        
        self.activation = torch.relu
        
    def forward(self, inputs):
        
        x = self.layer1(inputs)
        x = self.activation(x)
        x = self.layer2(x)
        
        return x
    


In [21]:
class func_symm(torch.nn.Module):
    '''
    Explicitly symmetric function (essentially, no rho)
    '''
    def __init__(self, input_dim, output_dim, bias=False):
        torch.nn.Module.__init__(self)
        
        self.layer1 = torch.nn.Linear(input_dim, 256, bias=bias)
#         self.layer2 = torch.nn.Linear(256, 256, bias=True)
        self.layer3 = torch.nn.Linear(256, output_dim, bias=bias)
        
        self.activation = torch.relu
        
    def forward(self, x, y):
        
        x = self.layer1(x)
        x = self.activation(x)
#         x = self.layer2(x)
#         x = self.activation(x)
        x = self.layer3(x)
        
        y = self.layer1(y)
        y = self.activation(y)
#         y = self.layer2(y)
#         y = self.activation(y)
        y = self.layer3(y)
        
        return x + y
    



In [43]:
# Set the dimension of input data and the latent size:
DIMENSION   = 1
LATENT_SIZE = 1

In [44]:
# Create the 3 NN to train
phi = func_3layer(input_dim=DIMENSION, output_dim=LATENT_SIZE)
rho = func_2layer(input_dim=LATENT_SIZE, output_dim=1)
symm = func_symm(input_dim=DIMENSION, output_dim=1)

In [45]:
# This should generate points between 0 and 1 at random:
x = generate_points(1,10)
y = generate_points(1,10)

In [46]:
x, y

(tensor([[0.1916],
         [0.2112],
         [0.3870],
         [0.8464],
         [0.3852],
         [0.6821],
         [0.3257],
         [0.6382],
         [0.2114],
         [0.0537]]), tensor([[0.9295],
         [0.8013],
         [0.6532],
         [0.0630],
         [0.4893],
         [0.9669],
         [0.7427],
         [0.6462],
         [0.3641],
         [0.6877]]))

In [72]:
# Now, let's compute an objective function:
def real_function(x, y):
    # This is symmetric by construction
    return (x - y ) **2
#     return abs(x - y)

In [73]:
# With the same input, it goes to 0:
print(x.shape)
print(real_function(x,y).shape)
real_function(x,y) - real_function(y, x)

torch.Size([10, 1])
torch.Size([10, 1])


tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]])

In [74]:
# Ten points by DIMENSION each
x.shape

torch.Size([10, 1])

In [75]:
# Ten points by LATENT_SPACe each
phi(x).shape

torch.Size([10, 1])

In [76]:
# Ten points by single-valued each
rho(phi(x)).shape

torch.Size([10, 1])

In [77]:
# Ten points by single-valued each
symm(x,y).shape

torch.Size([10, 1])

Train the decomposed networks:

In [78]:
# For training the net, create optimizers:
params = list(phi.parameters()) + list(rho.parameters())
optimizer =  torch.optim.Adam(params, lr=0.0001)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
loss_fn = torch.nn.MSELoss()



In [79]:
BATCH_SIZE = 32
for i in range(1000):
    # Sample fresh points
    x = generate_points(1,BATCH_SIZE)
    y = generate_points(1,BATCH_SIZE)
    

    # Clear gradients
    optimizer.zero_grad()
    
    # Compute the true answer:
    correct_answer = real_function(x,y)
    
    # Forward pass
    phi_x = phi(x)
    phi_y = phi(y)
    approximation = rho(phi_x + phi_y)
    
    # MSE loss
    loss = loss_fn(target=correct_answer, input=approximation)
    if i % 100 == 0:
        print(loss)
    
    # Back prop:
    loss.backward()
    # Update weights:
    optimizer.step()
    # Decay LR
    scheduler.step()

tensor(0.6948, grad_fn=<MseLossBackward>)
tensor(0.0731, grad_fn=<MseLossBackward>)
tensor(0.0515, grad_fn=<MseLossBackward>)
tensor(0.0801, grad_fn=<MseLossBackward>)
tensor(0.0945, grad_fn=<MseLossBackward>)
tensor(0.0880, grad_fn=<MseLossBackward>)
tensor(0.0580, grad_fn=<MseLossBackward>)
tensor(0.0833, grad_fn=<MseLossBackward>)
tensor(0.0734, grad_fn=<MseLossBackward>)
tensor(0.0639, grad_fn=<MseLossBackward>)


In [80]:
x = generate_points(1, 10)
y = generate_points(1, 10)
print("x: ", x)
print("y: ", y)
print(real_function(x,y))
print(rho(phi(x) + phi(y)))
print("Error:", loss_fn(target=real_function(x,y), input=rho(phi(x) + phi(y))))

x:  tensor([[0.9394],
        [0.6550],
        [0.5906],
        [0.2541],
        [0.7624],
        [0.0547],
        [0.7827],
        [0.0242],
        [0.8106],
        [0.2213]])
y:  tensor([[0.2046],
        [0.0011],
        [0.4268],
        [0.1612],
        [0.3806],
        [0.5792],
        [0.4267],
        [0.1000],
        [0.4402],
        [0.0619]])
tensor([[0.7349],
        [0.6539],
        [0.1638],
        [0.0930],
        [0.3818],
        [0.5245],
        [0.3560],
        [0.0758],
        [0.3704],
        [0.1594]])
tensor([[0.3271],
        [0.1876],
        [0.2909],
        [0.1188],
        [0.3268],
        [0.1812],
        [0.3458],
        [0.0355],
        [0.3576],
        [0.0810]], grad_fn=<MmBackward>)
Error: tensor(0.0529, grad_fn=<MseLossBackward>)


In [16]:
# For training the net, create optimizers:
params = list(symm.parameters())
optimizer =  torch.optim.Adam(params, lr=0.003)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

In [17]:
BATCH_SIZE = 32
for i in range(500):
    
    optimizer.zero_grad()
    
    x = generate_points(1,BATCH_SIZE)
    y = generate_points(1,BATCH_SIZE)
    
#     print(x, y)
    
    correct_answer = real_function(x,y)
    
    
#     print(phi_x)
#     print(phi_y)


    approximation = symm(x,y)
    
#     print(correct_answer)
#     print(approximation)
    
    loss = torch.nn.MSELoss()(target=correct_answer, input=approximation)
    if i % 100 == 0:
        print(loss)
    
    loss.backward()
    optimizer.step()
    scheduler.step()

tensor(0.0812, grad_fn=<MseLossBackward>)
tensor(0.0083, grad_fn=<MseLossBackward>)
tensor(0.0087, grad_fn=<MseLossBackward>)
tensor(0.0088, grad_fn=<MseLossBackward>)
tensor(0.0081, grad_fn=<MseLossBackward>)


In [18]:
x = generate_points(1, 10)
y = generate_points(1, 10)
print("x: ", x)
print("y: ", y)
print(real_function(x,y))
print(symm(x,y))

x:  tensor([[0.9528],
        [0.1039],
        [0.3522],
        [0.7515],
        [0.7408],
        [0.1683],
        [0.9277],
        [0.1361],
        [0.4304],
        [0.2179]])
y:  tensor([[0.6248],
        [0.0087],
        [0.2785],
        [0.0422],
        [0.4817],
        [0.2681],
        [0.0803],
        [0.6927],
        [0.7939],
        [0.9807]])
tensor([[0.5953],
        [0.0009],
        [0.0981],
        [0.0317],
        [0.3569],
        [0.0451],
        [0.0745],
        [0.0943],
        [0.3417],
        [0.2137]])
tensor([[ 0.5452],
        [-0.1965],
        [ 0.0766],
        [ 0.1474],
        [ 0.3670],
        [-0.0318],
        [ 0.2567],
        [ 0.1618],
        [ 0.3710],
        [ 0.3534]], grad_fn=<AddBackward0>)
