In [70]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

class Riverswim:
    def __init__(self, states_count, actions_count=2, gamma=0.95):
        self.states_count = states_count
        self.actions_count = actions_count
        self.P = np.zeros((states_count, actions_count, states_count))
        self.R = np.zeros((states_count, actions_count))
        self.Q = np.zeros((states_count, actions_count))
        self.gamma = gamma  # Discount factor
        self.terminal_state = states_count - 1  # Define the rightmost state as terminal

        for s in range(states_count):
            if s == 0:
                self.P[s, 0, s] = 1  
                self.P[s, 1, s] = 0.6  
                self.P[s, 1, s + 1] = 0.4  
                self.R[s, 0] = 0.05
            elif s == self.terminal_state:
                self.P[s, 0, s] = 1 
                self.P[s, 1, s] = 1  
            else:
                self.P[s, 0, s - 1] = 1 
                self.P[s, 1, s] = 0.55  
                self.P[s, 1, s + 1] = 0.4 
                self.P[s, 1, s - 1] = 0.05  
        
        # Moving into RL state by right action
        self.R[states_count - 2, 1] = 1
        self.s = 0  # Start state
        self.max_state = 0

    def reset(self):
        self.s = 0
        return self.s

    def step(self, epsilon=0.2):
        """ Take an epsilon-greedy step """
        if np.random.rand() < epsilon:
            action = np.random.randint(self.actions_count)
        else:
            action = np.argmax(self.Q[self.s])

        # Transition probability for the chosen action
        p = self.P[self.s, action]
        next_state = np.random.choice(self.states_count, p=p)
        reward = self.R[self.s, action]
        
        # Update the current state
        self.s = next_state
        if self.max_state < next_state:
            print("Yay changing max state from", self.max_state, "to", next_state)
            self.max_state = next_state
        
        return next_state, reward, action
    
    def policy_always_right(self):
        """ Always go right until the final state and stay there"""
        P = np.full((self.states_count, 2), (0, 1))
        
        return P
    
    def policy_b(self):
        """ Always go right until the final state and stay there"""
        P = np.full((self.states_count, 2), (0.35, 0.65))
        
        return P
    
    def policy_matrix(self, policy):
        # Create a policy matrix that represents the transition probabilities under the given policy.
        policy_mat = np.zeros((self.states_count, self.states_count))
        for s in range(self.states_count):
            policy_prob = policy[s]
            # Calculate the expected transition probability for each state under the given policy
            policy_mat[s] = policy_prob[0] * self.P[s, 0] + policy_prob[1] * self.P[s, 1]
        return policy_mat

    def reward_vector(self, policy):
        # Create a reward vector that represents the expected immediate reward from each state under the given policy.
        reward_vec = np.zeros(self.states_count)
        for s in range(self.states_count):
            policy_prob = policy[s]
            # Calculate the expected immediate reward for each state under the given policy
            reward_vec[s] = policy_prob[0] * self.R[s, 0] + policy_prob[1] * self.R[s, 1]
        return reward_vec
    
    def value_true(self, gamma, policy):
        p = self.policy_matrix(policy)
        r = self.reward_vector(policy=policy) 
        i = np.eye(6)
        return np.linalg.inv(i - gamma * p) @ r
    
    def q_star_via_qvi(self):
        # Threshold for convergence
        threshold = 1e-10
        Q_star = np.zeros((self.states_count, self.actions_count))
        while True:
            delta = 0
            Q_temp = Q_star.copy()
            for s in range(self.states_count):
                for a in range(self.actions_count):
                    Q_star[s, a] = self.R[s, a] + self.gamma * sum(self.P[s, a, s_prime] * max(Q_temp[s_prime]) for s_prime in range(self.states_count))
                    delta = max(delta, abs(Q_star[s, a] - Q_temp[s, a]))
                    
            if delta < threshold:
                break
            
        return Q_star
    
    def mb_opo(self, alpha=0, epsilon=0.2, horizon=10e6, q_mod=1):
        horizon = horizon
        Q_error_history = []
        Q_loss_history = []
        Q_star = self.q_star_via_qvi()
        
        # Initializing variables
        N_s = np.zeros(self.states_count)
        N_sa = np.zeros((self.states_count, self.actions_count))
        N_ss_a = np.zeros((self.states_count, self.actions_count, self.states_count))
        rewards = np.zeros((self.states_count, self.actions_count))

        # Initialize estimated model matrices
        P_hat = np.zeros((self.states_count, self.actions_count, self.states_count))
        R_hat = np.zeros((self.states_count, self.actions_count))
        
        self.Q = np.full((self.states_count, self.actions_count), (0.5, 0.5))  # Reset Q
        
        for t in range(int(horizon)):
            # Store current state
            state = self.s
            
            # Take a step using the epsilon-greedy policy
            next_state, reward, action = self.step(epsilon)
            
            # Update counters
            N_s[state] += 1
            N_sa[state, action] += 1
            N_ss_a[state, action, next_state] += 1
            rewards[state, action] += reward
            
            if t % q_mod != 0:
                continue
            
            # Update estimated model with current data
            for s_prime in range(self.states_count):
                P_hat[state, action, s_prime] = N_ss_a[state, action, s_prime] / N_sa[state, action]
                
            R_hat[state, action] = rewards[state, action]  / N_sa[state, action]
            
            # Perform Q-value iteration using the estimated model
            for s in range(self.states_count):
                for a in range(self.actions_count):
                    self.Q[s, a] = R_hat[s, a] + self.gamma * np.sum(
                        [P_hat[s, a, s_prime] * np.max(self.Q[s_prime]) for s_prime in range(self.states_count)]
                    )

            # Store the max error in Q-values and return loss in history for plotting
            Q_error_history.append(np.max(np.abs(Q_star - self.Q)))
            Q_loss_history.append(np.max(Q_star[0, 1] - self.Q[0, 1]))

        # Calculate the final policy from Q-values
        optimal_policy = np.argmax(self.Q, axis=1)

        # Return error and loss histories along with the optimal policy
        return Q_error_history, Q_loss_history, optimal_policy

        

