# Q-Learning in Windy Gridworld Environment

In [1]:
import gym
import numpy as np
import sys
import matplotlib.pyplot as plt
from collections import defaultdict
from gym.envs.toy_text import discrete

In [2]:
UP, RIGHT, DOWN, LEFT = 0, 1, 2, 3


class WindyGridworldEnv(discrete.DiscreteEnv):
    def __init__(self):
        self.shape = (7, 10)
        nS = self.shape[0] * self.shape[1]
        nA = 4

        winds = np.zeros(self.shape)
        winds[:, (3, 4, 5, 8)] = 1
        winds[:, (6, 7)] = 2

        self.goal = (3, 7)

        isd = np.zeros(nS)
        isd[np.ravel_multi_index((3, 0), self.shape)] = 1.0

        P = {}
        for s in range(nS):
            position = np.unravel_index(s, self.shape)
            P[s] = {a: [] for a in range(nA)}
            P[s][UP] = self._calc_transition_probabilities(
                position, [-1, 0], winds
            )  # -1 => Moves row up, 0 => Same Col
            P[s][RIGHT] = self._calc_transition_probabilities(
                position, [0, 1], winds
            )  # -1 => Moves row up, 0 => Same Col
            P[s][DOWN] = self._calc_transition_probabilities(
                position, [1, 0], winds
            )  # -1 => Moves row up, 0 => Same Col
            P[s][LEFT] = self._calc_transition_probabilities(
                position, [0, -1], winds
            )  # -1 => Moves row up, 0 => Same Col

        super(WindyGridworldEnv, self).__init__(nS, nA, P, isd)

    def _calc_transition_probabilities(self, pos, delta, winds):
        new_pos = (
            np.array(pos) + np.array(delta) + np.array([-1, 0]) * winds[tuple(pos)]
        )  # isn't pos already a tuple??
        new_pos = self._limit_coordinates(new_pos).astype(int)
        new_state = np.ravel_multi_index(tuple(new_pos), self.shape)

        is_done = tuple(new_pos) == self.goal

        return [(1.0, new_state, -1.0, is_done)] # prob, new_stae, reward, done

    def _limit_coordinates(self, coord):
        return np.clip(coord, (0, 0), np.array(self.shape) - 1)

    def render(self, mode="human"):
        outfile = sys.stdout

        for s in range(self.nS):
            pos = np.unravel_index(s, self.shape)

            if self.s == s:
                output = " s "
            elif pos == self.goal:
                output = " G "

            else:
                output = " _ "

            if pos[1] == 0:
                output = output.lstrip()
            if pos[1] == self.shape[1] - 1:
                output = output.rstrip()
                output += "\n"

            outfile.write(output)
        outfile.write("\n")


In [3]:
env = WindyGridworldEnv()

In [4]:
env.reset()
env.render()

_  _  _  _  _  _  _  _  _  _
_  _  _  _  _  _  _  _  _  _
_  _  _  _  _  _  _  _  _  _
s  _  _  _  _  _  _  G  _  _
_  _  _  _  _  _  _  _  _  _
_  _  _  _  _  _  _  _  _  _
_  _  _  _  _  _  _  _  _  _



In [5]:
print(env.step(1))
env.render()

(31, -1.0, False, {'prob': 1.0})
_  _  _  _  _  _  _  _  _  _
_  _  _  _  _  _  _  _  _  _
_  _  _  _  _  _  _  _  _  _
_  s  _  _  _  _  _  G  _  _
_  _  _  _  _  _  _  _  _  _
_  _  _  _  _  _  _  _  _  _
_  _  _  _  _  _  _  _  _  _



# It is Q-Learning Time! 

In [6]:
def epsilon_greedy_policy(Q, state, nA, epsilon):
    probs = np.ones(nA) * epsilon / nA # should this not be nA - 1?

    best_action = np.argmax(Q[state])
    probs[best_action] += 1 - epsilon

    return probs

