In [None]:
from collections import namedtuple

import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.signal import convolve

In [None]:
N = 101
x = np.linspace(-2,2, 101)
v = np.linspace(-3,3, 101)

amplification = 0.23

f_action = lambda a: np.tanh(a)
for x in [-2, -1, 0, 1, 2]:
    print(f'a: {x}, force: {f_action(x)}')

def f_gravity(x):
    phi = np.zeros_like(x)
    is_leq0 = x<=0
    phi_1 = 2*x+1
    a = (1 + 5*x**2)
    phi_2 = a**(-0.5)
    phi_3 = - 5*x**2*a**(-1.5)
    phi_4 = (x/2)**4 # negative of paper
    phi = phi_2 + phi_3 + phi_4
    phi[is_leq0] = phi_1[is_leq0]
    return -phi

N = 101
w = 4/N
x = np.linspace(-2,2,N)
g = f_gravity(x) * amplification
h = np.cumsum(-g)*w
h -= h.min()
#plt.plot(x, g, label='gravity')
plt.plot(x,h, label='height')

def f_friction(v):
    return -1/4*v

#plt.plot(v, f_friction(v), label='friction')
def f(t, s, a):
    ds = np.array([s[1], amplification * (f_action(a) + f_gravity(s[0])) + f_friction(s[1])])
    return ds


plt.legend()


In [None]:
def s1_given_s_a(s, a, debug=False):
    sol = solve_ivp(f, t_span=[0.0, 2.0], y0=s, vectorized=True, args=(a,))
    if debug:
        return sol
    
    return sol.y[:,-1]

s = [-1.0,0.]
sol = s1_given_s_a(s=s, a=-2, debug=True)

fig, ax = plt.subplots(1, 1, figsize=(2*8, 2*8))
#ax[0].plot(sol.t, sol.y[0,:])
#ax[0].set_ylabel('x(t)')
#ax[1].plot(sol.t, sol.y[1,:])
#ax[1].set_ylabel('v(t)')
plt.sca(ax)
plt.plot([s[0], sol.y[0,-1]], [s[1], sol.y[1,-1]], 'x-')
plt.xlim([-2,2])
plt.ylim([-3,3])
plt.title('state transition')

In [None]:
# construct table
n_x, n_v = 32, 32
n_s = n_x * n_v
bounds_x = np.array([-2,2])
bounds_v = np.array([-3,3])
cell_x = (bounds_x[1]-bounds_x[0])/n_x
cell_v = (bounds_v[1]-bounds_v[0])/n_v
def index(x, bounds, nbins):
    x = np.clip(x, bounds[0], bounds[1]-1e-6)
    r = (bounds[1]-bounds[0])
    idx = int((x-bounds[0])/r*nbins)
    return idx

def value(idx, bounds, nbins):
    r = (bounds[1]-bounds[0])
    w = r / nbins
    return float(idx)/nbins * r + bounds[0] +w/2

def index_x(x):
    return index(x, bounds_x, n_x)

def value_x(x):
    return value(x, bounds_x, n_x)

def index_v(v):
    return index(v, bounds_v, n_v)

def value_v(v):
    return value(v, bounds_v, n_v)

def idx_s_from_idx_xv(x, v):
    return v * n_x + x

def idx_xv_from_idx_s(s):
    x = s % n_x
    v = (s-x) / n_x
    return [int(x), int(v)]

def s_from_index_s(i_s):
    i_x, i_v = idx_xv_from_idx_s(i_s)
    return [value_x(i_x), value_v(i_v)]

def index_s_from_s(s):
    return idx_s_from_idx_xv(index_x(s[0]), index_v(s[1]))

print(n_x, bounds_x, cell_x)
xx = np.linspace(bounds_x[0]+cell_x/2, bounds_x[1]-cell_x/2, n_x)
vv = np.linspace(bounds_v[0]+cell_v/2, bounds_v[1]-cell_v/2, n_v)
aa = np.array([-2, -1, 0, 1, 2], dtype=int)

In [None]:
def compute_transition():
    p_s1_given_s_a = np.zeros(shape=(aa.shape[0], n_s, n_s))
    for i, a in enumerate(aa):
        print('action', a)
        for j, x in enumerate(xx):
            for k, v in enumerate(vv):
                i_s0 = k * n_x + j
                s1 = s1_given_s_a(s=np.array([x, v]), a=a)
                i_s1 = index_s_from_s(s1)
                #print(i_s0, i_s1, index_x(s1[0]), index_v(s1[1]))
                p_s1_given_s_a[i, i_s0, i_s1] = 1
                
    return p_s1_given_s_a
            
