In [1]:
import gym
import numpy as np

In [2]:
env = gym.make("FrozenLake-v0")

In [3]:
SZ_ACTION_SPACE = env.action_space.n
SZ_OBS_SPACE = env.observation_space.n

In [86]:
class QLearner:
    def __init__(self,env,alpha=0.1 ,  gamma=0.99,epsilon=1,epsilon_decay_dec = 0.001,min_epsilon = 0.01):
        self.env = env
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay_dec
        self.Q = np.zeros((env.observation_space.n,env.action_space.n))
        self.min_epsilon = min_epsilon
    def run(self,episodes):
        state = self.env.reset()
        success_rate = []
        for episode in range(episodes):
            state = self.env.reset()
            done = False
            while not done:
                if np.random.uniform(0,1)<self.epsilon:
                    action = self.env.action_space.sample()
                else:
                    action = np.argmax(self.Q[state,:])
                next_state, reward, done, info = self.env.step(action)

                self.Q[state,action]= (1-self.alpha)*self.Q[state,action] + self.alpha*(reward + self.gamma*np.max(self.Q[next_state,:]))
                state = next_state
            self.epsilon = max(self.min_epsilon, np.exp(-self.epsilon_decay*episode))
            
            
    def evaluate(self,episodes):
        total_reward = 0
        
        for episode in range(episodes):
            state = self.env.reset()
            done = False
            while not done:
                action = np.argmax(self.Q[state,:])
                next_state, reward, done, info = self.env.step(action)
                total_reward += reward
                state = next_state
        return total_reward/episodes


In [87]:
env_name = "FrozenLake-v0"
env = gym.make(env_name)
x = QLearner(env)

In [88]:
x.run(10000)

In [89]:
x.Q

array([[0.47353034, 0.45957649, 0.46133681, 0.45795379],
       [0.31451843, 0.36923785, 0.33471545, 0.42261253],
       [0.40177954, 0.39759038, 0.37314446, 0.40593967],
       [0.30250821, 0.23689135, 0.24378612, 0.39862575],
       [0.49552109, 0.34896256, 0.25260687, 0.27780119],
       [0.        , 0.        , 0.        , 0.        ],
       [0.34614926, 0.19814225, 0.18092307, 0.16099492],
       [0.        , 0.        , 0.        , 0.        ],
       [0.35592092, 0.43026885, 0.38395919, 0.53596223],
       [0.4986949 , 0.59041604, 0.51573299, 0.47652002],
       [0.64148871, 0.38534917, 0.34256716, 0.39987162],
       [0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        ],
       [0.40979281, 0.57443168, 0.66077835, 0.56945183],
       [0.7274544 , 0.82849   , 0.73206156, 0.72575282],
       [0.        , 0.        , 0.        , 0.        ]])

In [90]:
x.evaluate(100)

0.7