In [7]:
def Q_learning(env: gym.Env, episodes: int, learning_rate: float, gamma: float, epsilon: float):
    Q = defaultdict(lambda: np.zeros(env.action_space.n))

    x = np.arange(episodes)
    y = np.zeros(episodes)

    for ep in range(episodes):
        state = env.reset()

        for step in range(int(1e5)):
            probs = epsilon_greedy_policy(Q, state, env.action_space.n, epsilon)
            action = np.random.choice(np.arange(env.action_space.n), p=probs)

            next_state, reward, done, _  = env.step(action)

            Q[state][action] += learning_rate * ((reward + gamma * np.max(Q[next_state])) - Q[state][action])

            if done:
                y[ep] = step
                break
        
            state = next_state

    return x, y, Q


In [8]:
# x, y, Q = Q_learning(env, 1000, 0.01, .99, .05)

In [9]:
# plt.figure(figsize=(18, 5))
# plt.plot(x, y)

# plt.xlabel("episodes")
# plt.ylabel("steps required")

In [10]:
# state, done = env.reset(), False
# env.render()
# rewards = []

# while not done:
#     action = np.argmax(Q[state])
#     # print(f"Taking Action: {action}")
#     # actions.append(action)
#     state, reward, done, _ = env.step(action)
#     # env.render()

#     rewards.append(reward)


In [52]:
class Agent:
    def __init__(
        self, lr, gamma, n_actions, n_states, eps_start, eps_end, eps_dec
    ) -> None:
        self.lr = lr
        self.gamma = gamma
        self.n_actions = n_actions
        self.n_states = n_states
        self.epsilon = eps_start
        self.eps_end = eps_end
        self.eps_dec = eps_dec

        self.Q = {s: np.random.randn(self.n_actions) for s in range(self.n_states)}

        self.total_steps = 0

    def _epsilon_greedy_policy(self, state):
        probs = (
            np.ones(self.n_actions) * self.epsilon / self.n_actions
        )  # should this not be nA - 1?

        best_action = np.argmax(self.Q[state])
        probs[best_action] += 1 - self.epsilon

        return probs

    def choose_action(self, state):
        probs = self._epsilon_greedy_policy(state)
        return np.random.choice(np.arange(self.n_actions), p=probs)

    def eps_udpate(self):
        self.epsilon = max(self.epsilon - self.eps_dec, self.eps_end)

    def learn(self, state, action, reward, state_):
        self.total_steps += 1
        self.Q[state][action] += self.lr * (
            (reward + self.gamma * np.amax(self.Q[state_])) - self.Q[state][action]
        )

        self.eps_udpate()


In [54]:
from tqdm import tqdm


env = gym.make("FrozenLake-v1")
agent = Agent(
    lr=0.001,
    gamma=0.99,
    n_actions=env.action_space.n,
    n_states=env.observation_space.n,
    eps_start=0.1,
    eps_end=0.01,
    eps_dec=1e-3,
)

ngames = int(5e4)
win_pct = []
scores = []

for i in tqdm(range(ngames)):
    state, done = env.reset(), False
    score = 0

    while not done:
        action = agent.choose_action(state)
        next_state, reward, done, _ = env.step(action)
        agent.learn(state, action, reward, next_state)

        if agent.total_steps % 1000 == 0:
            print(f"Q[{state}]: {agent.Q[state]}")

        score += reward
        state = next_state

    scores.append(score)
    if i % 100 == 0:
        average = np.mean(scores[-100:])
        win_pct.append(average)


plt.figure(figsize=(18, 5))
plt.plot(scores[::100])
plt.plot(win_pct)
plt.legend(["Scores", "Win Percentages"])
plt.show()


  1%|          | 503/50000 [00:00<00:19, 2534.13it/s]

Q[8]: [-1.05641711 -1.25800186 -0.66419868  0.27465538]
Q[4]: [-0.85861905  0.17171838 -0.79147759 -0.96480823]
Q[9]: [-1.105696   -0.101317    1.08191194  0.63386408]
Q[6]: [-0.29655543  0.38539145  2.03970749  0.05484432]


  3%|▎         | 1306/50000 [00:00<00:18, 2572.70it/s]

