# Install Libraries

In [None]:
!pip install gym

In [None]:
!git clone https://github.com/zhpinkman/armed-bandit.git

In [None]:
!pip install ./armed-bandit

# Impor Libraries

In [1]:
from amalearn.reward import RewardBase
from amalearn.agent import AgentBase

In [2]:
from amalearn.environment import EnvironmentBase
import gym




In [471]:
import random
from copy import deepcopy
import matplotlib.pyplot as plt
import numpy as np
from prettytable import PrettyTable

# Environment

In [437]:
from gym.spaces import Discrete, Box

# Action:
# 0 1 2
# 3 4 5
# 6 7 8

class Environment(EnvironmentBase):
    def __init__(self, obstacle = [] ,id = 0, action_count=9, actionPrice = -1, goalReward = 100
                 , punish=-10, j_limit = 10, i_limit = 10, p = 0.8, container=None):
        """
        initialize your variables
        """
        
        self.obstacle = obstacle
        
        self.x_min = 1
        self.x_max = i_limit
        
        self.y_min = 1
        self.y_max = j_limit
        
        self.reset()
        
        self.action_count = action_count
        self.actionPrice = actionPrice
        self.goalReward = goalReward
        self.punish = punish
        self.p = p
        
        action_space = Discrete(action_count)
        state_space = Box(low=1, high=max(i_limit, j_limit), shape=(1,2), dtype=int)
        
        self.action_list = list(range(1,10))
        self.state_list = []
        
        for i in range(1, i_limit+1):
            for j in range(1, j_limit+1):
                self.state_list.append(np.array([i, j]))
        
        super(Environment, self).__init__(action_space=action_space, state_space=state_space, id=id ,container=container)

        
    def isStatePossible(self, state):
        """if given state is possible (not out of the grid and not obstacle) return ture"""
        if self.x_min <= state[0] <= self.x_max and self.y_min <= state[1] <= self.y_max:
            for obstacle_item in self.obstacle:
                if (state==obstacle_item).all():
                    return False
            return True
        else:
            return False
    
    
    def isAccessible(self, state, state_p):
        """if given state is Accesible (we can reach state_p by doing an action from state) return true"""
        return abs(state[0]-state_p[0]) <= 1 and abs(state[1] - state_p[1]) <= 1 and self.isStatePossible(state_p)
            
    def getTransitionStatesAndProbs(self, state, action, state_p):
        """return probability of transition or T(sp,a,s)"""
        
        actions = self.available_actions_state(state)
        
        if (self.calculate_next_state(state,action)==state_p).all():
            if self.isAccessible(state, state_p):
                return self.p + (1-self.p) / len(actions)
            else:
                return self.p
        else:
            if self.isAccessible(state, state_p):
                return (1-self.p) / len(actions)
            else:
                return 0

    
    def getReward(self, state, action, state_p):
        """return reward of transition"""
        if self.terminated_state(state_p):
            return self.actionPrice + self.goalReward
        
        if self.isStatePossible(state_p):
            return self.actionPrice
        else:
            return self.actionPrice + self.punish
        
    def sample_all_rewards(self):
        return 
    
    def calculate_reward(self, action):
        return self.getReward(self.current_state, action, self.calculate_next_state(self.current_state, action))

    def available_states_state(self, state):
        states = []
        for i in [-1, 0, +1]:
            for j in [-1, 0, +1]:
                new_state = np.array([state[0]+i, state[1]+j])
                
                if self.isAccessible(state, new_state):
                    states.append(new_state)
        return states
    
    def terminated(self):
        return self.terminated_state(self.current_state)
    
    def terminated_state(self, state):
        return (state==np.array([1,1])).all()
        
    def observe(self):
        return self.current_state 

    def available_actions(self):
        return self.available_actions_state(self.current_state)
    
    def available_actions_state(self, state):
        output_actions = []
        for action in range(self.action_count):
            next_state = self.calculate_next_state(state, action)
            
            if self.isAccessible(state, next_state):
                output_actions.append(action)
        
        return output_actions
        
    
    def calculate_next_state(self, state, action):
        return np.array([state[0] + (action%3 -1), state[1] + (int(action/3)-1) ])
        
    def next_state(self, action):
        actions = self.available_actions()
        
        if action not in actions:
            actions.append(action)
                
        probabilities = []
                
        for action2 in actions:
            state2 = self.calculate_next_state(self.current_state, action2)
            probabilities.append(self.getTransitionStatesAndProbs(self.current_state, action, state2))
        
        final_action = random.choices(population=actions, weights=probabilities, k=1)[0]
        
        real_next_state = self.calculate_next_state(self.current_state, final_action)
        
        if not self.isStatePossible(real_next_state):
            real_next_state = self.current_state
        
        self.last_action = action
        
        self.sliped = not (final_action==action)
        
        self.current_state = real_next_state
        
        return

    def reset(self):
        self.current_state = np.array([15, 15])
        
        self.last_action = None
        self.sliped = None

    def render(self, mode='human'):
        print(f"{self.current_state} \t {self.last_action} \t {self.sliped}")
        return 

    def close(self):
        return

