In [32]:
# This notebook takes samples from a multivariate normal distribution 
# and fits the correct parameters via PDMM in a distributed fashion.
import autograd.numpy as np
from autograd.numpy.random import multivariate_normal, normal, randint
%matplotlib inline
from matplotlib import pyplot as plt
from MRF_Node import Node
from Neighbour import Neighbour 

In [33]:
# This function returns a local objective function for a particular set of samples.
def f(x_i,x_j,var):
    return lambda a,b,c: np.sum(np.power(x_i-a,2) + c*(x_i-a)*(x_j-b)/(2*var)) - np.log(1-np.power(c,2))/4

In [34]:
def g(x):
    if x>=1:
        return 1 - 1e-10
    if x<=-1:
        return -1 + 1e-10
    else:
        return x

In [35]:
#Initialise parameters
u = np.asarray([1,1])
E = np.asarray([[1,0.5],[0.5,1]])
n = 1000
n_sample = 100
x = multivariate_normal(u,E,n).T

In [None]:
# Initialise PDMM Graph

N_nodes = 2;
N_dim = 3
msg_dim = 3

G = []
var = 1
sample_index = randint(n,size=n_sample)
x_s = x[:,sample_index]
d_T = 1e-15
p = 1e-5

c = {2:[-0.99,0.99]}

#initialise nodes

obj = f(x_s[0],x_s[1],var)
G.append(Node(0,N_dim,obj,p,d_T,c))
obj = f(x_s[1],x_s[0],var)
G.append(Node(1,N_dim,obj,p,d_T,c))

#insert forward neighbours
A_forward = np.asarray([[0,1,0],[1,0,0],[0,0,1]])
A_backward = -np.eye(N_dim)
c_ij = np.zeros([N_dim,1])
P_ij = np.eye(N_dim)

for i in np.arange(N_nodes-1):
    neighbour = Neighbour(G[i+1],i+1,A_forward,c_ij,P_ij,msg_dim)
    G[i].Neighbours.append(neighbour)

#insert backward neighbours 
for i in np.arange(N_nodes-1)+1:
    neighbour = Neighbour(G[i-1],i-1,A_backward,c_ij,P_ij,msg_dim)
    G[i].Neighbours.append(neighbour)

In [None]:
#Train
for i in range(10):
    sample_index = randint(n,size=n_sample)
    x_s = x[:,sample_index]

    obj = f(x_s[0],x_s[1],var)
    G[0].f= obj
    obj = f(x_s[1],x_s[0],var)
    G[1].f= obj

    for node in G:
        node.update()
        
    for node in G:
        node.finalise()
        
    print(i)
    print(G[0].x)
    print(G[1].x)

5000 0.00178012873136
5000 0.00269785828582
0
[[ 0.24319903]
 [-2.84632638]
 [-0.99      ]]
[[-0.10879591]
 [-3.62151513]
 [-0.99      ]]
5000 0.00528388868361
5000 0.00767963605246
1
[[-0.55896828]
 [-5.63666563]
 [-0.99      ]]
[[-0.93172533]
 [-6.97870687]
 [-0.99      ]]
5000 0.0145879575357
5000 0.0215542214421
2
[[ -1.61640063]
 [-10.14119745]
 [ -0.99      ]]
[[ -2.35432485]
 [-12.40488204]
 [ -0.99      ]]
5000 0.0434975039645