Q[3]: [-0.073238   -0.99465814  1.48234467 -0.60721787]
Q[4]: [-0.85741122  0.43183953 -0.788412   -0.95997756]
Q[4]: [-0.85611448  0.47713673 -0.78644296 -0.95997756]
Q[0]: [-0.31364068  0.94795345 -0.00097085 -0.93631903]


  4%|▎         | 1830/50000 [00:00<00:18, 2575.06it/s]

Q[0]: [-3.13640683e-01  8.94859904e-01 -1.01262660e-04 -9.32652649e-01]
Q[1]: [ 0.90075047  0.19119475 -0.6507247   0.18386366]
Q[0]: [-3.12446141e-01  8.38153126e-01 -1.01262660e-04 -9.27263977e-01]
Q[0]: [-3.11314722e-01  8.24912638e-01  5.40826491e-04 -9.27263977e-01]


  5%|▌         | 2597/50000 [00:01<00:19, 2442.86it/s]

Q[9]: [-1.0978026  -0.10088687  1.19257495  0.63424034]
Q[4]: [-0.8480244   0.72214814 -0.78150582 -0.94931054]
Q[0]: [-3.10191958e-01  8.18976518e-01  5.40826491e-04 -9.25526463e-01]


  6%|▌         | 3096/50000 [00:01<00:19, 2460.32it/s]

Q[0]: [-3.09117778e-01  8.21262270e-01  5.40826491e-04 -9.23663232e-01]
Q[1]: [ 0.95194483  0.1928004  -0.64551717  0.18632446]
Q[8]: [-1.04666679 -1.25194645 -0.65989138  0.69706452]


  7%|▋         | 3590/50000 [00:01<00:19, 2351.63it/s]

Q[10]: [-2.18676118  0.27550662 -0.9106375   0.82888905]
Q[0]: [-0.30682156  0.85722435  0.00482537 -0.91490773]
Q[0]: [-0.30682156  0.86639073  0.00567818 -0.91302903]


  8%|▊         | 4090/50000 [00:01<00:18, 2425.57it/s]

Q[13]: [-1.65393187 -0.66707895  1.90816053 -0.37837019]
Q[8]: [-1.04293853 -1.24946448 -0.65834734  0.79554199]
Q[0]: [-0.29974252  0.89419423  0.00936031 -0.91302903]
Q[0]: [-0.29974252  0.90151859  0.0102528  -0.91302903]


 10%|▉         | 4883/50000 [00:01<00:18, 2506.33it/s]

Q[1]: [ 0.99983851  0.19786819 -0.63448316  0.18698036]
Q[8]: [-1.03731521 -1.24740953 -0.65644736  0.86327634]
Q[4]: [-0.82978938  0.95377232 -0.76919167 -0.92964636]
Q[0]: [-0.29732731  0.9371232   0.01770274 -0.9055868 ]


 11%|█         | 5397/50000 [00:02<00:17, 2513.91it/s]

Q[4]: [-0.82806332  0.97138369 -0.76749065 -0.92964636]
Q[0]: [-0.29610044  0.95157628  0.01770274 -0.9055868 ]
Q[4]: [-0.82806332  0.98711526 -0.76749065 -0.92775348]
Q[1]: [ 1.02517386  0.19995095 -0.62422187  0.1937963 ]


 12%|█▏        | 6196/50000 [00:02<00:16, 2614.18it/s]

Q[0]: [-0.29485017  0.97274706  0.02349259 -0.90002303]
Q[4]: [-0.82628459  1.01593834 -0.76210986 -0.9258453 ]
Q[4]: [-0.82628459  1.0246435  -0.76210986 -0.9258453 ]
Q[9]: [-1.08336073 -0.09492182  1.26572539  0.63478672]


 13%|█▎        | 6716/50000 [00:02<00:17, 2541.62it/s]

