In [260]:
import numpy as np
import random
from tqdm.notebook import tqdm
import copy

In [261]:
class Environment(object):
    '''General RL environment'''

    def __init__(self):
        pass

    def reset(self):
        pass

    def advance(self, action):
        '''
        Moves one step in the environment.
        Args:
            action
        Returns:
            reward - double - reward
            newState - int - new state
            pContinue - 0/1 - flag for end of the episode
        '''
        return 0, 0, 0

In [262]:
def make_riverSwim(epLen=20, nState=5):
    '''
    Makes the benchmark RiverSwim MDP.
    Args:
        NULL - works for default implementation
    Returns:
        riverSwim - Tabular MDP environment '''
    nAction = 2
    R_true = {}
    P_true = {}

    for s in range(nState):
        for a in range(nAction):
            R_true[s, a] = (0, 0)
            P_true[s, a] = np.zeros(nState)

    # Rewards
    R_true[0, 0] = (5 / 100, 0)
    R_true[nState - 1, 1] = (1, 0)

    # Transitions
    for s in range(nState):
        P_true[s, 0][max(0, s-1)] = 1.

    for s in range(1, nState - 1):
        P_true[s, 1][min(nState - 1, s + 1)] = 0.3
        P_true[s, 1][s] = 0.6
        P_true[s, 1][max(0, s-1)] = 0.1

    P_true[0, 1][0] = 0.3
    P_true[0, 1][1] = 0.7
    P_true[nState - 1, 1][nState - 1] = 0.9
    P_true[nState - 1, 1][nState - 2] = 0.1

    riverSwim = TabularMDP(nState, nAction, epLen)
    riverSwim.R = R_true
    riverSwim.P = P_true
    riverSwim.reset()

    return riverSwim

In [263]:
class TabularMDP(Environment):
    '''
    Tabular MDP
    R - dict by (s,a) - each R[s,a] = (meanReward, sdReward)
    P - dict by (s,a) - each P[s,a] = transition vector size S
    '''

    def __init__(self, nState, nAction, epLen):
        '''
        Initialize a tabular episodic MDP
        Args:
            nState  - int - number of states
            nAction - int - number of actions
            epLen   - int - episode length
        Returns:
            Environment object
        '''

        self.nState = nState
        self.nAction = nAction
        self.epLen = epLen

        self.timestep = 0
        self.state = 0

        # Now initialize R and P
        self.R = {}
        self.P = {}
        for state in range(nState):
            for action in range(nAction):
                self.R[state, action] = (1, 1)
                self.P[state, action] = np.ones(nState) / nState
                
    def reset(self):
        "Resets the Environment"
        self.timestep = 0
        self.state = 0
        
    def advance(self,action):
        '''
        Move one step in the environment
        Args:
        action - int - chosen action
        Returns:
        reward - double - reward
        newState - int - new state
        pContinue - 0/1 - flag for end of the episode
        '''
        if self.R[self.state, action][1] < 1e-9:
            # Hack for no noise
            reward = self.R[self.state, action][0]
        else:
            reward = np.random.normal(loc=self.R[self.state, action][0],
                                      scale=self.R[self.state, action][1])
        #print(self.state, action, self.P[self.state, action])
        newState = np.random.choice(self.nState, p=self.P[self.state, action])
        
        # Update the environment
        self.state = newState
        self.timestep += 1

        if self.timestep == self.epLen:
            pContinue = 1
            #newState = None
            self.reset()
        else:
            pContinue = 0

        return reward, newState, pContinue
    
    def argmax(self,b):
        return np.random.choice(np.where(b == b.max())[0])

In [264]:
class deep_sea(Environment):
    '''
    Description:
        A deep sea environment, where a diver goes
        down and each time and she needs to make a
        decision to go left or right.
        environment terminates after fixed time step

    Observation:
        [horizontal position, vertical position]

    Actions:
        2 possible actions:
        0 - left
        1 - right

    Starting State:
        start at position 0, time step 0

    Episode termination:
        Env terminates after fixed number of time steps
    '''

    def __init__(self, num_steps):
        self.num_steps = num_steps
        self.epLen = num_steps
        self.flip_mask = 2*np.random.binomial(1,0.5,(num_steps,num_steps))-1
        self.nAction = 2
        self.nState = num_steps
        self.epLen = num_steps
        self.R = {}
        for s in range(self.nState):
            for s_ in range(self.nState):
                self.R[(s,s_), 0] = (0, 0)
                self.R[(s,s_), 1] = (-0.01/self.nState, 0)
        self.R[(self.num_steps-1,self.num_steps-1),1] = (0.99,0)

    def name(self):
        return  "deep sea"

    def reset(self):
        self.state = (0,0)
        return copy.deepcopy(self.state)

    def advance(self,action):
        assert action in [0,1], "invalid action"
        self.state_prev = self.state
        step_horizontal = (2*action-1)
        horizontal = max(self.state[0] + step_horizontal, 0)
        vertical = self.state[1] + 1
        done =  bool(vertical == self.num_steps)
        self.state = (horizontal, vertical)
        return self.R[self.state_prev,action][0], copy.deepcopy(self.state), done
    
    def argmax(self,b):
        return np.random.choice(np.where(b == b.max())[0])

