In [1]:
import os

import numpy as np
from scipy.stats import sem
import importlib
import gymnasium as gym
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import patches
import matplotlib.animation as manimation

import gym_env
from utils import softmax

In [2]:
class TwoStepStochastic:
    def __init__(self, size=7, prob_common=0.5, seed=None, stoch_states={1}):
        """
        Two-step task environment with stochastic transitions.
        
        Args:
            size: Number of states
            prob_common: Probability of common transition (default 0.5)
            seed: Random seed for reproducibility
        """
        self.size = size
        self.prob_common = prob_common
        self.rng = np.random.RandomState(seed)
        
        self.envstep = self._build_transition_table()
        self.stochastic_states = stoch_states
        
    def _build_transition_table(self):
        """Build the base transition lookup table."""
        envstep = []
        for s in range(self.size):
            envstep.append([[0, 0], [0, 0]])
        envstep = np.array(envstep)
        
        # State 0 -> 1, 2 (deterministic)
        envstep[0, 0] = [1, 0]
        envstep[0, 1] = [2, 0]
        
        # State 1 -> 3, 4 (will be stochastic)
        envstep[1, 0] = [3, 1]  # common for action 0
        envstep[1, 1] = [4, 1]  # common for action 1
        
        # State 2 -> 5, 6 (not stochastic)
        envstep[2, 0] = [5, 1]  # common for action 0
        envstep[2, 1] = [6, 1]  # common for action 1
        
        return envstep
        
    def step_deterministic(self, state, action):
        state, done = self.envstep[state, action]
        return state, done

    def step(self, state, action):
        # Get the "common" transition for this action
        common_state, done = self.envstep[state, action]
        
        # Handle stochastic transitions
        if state in self.stochastic_states:
            if self.rng.random() < self.prob_common:
                # Common transition
                next_state = common_state
            else:
                # Rare transition (flip to the other action's common state)
                rare_action = 1 - action
                next_state = self.envstep[state, rare_action][0]
        else:
            # Deterministic transition
            next_state = common_state
            
        return next_state, done
    
    def reset(self):
        """Reset to initial state."""
        return 0
    
    def get_transition_type(self, state, action, next_state):
        if state not in self.stochastic_states:
            return 'deterministic'
        
        common_state = self.envstep[state, action][0]
        if next_state == common_state:
            return 'common'
        else:
            return 'rare'


# Usage example:
if __name__ == "__main__":
    env = TwoStepStochastic(size=7, prob_common=0.5)
    
    # Run a simple episode
    state = env.reset()
    print(f"Start state: {state}")
    
    # First step
    action = 0
    state, done = env.step(state, action)
    print(f"After action {action}: state={state}, done={done}")
    
    # Second step
    action = 1
    next_state, done = env.step(state, action)
    transition_type = env.get_transition_type(state, action, next_state)
    print(f"After action {action}: state={next_state}, done={done}, transition={transition_type}")

Start state: 0
After action 0: state=1, done=0
After action 1: state=3, done=1, transition=rare