Q[6]: [-0.29559783  0.38618668  1.61417536  0.05605925]
Q[1]: [ 1.04712454  0.20577563 -0.62241703  0.19462416]
Q[14]: [1.46846653 0.80570792 0.15331257 0.7162826 ]
Q[0]: [-0.28713462  1.01576196  0.03039346 -0.89232285]


 15%|█▌        | 7532/50000 [00:02<00:15, 2659.39it/s]

Q[9]: [-1.08058271 -0.09202535  1.2715346   0.63512225]
Q[0]: [-0.28450928  1.02600172  0.03340424 -0.8884429 ]
Q[0]: [-0.28316574  1.0312837   0.03544192 -0.88463353]
Q[8]: [-1.02970682 -1.2269118  -0.6479003   1.05731113]


 16%|█▌        | 8061/50000 [00:03<00:16, 2582.38it/s]

Q[4]: [-0.82256594  1.08220793 -0.74932898 -0.91964187]
Q[4]: [-0.8169459   1.08601172 -0.74932898 -0.91964187]
Q[0]: [-0.27919471  1.04997011  0.03644367 -0.88271994]
Q[4]: [-0.8131794   1.0942077  -0.74572927 -0.9176425 ]


 18%|█▊        | 8837/50000 [00:03<00:16, 2522.17it/s]

Q[0]: [-0.27651127  1.0571728   0.03851266 -0.87886989]
Q[8]: [-1.02549721 -1.22033992 -0.64599615  1.09185641]
Q[1]: [ 1.08909895  0.21744988 -0.605704    0.19713591]
Q[8]: [-1.02370923 -1.21538788 -0.64599615  1.10254425]


 19%|█▊        | 9341/50000 [00:03<00:16, 2459.85it/s]

Q[0]: [-0.26847441  1.07125312  0.04059199 -0.87497465]
Q[0]: [-0.26847441  1.07463165  0.04271207 -0.87497465]
Q[1]: [ 1.09820411  0.21841305 -0.6037844   0.19913839]


 20%|█▉        | 9836/50000 [00:03<00:16, 2430.70it/s]

Q[14]: [1.43641101 0.80697174 0.1535462  0.7162826 ]
Q[8]: [-1.02158141 -1.21341001 -0.64234591  1.12433915]
Q[0]: [-0.26713265  1.08548982  0.04791628 -0.865258  ]


 21%|██        | 10333/50000 [00:04<00:16, 2449.49it/s]

Q[4]: [-0.80371484  1.12689001 -0.73100277 -0.90739143]
Q[0]: [-0.26302538  1.0907229   0.0499946  -0.86329717]
Q[4]: [-0.79995067  1.13151931 -0.73100277 -0.90131982]


 22%|██▏       | 11125/50000 [00:04<00:14, 2596.19it/s]

Q[0]: [-0.26302538  1.09569318  0.05315693 -0.86329717]
Q[4]: [-0.79806565  1.13616001 -0.7291881  -0.89721693]
Q[1]: [ 1.1155146   0.22796343 -0.5932408   0.20183231]
Q[8]: [-1.0112311  -1.21341001 -0.63906098  1.14761285]
Q[0]: [-0.25755075  1.10455353  0.05735951 -0.85741161]


 24%|██▍       | 11903/50000 [00:04<00:14, 2555.83it/s]

Q[4]: [-0.79617863  1.14372756 -0.72545684 -0.89522903]
Q[1]: [ 1.12147526  0.22891609 -0.59146694  0.20294669]
Q[4]: [-0.79041326  1.14665998 -0.72358892 -0.88909156]
Q[0]: [-0.25075621  1.11129836  0.06265499 -0.85350622]


 25%|██▍       | 12425/50000 [00:04<00:14, 2522.46it/s]

Q[1]: [ 1.12634966  0.22891609 -0.58235294  0.20406021]
Q[8]: [-1.00908289 -1.20869633 -0.63906098  1.15945636]





KeyboardInterrupt: 