p_forward = compute_transition()

In [None]:
s0 = [-0.6, .8]
a = 3
i_s0 = index_s_from_s(s0)
i_s0_xv = idx_xv_from_idx_s(i_s0)
s0_discrete = s_from_index_s(i_s0)
print('s0', s0, i_s0_xv, i_s0, s0_discrete)

i_s1 = np.argmax(p_forward[a, i_s0])
i_s1_xv = idx_xv_from_idx_s(i_s1)
s1 = s_from_index_s(i_s1)
print('s1', s1, i_s1_xv, i_s1)

fig, ax = plt.subplots(1, 2, figsize=(2*8, 6))
plt.sca(ax[0])
plt.plot([s0[0], s1[0]], [s0[1], s1[1]])
plt.scatter(s1[0], s1[1], marker='o')
plt.xlim([-2,2])
plt.ylim([-2,2])

plt.sca(ax[1])

def show_belief(b, ax=None):
    if ax is not None:
        plt.sca(ax)
    plt.imshow(np.flip(b.reshape(n_v, -1), axis=0), cmap='gray', extent=bounds_x.tolist() + bounds_v.tolist())
    
show_belief(p_forward[a, i_s0])
plt.plot(bounds_x, [s1[1]]*2, 'r-')
plt.plot([s1[0]]*2, bounds_v, 'r-')

In [None]:
def get_p_pullback(p_forward):
    p_pullback = np.zeros_like(p_forward)
    for i, a in enumerate(aa):
        p = p_forward[i].T
        norm = lambda x: np.maximum(x.sum(axis=1).reshape((-1, 1)), 1e-6)
        l1 = norm(p)
        p = p / l1
        l1_1 = norm(p)
        #plt.plot(l1_1.reshape(-1), label='after')
        #plt.legend()
        p_pullback[i] = p
        
    return p_pullback

p_pullback = get_p_pullback(p_forward)

In [None]:
def normalize(p):
    norm = lambda x: np.maximum(x.sum(axis=1).reshape((-1, 1)), 1e-6)
    l1 = norm(p)
    return p / l1

def smooth_transitions(p_pullback):
    kernel = np.array([.25, 0.5, 0.25]).reshape((-1,1))
    show_image = True
    
    for i, a in enumerate(aa):
        for j, x in enumerate(xx):
            for k, v in enumerate(vv):
                i_s1 = k * n_x + j
                if p_pullback[i, i_s1].sum() < 1e-6:
                    p_pullback[i, i_s1, i_s1] = 1

                # smooth s1
                p = convolve(p_pullback[i, i_s1].reshape(n_v, -1), kernel)
                p = convolve(p, kernel.T)
                # saturate transitions at the boundaries
                p[1,:] += p[0,:]
                p[0,:] = 0
                p[-2,:] += p[-1,:]
                p[-1:0] = 0
                p[:,1] += p[:,0]
                p[:,0] = 0
                p[:,-2] += p[:,-1]
                p[:,-1] = 0

                #plt.imshow(p, cmap='gray')
                #print(p.sum(), p.shape)
                transition = p[1:-1,1:-1]
                if show_image and (np.random.uniform() < 0.1):
                    show_belief(transition.reshape(-1))
                    show_image=False
                    
                #print(transition.sum())
                p_pullback[i, i_s1] = transition.reshape(-1)
                    
        p_pullback[i] = normalize(p_pullback[i])
                
    return p_pullback
                
p_pullback = smooth_transitions(p_pullback)

In [None]:
s1 = np.array([1,-1])
i_s1 = index_s_from_s(s1)
b = np.zeros(n_s)
b[i_s1] = 1.

#p = p_s0_given_a_s1[i_s1]
for i in range(40):
    print(b.sum())
    #b = np.dot(p_pullback[2].T, b)
    b = np.dot(b, p_pullback[2])
    #b = b / (np.sum(b) + 1e-6)
    
show_belief(b)
print(b.sum())

In [None]:
idx = np.argmax(b)
print(idx, b[idx], b.sum())

In [None]:
print('forward', p_forward[0, :, i_s1].sum(), 'state', np.argmax(p_forward[0, :, i_s1]))

In [None]:
print('backward', p_pullback[0, i_s1].sum(), 'state', np.argmax(p_pullback[0, i_s1]))

In [None]:
print(p_forward[0, 626, i_s1])
print(p_pullback[0, i_s1, 626])