In [438]:
grid_states =  [np.array([7, 1]), np.array([8, 1]), np.array([7, 2]), np.array([8, 2])
                ,np.array([7, 3]), np.array([8, 3]), np.array([7, 4]), np.array([8, 4])
                ,np.array([13, 8]), np.array([14, 8]), np.array([15, 8])
                ,np.array([13, 9]), np.array([14, 9]), np.array([15, 9])
                ,np.array([6, 12]), np.array([7, 12]), np.array([6, 13]), np.array([7, 13])
                ,np.array([6, 14]), np.array([7, 14]), np.array([6, 15]), np.array([7, 15])]

In [439]:
base_environment = Environment(obstacle = grid_states ,id = 0, action_count=9, actionPrice = -1, goalReward = 100
                                , punish=-10, j_limit = 15, i_limit = 15, p = 0.8, container=None)

# Agent

In [515]:
import numpy as np

class Agent(AgentBase):
    def __init__(self, id, environment, discount, theta):
        
        # initialize a random policy and V(s) = 0 for each state
        self.environment = environment
        
        # mapp states to its ids
#         self.mapp = {}
        
        # init V
        self.V = {}
        
        # init policy
        self.policy = {}
        
        super(Agent, self).__init__(id, environment)
        
        self.discount = discount
        
        self.theta = theta
                
        self.value_initialization()
        
        self.policy_initialization()
    
    def value_initialization(self):
        for state in self.environment.state_list:
            self.V[tuple(state)] = 0
        
    def policy_initialization(self):
        for state in self.environment.state_list:
            self.policy[tuple(state)] = random.choice(self.environment.action_list)
        
    def policy_evaluation(self):
        pass
    
    def policy_improvement(self):
        pass
    
    def value_iteration(self):
        for iter in range(self.theta["max_iter"]):
            new_V = {}

            delta = 0
            for state in self.environment.state_list:
                new_V[tuple(state)] = -math.inf

                available_actions = self.environment.available_actions_state(state)
                available_states  = self.environment.available_states_state(state)

                for action in available_actions:
                    sum = 0
                    for state_p in available_states:
                        p_sp = self.environment.getTransitionStatesAndProbs(state, action, state_p)
                        r_sp = self.environment.getReward(state, action, state_p)
                        v_sp = self.V[tuple(state_p)]

                        sum += p_sp * (r_sp + self.discount * v_sp)

                    new_V[tuple(state)] = max(new_V[tuple(state)], sum)
                delta = max(delta, abs(self.V[tuple(state)] - new_V[tuple(state)]))
                
            print(f"iter = {iter} -> delta = {round(delta, 2)}")
            self.V = deepcopy(new_V)

            if delta < self.theta["delta_treshold"]:
                break
    
    def policy_extraction(self):
        for state in self.environment.state_list:

            available_actions = self.environment.available_actions_state(state)
            available_states  = self.environment.available_states_state(state)
            
            max_value = -math.inf
            argmax = None
            
            for action in available_actions:
                sum = 0
                for state_p in available_states:
                    p_sp = self.environment.getTransitionStatesAndProbs(state, action, state_p)
                    v_sp = self.V[tuple(state_p)]

                    sum += p_sp * v_sp
            
                if (state==np.array([1,1])).all():
                    print(action, sum)
            
                if sum > max_value:
                    max_value = sum
                    argmax = action
                    
            self.policy[tuple(state)] = argmax
    
    def print_value(self):
        p = PrettyTable()

        for j in range(1,16):
            row = []
            for i in range(1, 16):
                row.append(int(self.V[(i,j)]))
            p.add_row(row)

        print (p.get_string(header=False, border=True))
    
    def print_policy(self):
        p = PrettyTable()

        for j in range(1,16):
            row = []
            for i in range(1, 16):
                row.append(self.action_symbol(int(self.policy[(i,j)])))
            p.add_row(row)

        print (p.get_string(header=False, border=True))
    
    def take_action(self) -> (object, float, bool, object):
        # observation, reward, done, info
        return self.environment.step(random.choice(self.environment.action_list))
    
    def action_symbol(self, action):
        if action == 0:
            return "↖"
        elif action == 1:
            return "↑"
        elif action == 2:
            return "↗"
        elif action == 3:
            return "←"
        elif action == 4:
            return "•"
        elif action == 5:
            return "→"
        elif action == 6:
            return "↙"
        elif action == 7:
            return "↓"
        elif action == 8:
            return "↘"