In [3]:
class SR_IS_TwoStep:
    def __init__(self, alpha=0.25, beta=1.0, _lambda=10, num_steps=250, policy="softmax", imp_samp=True, seed=None):
        # Hard code start and end locations as well as size
        self.start_loc = 0
        self.target_locs = [3,4,5,6]
        self.start_locs = [0]
        self.size = 7
        self.agent_loc = self.start_loc
        self.seed = seed

        # Construct the transition probability matrix and env
        self.T = self.construct_T()
        self.env = TwoStepStochastic(size=7, prob_common=0.5, seed=self.seed)
        
        # Get terminal states
        self.terminals = np.diag(self.T) == 1
        # Calculate P = T_{NT}
        self.P = self.T[~self.terminals][:,self.terminals]

        # Set reward
        self.reward_nt = -0.1
        self.r = np.full(len(self.T), self.reward_nt)
        # Reward of terminal states depends on if we are replicating reward revaluation or policy revaluation
        self.r[self.terminals] = [5,-5,0,1]

        # Precalculate exp(r) for use with LinearRL equations
        self.expr_t = np.exp(self.r[self.terminals] / _lambda)
        self.expr_nt = np.exp(self.reward_nt / _lambda)

        # Params
        self.alpha = alpha
        self.beta = beta
        self.gamma = self.expr_nt
        self._lambda = _lambda
        self.num_steps = num_steps
        self.policy = policy
        self.imp_samp = imp_samp

        # Model
        self.DR = self.get_DR()
        self.Z = np.full(self.size, 0.01)

        self.V = np.zeros(self.size)
        self.one_hot = np.eye(self.size)

    def construct_T(self):
        """
        Manually construct the transition matrix
        """
        # For NHB two-step task
        T = np.zeros((self.size, self.size))
        T[0, 1:3] = 0.5
        T[1, 3:5] = 0.5
        T[2, 5:7] = 0.5
        T[3:7, 3:7] = np.eye(4)

        return T

    def get_DR(self):
        """
        Returns the DR initialization based on what decision policy we are using, values are filled with 0.01 if using softmax to avoid div by zero
        """
        if self.policy == "softmax":
            DR = np.full((self.size, self.size), 0.01)
            np.fill_diagonal(DR, 1)
            DR[np.where(self.terminals)[0], np.where(self.terminals)[0]] = (1/(self.gamma))
        else:
            DR = np.eye(self.size)

        return DR

    def update_Z(self):
        self.Z[~self.terminals] = self.DR[~self.terminals][:,~self.terminals] @ self.P @ self.expr_t
        self.Z[self.terminals] = self.expr_t

    def update_V(self):
        self.V = np.log(self.Z) * self._lambda
    
    def get_successor_states(self, state):
        """
        Manually define the successor states based on which state we are in
        """
        return np.where(self.T[state, :] != 0)[0]

    def importance_sampling(self, state, s_prob):
        """
        Performs importance sampling P(x'|x)/u(x'|x). P(.) is the default policy, u(.) is the decision policy
        """
        successor_states = self.get_successor_states(state)
        p = 1/len(successor_states)
        w = p/s_prob
                
        return w

    def select_action(self, state):
        """
        Action selection based on our policy
        Options are: [random, softmax]
        """
        if self.policy == "random":
            action = np.random.choice([0,1])

            return action
        
        elif self.policy == "softmax":
            successor_states = self.get_successor_states(state)
            action_probs = np.full(2, 0.0)   # We can hardcode this because every state has 2 actions

            v_sum = sum(np.exp((np.log(self.Z[s] + 1e-20) * self._lambda) / self.beta) for s in successor_states)

            # if we don't have enough info, random action
            if v_sum == 0:
                return  np.random.choice([0,1])

            for action in [0,1]:
                new_state, _ = self.env.step_deterministic(state, action)
                action_probs[action] = np.exp((np.log(self.Z[new_state] + 1e-20) * self._lambda) / self.beta ) / v_sum
                
            action = np.random.choice([0,1], p=action_probs)
            s_prob = action_probs[action]

            return action, s_prob

    def get_D_inv(self):
        """
        Calculates the DR directly using matrix inversion, used for testing
        """
        I = np.eye(self.size)
        D_inv = np.linalg.inv(I-self.gamma*self.T)
        
        return D_inv

    def learn(self):
        """
        Agent explores the maze according to its decision policy and and updates its DR as it goes
        """
        if self.seed is not None:
            np.random.seed(seed=self.seed)

        # Iterate through number of steps
        for i in range(self.num_steps):
            # Agent gets some knowledge of terminal state values
            if i == 2:
                self.Z[self.terminals] = self.expr_t
            # Current state
            state = self.agent_loc

            # Choose action
            if self.policy == "softmax":
                action, s_prob = self.select_action(state)
            else:
                action = self.select_action(state)
        
            # Take action
            next_state, done = self.env.step(state, action)
            # print(f"state: {state} | action: {action} | next state: {next_state} | done: {done}")
            # Importance sampling
            if self.imp_samp:
                w = self.importance_sampling(state, s_prob)
                w = 1 if np.isnan(w) or w == 0 else w
            else:
                w = 1
            
            # Update default representation
            target = self.one_hot[state] + self.gamma * self.DR[next_state]
            self.DR[state] = (1 - self.alpha) * self.DR[state] + self.alpha * target * w

            # Update Z-Values
            self.Z[~self.terminals] = self.DR[~self.terminals][:,~self.terminals] @ self.P @ self.expr_t
            
            if done:
                self.agent_loc = self.start_loc
                continue
            
            # Update state
            state = next_state
            self.agent_loc = state

        # Update DR at terminal state
        self.update_Z()
        self.update_V()

In [4]:
set_val = 1
lambd = set_val
alpha = 0.05
beta = set_val
num_steps = 250
agent =  SR_IS_TwoStep(_lambda=lambd, alpha=alpha, beta=beta, num_steps=num_steps, policy="softmax", imp_samp=True, seed=1234)

In [5]:
agent.learn()

In [6]:
print(f"Value of state S2: {agent.V[1]} | Value of state S3: {agent.V[2]}\nSoftmax: {softmax(x=np.array([agent.V[1], agent.V[2]]))}")

Value of state S2: 3.6255867495541376 | Value of state S3: 0.9313754246840471
Softmax: [0.9366842 0.0633158]