In [312]:
class RLSVI(object):
    def __init__(self,env):
        self.env = env
        self.buffer = {key: [] for key in self.env.R.keys()}
        self.buf = []
        self.Q = {key: 0.0 for key in self.buffer.keys()}
        self.prior_variance = 100.0
        self.noise_mean = -0.2
        self.noise_variance = 0.05
        self.action_set = (0, 1)
    
    def _random_argmax(self, action_values):
        argmax_list = np.where(action_values==np.max(action_values))[0]
        return self.action_set[argmax_list[np.random.randint(argmax_list.size)]]
    
    def act(self,s):
        x = [self.Q[(s,a)] for a in range(self.env.nAction)]
        return env.argmax(np.array(x))
    
    def update_buffer(self,data):
        s,a,r,s_ = data[0],data[1],data[2],data[3]
        self.buffer[(s,a)].append((r,s_))
    
    def learn_from_buffer(self):
        perturbed_buffer = {key: [(transition[0] + np.sqrt(self.noise_variance) * np.random.randn() + self.noise_mean,
                                   transition[1]) for transition in self.buffer[key]]
                            for key in self.buffer.keys()}
        random_Q = {key: np.sqrt(self.prior_variance) * np.random.randn() for key in self.buffer.keys()}
        Q = {key: 0.0 for key in self.buffer.keys()}
        Q_temp = {key: 0.0 for key in self.buffer.keys()}
        for n in range(self.env.epLen):
            for key in self.buffer.keys():
                q = 0.0
                for transition in perturbed_buffer[key]:
                    if transition[1] == None:
                        q += transition[0]
                    else:
                        #print(Q)
                        v = max(Q[(transition[1], a)] for a in self.action_set)
                        q += transition[0] + v
                Q_temp[key] = (1.0 / ((len(self.buffer[key]) / self.noise_variance) + (1.0 / self.prior_variance))) \
                              * ((q / self.noise_variance) + (random_Q[key] / self.prior_variance))
            Q = Q_temp
            Q_temp = {key: 0.0 for key in self.buffer.keys()}
        self.Q = Q
        

In [334]:
env = make_riverSwim(epLen = 20, nState = 5)
agent = RLSVI(env)
np.random.seed(0)
for l in tqdm(range(0,200)):
    #if (l + 1 % 150) == 0:
    #    agent.buffer_reset()
    env.reset()
    done = 0
    while done != 1:
        s = env.state
        #print(s)
        a = agent.act(s)
        r,s_,done = env.advance(a)
        #print(r)
        if done == 1:
            agent.update_buffer((s,a,r,None))
        else:
            agent.update_buffer((s,a,r,s_))
        
    agent.learn_from_buffer()

HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))




In [341]:
print(agent.Q)

{(0, 0): -0.31917201459469463, (0, 1): -0.016623200082171365, (1, 0): -0.3650414373407999, (1, 1): 0.4504026992197599, (2, 0): 0.10109746692039821, (2, 1): 1.5431141011224674, (3, 0): 1.1623573926565136, (3, 1): 3.1491526565583454, (4, 0): 1.7721835425419805, (4, 1): 5.232086947468123}


In [321]:
x=[(1, 14), (1, 14), (1, 14), (1, 14), (1, 14), (1, 14), (1, 14), (1, 14), (1, 13), (1, 14), (1, 14), (1, 13), (1, 14), (1, 13), (1, 14), (1, 14), (1, 14), (1, 14), (1, 14), (1, None), (1, 13), (1, 14), (1, 14), (1, 14), (1, 13), (1, 14), (1, 14), (1, 14), (1, 14), (1, 14), (1, 14), (1, 14), (1, 14), (1, 14), (1, 14), (1, 14), (1, 14), (1, 14), (1, 13), (1, 14), (1, None)]

In [322]:
len(x)

41

In [331]:
env = make_riverSwim(epLen = 100, nState = 15)
agent = RLSVI(env)
np.random.seed(0)
R = 0
for l in tqdm(range(0,150)):
    #if (l + 1 % 150) == 0:
    #    agent.buffer_reset()
    env.reset()
    done = 0
    while done != 1:
        s = env.state
        #print(s)
        a = 0
        r,s_,done = env.advance(a)
        R = R + r
    print(l,R)

HBox(children=(FloatProgress(value=0.0, max=300.0), HTML(value='')))

0 2.9999999999999973
1 5.999999999999987
2 8.999999999999993
3 12.000000000000036
4 15.000000000000078
5 18.00000000000012
6 21.000000000000163
7 24.000000000000206
8 27.00000000000025
9 30.00000000000029
10 33.00000000000026
11 36.00000000000009
12 38.99999999999992
13 41.99999999999975
14 44.99999999999958
15 47.99999999999941
16 50.99999999999924
17 53.99999999999907
18 56.9999999999989
19 59.99999999999873
20 62.99999999999856
21 65.9999999999984
22 68.99999999999822
23 71.99999999999805
24 74.99999999999788
25 77.99999999999771
26 80.99999999999754
27 83.99999999999737
28 86.9999999999972
29 89.99999999999703
30 92.99999999999686
31 95.99999999999669
32 98.99999999999652
33 101.99999999999635
34 104.99999999999618
35 107.999999999996
36 110.99999999999584
37 113.99999999999567
38 116.9999999999955
39 119.99999999999532
40 122.99999999999515
41 125.99999999999498
42 128.99999999999508
43 131.99999999999577
44 134.99999999999645
45 137.99999999999713
46 140.9999999999978
47 143.9999

In [None]:
1257