In [1]:
import numpy as np
import matplotlib.pyplot as plt
import numba
import sys
import time

sigma = 0.01                        # noise
dim = 2                             # dimensionality of the space (v: R^d --> R)
obs_min = np.array([0,0])
obs_max = np.array([1,1])

In [2]:
"""

FUNCTIONS - LIKELIHOOD 

"""

m = 21

# Define h where it is fixed (when x_2=0 or x_2=1).
h_first_row = np.zeros(m)
h_last_row = np.zeros(m)
for i in range(m): # Points on the bottom, where x_2=0 (h=x_1)
    h_first_row[i] = i/(m-1)
for i in range(m): # Points on top, where x_2=1 (h=1-x_1)
    h_last_row[i] = 1-i/(m-1)
    
"""
Initialise the finite difference scheme.
"""
@numba.njit()
def A_init(k):   
    A = np.zeros((m**2,m**2))
    for j in range(m-2): # central points
        for i in range(m-2):
            A[(i+1)+m*(j+1),(i+1)+m*(j+1)] = 2*k[(i+1)+m*(j+1)]+k[(i)+m*(j+1)]+k[(i+1)+m*(j)]
            A[(i+1)+m*(j+1),(i)+m*(j+1)] = -k[(i)+m*(j+1)]
            A[(i+1)+m*(j+1),(i+1)+m*(j)] = -k[(i+1)+m*(j)]
            A[(i+1)+m*(j+1),(i+2)+m*(j+1)] = -k[(i+1)+m*(j+1)]
            A[(i+1)+m*(j+1),(i+1)+m*(j+2)] = -k[(i+1)+m*(j+1)]
    i = m-1 # Points on the right, where x=6 (h_x=0)
    for j in range(m-2):
        A[i+m*(j+1),i+m*(j+1)] = k[i+m*(j+1)]+k[i+m*(j)]+k[i-1+m*(j+1)]
        A[i+m*(j+1),i+m*(j)] = -k[i+m*(j)]
        A[i+m*(j+1),i+m*(j+2)] = -k[i+m*(j+1)]
        A[i+m*(j+1),i-1+m*(j+1)] = -k[i-1+m*(j+1)]
    i = 0 # Points on the left, where x=0 (h_x=0)
    for j in range(m-2):
        A[i+m*(j+1),i+m*(j+1)] = 2*k[i+m*(j+1)]+k[i+m*(j)]
        A[i+m*(j+1),i+m*(j)] = -k[i+m*(j)]
        A[i+m*(j+1),(i+1)+m*(j+1)] = -k[i+m*(j+1)]
        A[i+m*(j+1),i+m*(j+2)] = -k[i+m*(j+1)]       
    # Corner points are already considered in the cases where x_2=0 and x_2=1.'''   
    A = 1/((6/(m-1))**2)*A   # Multiply by 1/Delta^2
    return A

"""
G maps a function u to the output of a PDE, h. It makes use of the finite 
difference approximation given by A.k and b_k. It's specific for the problem
setup.
"""
@numba.njit()
def G(k):    
    A_k = A_init(k)
    h = np.zeros(m**2)
    b = np.zeros(m**2)
    for i in range(m):
        b[i+m] -= A_k[i+m,i]*h_first_row[i]
        b[i+m*(m-2)] -= A_k[i+m*(m-2),i+m*(m-1)]*h_last_row[i]        
        
    # As the first and last column are fixed, we only consider a subset of the problem for the linalg solver.
    h_sub = np.zeros(m*(m-2))     
    A_sub = A_k[:,m:m*(m-1)]
    A_sub = A_sub[m:m*(m-1),:]
    h_sub = np.linalg.solve(A_sub,b[m:m*(m-1)])

    h[m:m*(m-1)] = h_sub
    h[0:m] = h_first_row
    h[m*(m-1):] = h_last_row    
    return h

In [3]:
"""

FUNCTIONS - VALUE FUNCTION AND POLICIES

"""

