In [1]:
import jax.numpy as jnp
import jax

In [2]:
def move(s,n):
    branches = jnp.array([[1,0],[-1,0],[0,1],[0,-1],[0,0]])
    s2 = s + branches[n]
    s2 = jnp.clip(s2,0,9)
    return s2

def policy(q,s,epsilon):
    policy = jnp.ones((5))*epsilon/5
    best = q[s[0],s[1]] == jnp.max(q[s[0],s[1]])
    policy += best*(1-epsilon)/best.sum()
    return policy

def max_action(q,s,epsilon,key):
    key, subkey = jax.random.split(key)
    a = jax.random.choice(subkey,jnp.arange(5),p=policy(q,s,epsilon))
    return move(s,a), a, key

def update_Q(q, trace, s0, a0, q1, qo, reward, params):
    trace = trace*params["lambda"]*params["gamma"]
    trace = trace.at[s0[0],s0[1],a0].set(1.0 - params["alpha"]*params["lambda"]*params["gamma"]*trace[s0[0],s0[1],a0])
    q += params["alpha"]*trace*(params["gamma"]*q1 + reward - qo)
    q = q.at[s0[0],s0[1],a0].add(-params["alpha"]*(q[s0[0],s0[1],a0]-qo))
    return q, trace

def step(args):
    q,trace,s,qo,key,i,params = args
    s2, a, key = max_action(q,s,params["epsilon"],key)
    expected = jnp.dot(q[s2[0],s2[1]],policy(q,s2,params["epsilon"]))
    reward = jax.lax.cond(jnp.all(s2 == jnp.array([9,9])),lambda _:1.0,lambda _:-0.01, None)
    q, trace = update_Q(q,trace,s,a,expected,qo, reward, params)
    s = s2
    qo = expected
    i += 1
    return q,trace,s,qo,key,i,params

def episode(episode_n, args):
    q, mean, key, params = args
    #params["epsilon"] = 1/episode_n
    trace = jnp.zeros_like(q)
    s = jnp.array([0,0])
    qo = 0
    i = 0

    def cond(a):
        q,trace,s,qo,key,i,params = a
        return jnp.any(s != jnp.array([9,9]))

    q,trace,s,qo,key,i,params = jax.lax.while_loop(cond,step,(q,trace,s,qo,key,i,params))
    mean += i
    return q, mean, key, params

@jax.jit
def train(seed):
    steps = 50
    params = {
        "alpha": 0.1,
        "lambda": 0.95,
        "gamma": 1.0,
        "epsilon" : 0.1
    }
    q = jnp.zeros([10,10,5])
    mean = 0
    key = jax.random.PRNGKey(seed)
    q, mean, key, params = jax.lax.fori_loop(1, steps+1, episode, (q, mean, key, params))
    return mean/steps

@jax.jit
def multi_train(seeds):
    return jnp.mean(jax.vmap(train)(seeds))

@jax.jit
def parameter_search(seeds,args):
    return jax.vmap(multi_train)(seeds,args)

In [3]:
mean = train(12)
print(mean)

58.899998


In [None]:
n_runs = 100
seeds = jnp.arange(n_runs)
param = jnp.array([[0.05],[0.1],[0.3]])
seeds = jnp.tile(seeds, (len(param),1))
param = jnp.repeat(param, n_runs, axis=1)
res = parameter_search(seeds,None)
print(res)

In [73]:
n_runs = 100
seeds = jnp.arange(n_runs)
res = multi_train(seeds)
res