In [31]:
from environments import DeepSea
import numpy as np
from tqdm.notebook import tqdm
import random
random.seed(42)

In [25]:
H = 3
env = DeepSea(H)
cost = 0.0
tuples = []
for k in tqdm(range(int(1000*2**H))):
    s,done = env.reset()
    while not done:
        a = np.random.binomial(1,0.1)
        c,s_,done = env.step(s,a)
        cost += c
        tuples.append([s,a,c,s_])
        s = s_

  0%|          | 0/8000 [00:00<?, ?it/s]

In [26]:
class Fitted_Q_Iteration(object):
    def __init__(self,data,loss,iters,nState,nAction,gamma,nEnsemble):
        self.data = data.copy()
        self.loss = loss
        self.iters = iters
        self.nState = nState
        self.nAction = nAction
        self.gamma = gamma
        self.nEnsemble = nEnsemble
        self.d = nState*nAction
        self.n = len(data)
        self.phi = np.identity(nState*nAction)
        self.f = np.zeros((nState,nAction))
        self.f_ = np.zeros((nState,nAction)) 
        self.Vf = np.zeros(nState)
        self.get_feature_idx()
        self.Q_ensemble = np.random.uniform(size=(self.nEnsemble,nState,nAction))
        
    
    def get_feature_idx(self):
        self.tuple_to_idx = {}
        self.idx_to_tuple = {}
        k = 0
        for s in range(self.nState):
            for a in range(self.nAction):
                self.tuple_to_idx[s,a] = k
                self.tuple_to_idx[k] = [s,a]
                k += 1
        
    
    def get_targets(self):
        self.tar = np.zeros( self.n )
        data = self.data
        for i in range(self.n):
            c, s_ = data[i][2], data[i][3]
            if s_ != None:
                self.tar[i] = max(0,min(c + self.gamma * min(self.f_[s_,:]),1))
            else:
                self.tar[i] = c
                
    def find_Q_sq(self):
        data = self.data.copy()
        self.q_score_sq = np.zeros(self.nEnsemble)
        #for i in range(self.nEnsemble):
        for j in range(self.n):
            s,a = data[j][0],data[j][1]
            self.q_score_sq[:] += (self.Q_ensemble[:,s,a] - self.tar[j]) ** 2
                
                    
                
        best_ensemble = np.random.choice(np.flatnonzero(self.q_score_sq == self.q_score_sq.min()))
        return best_ensemble
    
    
    def find_Q_log(self):
        data = self.data.copy()
        self.q_score = np.zeros(self.nEnsemble)
        #for i in range(self.nEnsemble):
        for j in range(self.n):
            s,a = data[j][0],data[j][1]
            self.q_score[:] += -1.0 * (self.tar[j] * np.log(self.Q_ensemble[:,s,a]) + (1 - self.tar[j]) * np.log(1 - self.Q_ensemble[:,s,a]))
                
        best_ensemble = np.random.choice(np.flatnonzero(self.q_score == self.q_score.min()))
        return best_ensemble
                    
                
        best_ensemble = np.random.choice(np.flatnonzero(self.q_score == self.q_score.min()))
        return best_ensemble
            
    
    def update_Q_log(self):
        self.get_targets()
        best_q = self.find_Q_log()
        self.f_ = self.Q_ensemble[best_q]
    
    def update_Q_sq(self):
        self.get_targets()
        best_q = self.find_Q_sq()
        self.f_ = self.Q_ensemble[best_q]
    
    def update_Q(self):
        if self.loss == 'log':
            self.update_Q_log()
        else:
            self.update_Q_sq()
    
    def run(self):
        for k in tqdm(range(self.iters)):
            self.update_Q()
        

In [36]:
nState = len(env.state_to_tuple)
agent = Fitted_Q_Iteration(tuples[:10000],'log',200,nState,2,0.99,1000)
agent.run()
print(agent.f_)

  0%|          | 0/200 [00:00<?, ?it/s]

[[0.60536612 0.76488246]
 [0.73284368 0.52049331]
 [0.87079687 0.14756819]
 [0.86698985 0.66309174]
 [0.80098791 0.81934321]
 [0.03993512 0.07689024]]