In [521]:
theta = {"max_iter": 50, "delta_treshold": 3}
agent = Agent(id=0, environment=base_environment, discount=0.9, theta=theta)

In [522]:
agent.value_iteration()

iter = 0 -> delta = 84.0
iter = 1 -> delta = 75.33
iter = 2 -> delta = 66.97
iter = 3 -> delta = 60.19
iter = 4 -> delta = 54.09
iter = 5 -> delta = 48.66
iter = 6 -> delta = 43.79
iter = 7 -> delta = 39.4
iter = 8 -> delta = 35.46
iter = 9 -> delta = 31.91
iter = 10 -> delta = 28.72
iter = 11 -> delta = 25.85
iter = 12 -> delta = 23.27
iter = 13 -> delta = 20.94
iter = 14 -> delta = 18.84
iter = 15 -> delta = 16.96
iter = 16 -> delta = 15.26
iter = 17 -> delta = 13.74
iter = 18 -> delta = 12.36
iter = 19 -> delta = 11.13
iter = 20 -> delta = 10.01
iter = 21 -> delta = 9.01
iter = 22 -> delta = 8.11
iter = 23 -> delta = 7.3
iter = 24 -> delta = 6.57
iter = 25 -> delta = 5.91
iter = 26 -> delta = 5.32
iter = 27 -> delta = 4.79
iter = 28 -> delta = 4.31
iter = 29 -> delta = 3.88
iter = 30 -> delta = 3.49
iter = 31 -> delta = 3.14
iter = 32 -> delta = 2.83


In [523]:
agent.print_value()

+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
| 801 | 793 | 693 | 605 | 528 | 466 | 415 | 192 | 192 | 192 | 190 | 185 | 163 | 141 | 122 |
| 793 | 788 | 690 | 604 | 527 | 466 | 415 | 220 | 220 | 219 | 216 | 190 | 165 | 141 | 122 |
| 693 | 690 | 681 | 600 | 527 | 465 | 415 | 257 | 255 | 254 | 222 | 193 | 165 | 141 | 122 |
| 605 | 604 | 600 | 588 | 522 | 461 | 408 | 343 | 297 | 258 | 224 | 193 | 165 | 141 | 122 |
| 527 | 527 | 526 | 521 | 508 | 454 | 402 | 347 | 300 | 260 | 225 | 193 | 166 | 142 | 123 |
| 459 | 459 | 459 | 458 | 452 | 438 | 393 | 348 | 301 | 260 | 225 | 193 | 166 | 143 | 123 |
| 399 | 399 | 399 | 399 | 397 | 391 | 378 | 340 | 301 | 260 | 225 | 194 | 167 | 143 | 123 |
| 347 | 347 | 347 | 347 | 346 | 344 | 338 | 325 | 293 | 259 | 224 | 195 | 169 | 143 | 123 |
| 300 | 300 | 300 | 300 | 300 | 300 | 298 | 292 | 280 | 252 | 223 | 194 | 169 | 142 | 123 |
| 260 | 260 | 260 | 260 | 260 | 260 | 259 | 257 | 251 | 240 | 216 | 192 | 167 | 

In [524]:
agent.policy_extraction()

4 800.12854447486
5 793.8866168122071
7 793.8865284348302
8 789.5489453900622


In [525]:
agent.print_policy()

+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+
| • | ← | ← | ← | ← | ← | ← | ↘ | ↓ | ↙ | ↙ | ↙ | ↙ | ↙ | ↙ |
| ↑ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↘ | ↓ | ↙ | ↙ | ↙ | ↙ | ↙ | ↙ |
| ↑ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↘ | ↓ | ↙ | ↙ | ↙ | ↙ | ↙ | ↙ |
| ↑ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↙ | ↙ | ↙ | ↙ | ↙ | ↙ | ↙ | ↙ |
| ↑ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ← | ↙ | ↙ | ↙ | ↙ | ↙ | ↙ | ↙ |
| ↑ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ← | ← | ← | ↙ | ↙ | ↙ | ↙ |
| ↑ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ← | ↙ | ← | ← |
| ↑ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ← | ↖ | ↖ |
| ↑ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↙ | ↙ |
| ↑ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ← | ← |
| ↗ | ↗ | ↗ | ↗ | ↑ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ |
| ↗ | ↗ | ↗ | ↗ | ↗ | ↑ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ |
| ↗ | ↗ | ↗ | ↗ | ↑ | ↖ | ↗ | ↑ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ |
| ↗ | ↗ | ↗ | ↗ | ↑ | ↖ | ↗ | ↑ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ |
| ↗ | ↗ | ↗ | ↗ | ↑ | ↖ | ↗ | ↑ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ | ↖ |
+---+---