In [5]:
import numpy as np
from senv import MountainCar
import copy
import scipy as sc
from tqdm.notebook import tqdm
from sfqi import fqi
from joblib import Parallel, delayed
from sexp import eval 

In [6]:
def getdata(env, tt, st=1):
    H, w, sd = env.H, env.width, env.getstate().shape[1]
    succ=np.zeros((H,st,sd+2))# stores state,action,cost
    fail=np.zeros((H,ft:=tt-st,sd+2))
    fc,sc=ft,st
    while sc or fc:
        sac=np.zeros((H,w,sd+2))
        sac[:,:,-2] = np.random.randint(3,size=(H,w))
        s = env.reset()
        for h in range(H):
            sac[h,:,:-2]=s
            c, s = env.step(sac[h,:,-2])
            sac[h,:,-1]=c
        fidx=np.flatnonzero(c)[:fc]
        fail[:,ft-fc:ft-(fc:=fc-fidx.size)]=sac[:,fidx]
        sidx=np.flatnonzero(1-c)[:sc]
        succ[:,st-sc:st-(sc:=sc-sidx.size)]=sac[:,sidx]
    return np.append(succ,fail,axis=1)

In [7]:
def eval(env, theta):
    w=env.width ; s=env.reset(1)
    cost=0
    for h in range(env.H):
        c, s = env.step(np.argmin(theta[h]@s.T))
        cost+=c
    env.reset(w)
    return cost.squeeze()

In [8]:
class FQI(object):
    def __init__(self, data, d, H):
        self.data = data
        self.d = d
        self.H = H

        self.theta_ls = np.zeros((self.H+1,3,self.d))
        self.theta_log = np.zeros((self.H+1,3,self.d))

    def sigmoid(self, x):
        x[x < -36] = -36
        x[x > 36] = 36
        return 1 / (1 + np.exp(-x))
    
    def get_targets(self, theta, h):
        data = self.data.copy()
        tar = {}

        a, c = data[h,:,-2], data[h,:,-1]

        a0 = np.where(a==0)
        a1 = np.where(a==1)
        a2 = np.where(a==2)

        if h != self.H - 1:

            phi_ = data[h+1,:,:-2]

            inner0 = self.sigmoid(np.matmul(phi_, theta[h+1,0]))
            inner1 = self.sigmoid(np.matmul(phi_, theta[h+1,1]))
            inner2 = self.sigmoid(np.matmul(phi_, theta[h+1,2]))

            v = np.minimum(np.minimum(inner0, inner1), inner2)
            
            tar[0] = c[a0] + v[a0]
            tar[1] = c[a1] + v[a1]
            tar[2] = c[a2] + v[a2]

        else:
            tar[0] = c[a0]
            tar[1] = c[a1] 
            tar[2] = c[a2] 
        
        return tar


    def ls_loss(self, theta, X, Y):
        return np.sum((self.sigmoid(X @ theta) - Y) ** 2)

    def ls_grad(self, theta, X, Y):
        p = self.sigmoid(X @ theta)
        der = p * (1 - p)
        scalar = 2 * (p - Y) * der
        return scalar.T @ X
    
    def log_loss(self, theta, X, Y):
        p = self.sigmoid(X @ theta)
        return -1.0 * np.sum(Y * np.log(p) + (1 - Y) * np.log(1 - p))

    def log_grad(self, theta, X, Y):
        p = self.sigmoid(X @ theta)
        scalar = (p - Y)
        return scalar.T @ X
    
    def ls_solve(self, features, tar, theta0):
        self.sol = sc.optimize.minimize(
            self.ls_loss,
            x0 = np.zeros(self.d),
            args = (features, tar),
            jac = self.ls_grad,
            method = 'bfgs',
            options={'gtol':1e-5}
        )
        return self.sol.x
    
    def log_solve(self, features, tar, theta0):
        self.sol = sc.optimize.minimize(
            self.log_loss,
            x0 = np.zeros(self.d),
            args = (features, tar),
            jac = self.log_grad,
            method = 'bfgs',
            options={'gtol':1e-5}
        )
        return self.sol.x

    def run_log(self):
        for h in tqdm(range(self.H-1,-1,-1)):
            self.tar = self.get_targets(self.theta_log, h)
            for a in range(3):
                idx = np.where(self.data[h,:,-2] == a)
                features = np.squeeze(self.data[h,idx,:-2])
                self.theta_log[h,a] = self.log_solve(features, self.tar[a], self.theta_log[h+1,a])
        return self.theta_log
    
    def run_ls(self):
        for h in tqdm(range(self.H-1,-1,-1)):
            tar = self.get_targets(self.theta_ls, h)
            for a in range(3):
                idx = np.where(self.data[h,:,-2] == a)
                features = np.squeeze(self.data[h,idx,:-2])
                self.theta_ls[h,a] = self.ls_solve(features, tar[a], self.theta_ls[h+1,a])
        return self.theta_ls






In [9]:
def run_exp(H, trials):    
    env = MountainCar(H, 2, trials)
    data = getdata(env, trials)
    agent = FQI(data, 9, H)
    datasiz = [trials]
    for j in range(len(datasiz)):
        sac=list(zip(data[:,:datasiz[j],:-2],
                        data[:,:datasiz[j],-2],
                        data[:,:datasiz[j],-1]))
        thetalog=fqi(sac,env.A,'log')
        thetasq=fqi(sac,env.A,'sq')
    theta_sq = agent.run_ls()
    theta_log = agent.run_log()

    return [theta_log[:-1], thetalog, theta_sq[:-1], thetasq]

In [10]:
num_runs = 20
H = 800
trajectories = 15000
thetas = Parallel(n_jobs=min(num_runs,45))(delayed(run_exp)(H, trajectories) for j in tqdm(range(num_runs)))

  0%|          | 0/20 [00:00<?, ?it/s]

In [11]:
a_log = []
s_log = []
a_sq = []
s_sq = []

env = MountainCar(H, 2)
env.reset(0)

for i in range(num_runs):
    a_log.append(eval(env,thetas[i][0]))
    s_log.append(eval(env,thetas[i][1]))
    a_sq.append(eval(env,thetas[i][2]))
    s_sq.append(eval(env,thetas[i][3]))


In [12]:
sum(a_log)

7.0

In [13]:
sum(s_log)

8.0

In [14]:
sum(a_sq)

17.0

In [15]:
sum(s_sq)

10.0

  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
  0%|          | 0/800 [00:00<?, ?it/s]