''' Progress bar to know how much longer one has to wait '''
def progressBar(value, endvalue, bar_length=40):
    percent = float(value) / endvalue
    arrow = '-' * int(round(percent * bar_length)-1) + '>'
    spaces = ' ' * (bar_length - len(arrow))
    sys.stdout.write("\rPercent: [{0}] {1}%".format(arrow + spaces, int(round(percent * 100))))
    sys.stdout.flush() 
    
''' u(xi,x), which evaluates the function u with coefficients xi at x=(pos,speed) --- new, complicated, but fast version '''
@numba.njit()
def u(xi,x):
    u_sum = 0
    for i in range(N_trunc):
        for j in range(N_trunc):
            i_eval = np.pi*(i+1/2)*x[0]
            j_eval = np.pi*(j+1/2)*x[1]
            u_sum += xi[i,j]*np.cos(i_eval)*np.cos(j_eval)
    return 2*u_sum

''' Plotting a mean function exp(u) '''    
def mean_func_plot(name):
    x = np.arange(0,1,0.01)
    y = np.arange(0,1,0.01)
    X,Y = np.meshgrid(x,y)
    Z = np.zeros(X.shape)
    
    for it in range(8000):
        for i in range(X.shape[0]):
            for j in range(X.shape[1]):
                x=np.zeros(dim)
                x[0]=X[i,j]
                x[1]=Y[i,j]
                Z[i,j] += np.exp(u(xi[it],x))/8000
                
#     for i in range(X.shape[0]):
#         for j in range(X.shape[1]):
#             Z[i,j] = np.exp(Z[i,j])

    fig, ax = plt.subplots()
    c = ax.pcolormesh(X, Y, Z, cmap='cool', vmin=Z.min(), vmax=Z.max())
    
    # set the limits of the plot to the limits of the data
    ax.axis([X.min(), X.max(), Y.min(), Y.max()])
    fig.colorbar(c, ax=ax)
    fig.savefig(name + '.pdf', dpi=300, bbox_inches='tight')
    plt.close(fig)
    
''' Check the predicitve posterior '''
def data_prediction_check(name):
    
    ''' load data and observation points '''
    N_data = 33
    data = np.load('GWF_data.npy')
    obs_indices = np.load('GWF_obs_indices.npy')
    
    ''' Inititalise list for predicted data and for the labels '''
    y_rep = []
    labels = []
    for i in range(N_data):
        y_rep.append([])
        x = (obs_indices[i]%m)/(m-1)
        y = ((obs_indices[i]-obs_indices[i]%m)/m)/(m-1)
        labels.append(str((x,y)))
    
    for it in range(8000):
        progressBar(it,8000)
    
        k = np.zeros(m**2)
        for j in range(m):
            for i in range(m):
                k[i+m*j] = np.exp(u(xi[it],(i/(m-1),j/(m-1))))
        h = G(k)
        y_h = h[obs_indices]
        
        ''' Create data predictions '''
        for i in range(N_data):
            if noise:
                y_rep[i].append(y_h[i]+sigma*np.random.normal())
            else:
                y_rep[i].append(y_h[i])
    
    fig, ax = plt.subplots()
    ax.boxplot(y_rep)
    ax.set_xticklabels(labels, rotation=90)

    ''' Print true data points '''
    x = []
    y = []
    for i in range(33):
        x.append(i+1)
        y.append(data[i])
    ax.plot(x, y, "o")
    
    fig.savefig(name + '.pdf', dpi=300, bbox_inches='tight')
    plt.close(fig)

# MAIN PROGRAMME

In [4]:
""" load samples - pCN """

method = 'pCN'
N_data = 33
N_trunc = 25

xi = []
for it in range(8000):
    xi.append(np.load('np_saved/GWF/samples_policy_learning/KL_'+str(N_trunc)+'_'+method+'_NData'+str(N_data)+'_sampleNo'+str(it)+'.npy'))
      
noise = True
data_prediction_check('figs/GWF/GWF_KL_prediction_check')
noise = False
data_prediction_check('figs/GWF/GWF_KL_prediction_check_no_noise')

mean_func_plot('figs/GWF/GWF_KL_mean')

Percent: [--------------------------------------->] 100%