In [74]:
# Create the environment
states_count = 5
env = Riverswim(states_count)

Q_error_history, Q_loss_history, optimal_policy = env.mb_opo(alpha=0, epsilon=0.2, horizon=20000, q_mod=1)
Q_loss_history

Yay changing max state from 0 to 1
Yay changing max state from 1 to 2


[1.540766611449996,
 1.540766611449996,
 1.540766611449996,
 1.540766611449996,
 1.325858502074996,
 1.289103907543746,
 1.2541870427390587,
 1.2210160211746055,
 1.189503550688375,
 1.159566703726456,
 1.1311266991126332,
 1.1041086947295013,
 1.078441590565526,
 1.0540578416097497,
 1.030893280101762,
 1.0088869466691737,
 0.9879809299082147,
 0.9681202139853038,
 0.9492525338585384,
 0.9313282377381112,
 0.9143001564237054,
 0.89812347917502,
 0.8827556357887688,
 0.8681561845718302,
 0.8542867059157385,
 0.8411107011924513,
 0.8285934967053286,
 0.816702152442562,
 0.8054053753929337,
 0.7946734371957868,
 0.7844780959084973,
 0.7747925216855722,
 0.7655912261737934,
 0.7568499954376036,
 0.7485458262382232,
 0.7406568654988118,
 0.733162352796371,
 0.7260425657290522,
 0.7192787680150994,
 0.7128531601868442,
 0.7067488327500018,
 0.7009497216850015,
 0.6954405661732512,
 0.6902068684370884,
 0.8991177945532993,
 0.8955752353981342,
 0.6878992017843043,
 0.683636322267589,
 0.6795

In [72]:
Q_error_history

[2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
 2.293019473123776,