In [38]:
nState = len(env.state_to_tuple)
agent = Fitted_Q_Iteration(tuples[:10000],'square',200,nState,2,0.99,1000)
agent.run()
print(agent.f_)

  0%|          | 0/200 [00:00<?, ?it/s]

[[0.11967657 0.46083258]
 [0.65699108 0.67552491]
 [0.50364398 0.68490746]
 [0.94614328 0.87869146]
 [0.70570375 0.62351095]
 [0.50199527 0.75937238]]


In [40]:
agent.q_score_sq

array([1246.95340075, 2883.05786822, 2529.09837598, 2583.31756627,
       2809.06241014, 1765.02232053, 4620.90427345, 1074.00915762,
       3561.73180837, 1633.97145615, 1398.58407025, 2336.83867674,
       1176.69098174, 2160.9717937 , 2412.29345514, 1117.55930794,
       2380.24711498, 1550.67851995, 1727.76926431, 2301.58885174,
       3100.1357839 , 2485.90197355,  747.46605157, 1686.92063619,
       1051.11150684, 1026.47638471, 2234.50487873, 1672.9144656 ,
       1636.24025294, 2786.44706974,  650.18716143, 1985.8229527 ,
       2881.72725023, 2646.72605531, 2275.22604785,  402.87537281,
       2312.25929348, 1405.47289461, 2493.3285485 , 2300.61294596,
       3267.50620147, 1800.92439078, 2450.04923593, 1741.12154413,
       1735.74034682, 3826.10120504, 1466.54536269, 2713.13766656,
       3782.70849008, 2882.78568755, 1425.44357531, 1667.63172245,
       1247.02662142, 2211.30694431, 2896.89103366, 2465.97988452,
       1746.44251829, 1908.67214074, 1149.59955739, 1227.17384

[[0, 0, 0, 1],
 [1, 0, 0, 3],
 [3, 1, 0, 7],
 [7, 0, 0, 11],
 [11, 0, 1, None],
 [0, 1, 0, 2],
 [2, 0, 0, 4],
 [4, 0, 0, 7],
 [7, 0, 0, 11],
 [11, 0, 1, None],
 [0, 0, 0, 1],
 [1, 0, 0, 3],
 [3, 0, 0, 6],
 [6, 1, 0, 11],
 [11, 0, 1, None],
 [0, 0, 0, 1],
 [1, 0, 0, 3],
 [3, 0, 0, 7],
 [7, 0, 0, 11],
 [11, 0, 1, None],
 [0, 0, 0, 1],
 [1, 0, 0, 3],
 [3, 0, 0, 6],
 [6, 0, 0, 10],
 [10, 0, 1, None],
 [0, 0, 0, 1],
 [1, 0, 0, 3],
 [3, 0, 0, 6],
 [6, 0, 0, 11],
 [11, 0, 1, None],
 [0, 0, 0, 1],
 [1, 1, 0, 4],
 [4, 0, 0, 7],
 [7, 0, 0, 11],
 [11, 0, 1, None],
 [0, 0, 0, 1],
 [1, 0, 0, 4],
 [4, 0, 0, 7],
 [7, 0, 0, 11],
 [11, 0, 1, None],
 [0, 0, 0, 1],
 [1, 0, 0, 3],
 [3, 1, 0, 7],
 [7, 0, 0, 11],
 [11, 0, 1, None],
 [0, 1, 0, 2],
 [2, 0, 0, 5],
 [5, 0, 0, 8],
 [8, 0, 0, 12],
 [12, 0, 1, None],
 [0, 0, 0, 1],
 [1, 0, 0, 3],
 [3, 0, 0, 6],
 [6, 0, 0, 10],
 [10, 0, 1, None],
 [0, 0, 0, 1],
 [1, 0, 0, 3],
 [3, 0, 0, 6],
 [6, 0, 0, 10],
 [10, 1, 1, None],
 [0, 0, 0, 2],
 [2, 0, 0, 4],
 [4, 0, 0,