In [None]:
%run Imports.ipynb
%run Discrete_Agent.ipynb

In [None]:
class PolicyIterationAgent_rsa(DiscreteAgent):
    
    def __init__(self, env):
        
        self.env = env
        self.is_stable = False
        self.values = {}
        self.policy = {}
        self.delta = 1.0
        self.theta = 0.1
        
        self.sweep_no = 0
        self.is_converged = False
        self.max_sweeps = 1000
        
        for state in self.env.states:
            self.values[state] = 0
            actions_list = self.env.actions[state]
            self.policy[state] = random.choice(actions_list)
        
    def get_action(self, state):
        action = self.policy[state]
        return action
    
    def evaluate_policy(self):
        self.is_converged = False
        while self.delta >= self.theta and self.sweep_no < self.max_sweeps:
            self.sweep_no += 1
            self.delta = 0
            for state in self.env.states:
                val = self.values[state]
                act = self.policy[state]
                val_sum = self.env.rewards[(state,act)]
                for dest in self.env.states:
                    if (state,act,dest) in self.env.transitions:
                        val_sum += (self.env.gamma * self.values[dest]) * self.env.transitions[(state,act,dest)]
                self.values[state] = copy.deepcopy(val_sum)
                self.delta = max(self.delta, abs(val - val_sum))
                
        if self.delta < self.theta:
            self.is_converged = True
                    
    def improve_policy(self): 
        self.is_stable = True
        for state in self.env.states:
            old_act = self.policy[state]
            val_max = self.values[state]
            act_max = self.policy[state]
            
            for act in self.env.actions[state]:
                val_sum = self.env.rewards[(state,act)]
                
                for dest in self.env.states:
                    if (state,act,dest) in self.env.transitions:
                        val_sum += (self.env.gamma * self.values[dest]) * self.env.transitions[(state,act,dest)]
                        
                if val_sum > val_max:
                    val_max = val_sum
                    act_max = act   
                elif val_sum == val_max:
                    act_max = random.choice([act_max, act])
            
            self.values[state] = copy.deepcopy(val_max)
            self.policy[state] = copy.deepcopy(act_max)
            if old_act != act_max:
                self.is_stable = False
                
    def update(self):
        # Evaluate + Update
        self.sweep_no = 0
        while not self.is_stable:
            self.evaluate_policy()
            self.improve_policy()
        return self.sweep_no, self.is_converged, self.is_stable, self.values, self.policy

In [None]:
class PolicyIterationAgent_rsas(DiscreteAgent):
    
    def __init__(self, env):
        
        self.env = env
        self.is_stable = False
        self.values = {}
        self.policy = {}
        self.delta = 1.0
        self.theta = 0.1
        
        self.sweep_no = 0
        self.is_converged = False
        self.max_sweeps = 1000
        
        for state in self.env.states:
            self.values[state] = 0
            actions_list = self.env.actions[state]
            self.policy[state] = random.choice(actions_list)
        
    def get_action(self, state):
        action = self.policy[state]
        return action
    
    def evaluate_policy(self):
        self.is_converged = False
        while self.delta >= self.theta and self.sweep_no < self.max_sweeps:
            self.sweep_no += 1
            self.delta = 0
            for state in self.env.states:
                val = self.values[state]
                act = self.policy[state]
#                 val_sum = self.env.rewards[(state,act)]
                val_sum = 0.0
                for dest in self.env.states:
                    if (state,act,dest) in self.env.transitions:
                        val_sum += (self.env.gamma * self.values[dest]) * self.env.transitions[(state,act,dest)]
                        val_sum += self.env.rewards[(state,act,dest)] * self.env.transitions[(state,act,dest)]
                self.values[state] = copy.deepcopy(val_sum)
                self.delta = max(self.delta, abs(val - val_sum))
                
        if self.delta < self.theta:
            self.is_converged = True
                    
    def improve_policy(self): 
        self.is_stable = True
        for state in self.env.states:
            old_act = self.policy[state]
            val_max = self.values[state]
            act_max = self.policy[state]
            
            for act in self.env.actions[state]:
#                 val_sum = self.env.rewards[(state,act)]
                val_sum = 0.0
                for dest in self.env.states:
                    if (state,act,dest) in self.env.transitions:
                        val_sum += (self.env.gamma * self.values[dest]) * self.env.transitions[(state,act,dest)]
                        val_sum += self.env.rewards[(state,act,dest)] * self.env.transitions[(state,act,dest)]
                if val_sum > val_max:
                    val_max = val_sum
                    act_max = act   
                elif val_sum == val_max:
                    act_max = random.choice([act_max, act])
            
            self.values[state] = copy.deepcopy(val_max)
            self.policy[state] = copy.deepcopy(act_max)
            if old_act != act_max:
                self.is_stable = False
                
    def update(self):
        # Evaluate + Update
        self.sweep_no = 0
        while not self.is_stable:
            self.evaluate_policy()
            self.improve_policy()
        return self.sweep_no, self.is_converged, self.is_stable, self.values, self.policy