In [None]:
import random
from collections import defaultdict
from typing import Tuple, List
import torch
import torch.nn as nn

## Intro to RL

### MDP Recap
MDP = {State, Actions, Transition, Rewards}. The goal is to maximize the expected discounted return $\mathbb{E}[\sum_t \gamma^tr_t]$.

### Q-learning
"We learn an action-value table $Q(s,a)$ with the *off-policy* Bellman optimality update"
$$Q(s,a) \leftarrow Q(s,a)(1 - \alpha) + \alpha \left[ r + \gamma \max_{a'} \left( Q(s', a') \right) \right]$$
We use $\epsilon$-greedy where we take the highest value action $\epsilon$% of the time and a random exploratory action $1-\epsilon$% of the time (actually we reversed it). Our hyperparameters are $\alpha \in (0,1], \gamma \in (0,1]$ which represent the learning rate and discount factor respectively.

In [None]:
# ----- Gridworld -----
class GridWorld:
    def __init__(self, n=4, start=(0,0), goal=(3,3)):
        self.n = n
        self.start = start
        self.goal = goal
        self.state = start
        self.actions = [0, 1, 2, 3] # up, right, down, left

    def reset(self):
        self.state = self.start
        return self.state

    def step(self, a: int):
        r, c = self.state

        if a==0:
            r = max(r - 1, 0)
        elif a == 1:
            c = min(c + 1, self.n - 1)
        elif a == 2:
            r = min(r + 1, self.n - 1)
        elif a == 3:
            c = max(c - 1, 0)
        
        next_state = (r, c)
        done = next_state == self.goal
        reward = 10.0 if done else -1
        self.state = next_state
        return next_state, reward, done, {}
    
    def state_space(self) -> List[Tuple[int, int]]:
        return [(r,c) for r in range(self.n) for c in range(self.n)]

    def action_space(self) -> List[int]:
        return self.actions


# ----- Q-Learning ------
def epsilon_greedy(Q, s, actions, eps):
    if random.random() < eps:
        return random.choice(actions)

    qs = [Q[(s, a)] for a in actions]
    max_q = max(qs)
    best = [a for a,q in zip(actions, qs) if q == max_q]
    return random.choice(best)

def train_q_learning(
    env,
    episodes=2000,
    alpha=0.1,
    gamma=0.99,
    eps_start=1.0,
    eps_end = 0.05,
    eps_decay_episodes=1500
):
    Q = defaultdict(float)
    actions = env.action_space()
    returns = []

    for ep in range(episodes):
        s = env.reset()
        done = False
        G = 0.0

        # Linear epsilon decay
        if eps_decay_episodes > 0:
            eps = max(eps_end, eps_start - (eps_start - eps_end) * (ep / eps_decay_episodes))
        else:
            eps = eps_start # this seems wrong to me

        while not done:
            a = epsilon_greedy(Q, s, actions, eps)
            s_next, r, done, _ = env.step(a)
            G += r

            # Q-learning update
            max_next = max(Q[(s_next, a2)] for a2 in actions) if not done else 0.0
            td_target = r + gamma * max_next
            Q[(s, a)] += alpha * (td_target - Q[(s, a)])

            s = s_next
        
        returns.append(G)

    return Q, returns

# ----- Derive a greedy policy and visualize it -----
ARROWS = {0: "↑", 1: "→", 2: "↓", 3: "←"}

def greedy_policy(Q, env):
    pi = {}
    for s in env.state_space():
        if s == env.goal:
            pi[s] = "G"
            continue
        best_a = max(env.action_space(), key=lambda a: Q[(s, a)]) # don't understand this
        pi[s] = ARROWS[best_a]
    return pi

def print_policy(pi, n=4):
    for r in range(n):
        row = []
        for c in range(n):
            row.append(pi[(r, c)])
        print(" ".join(row))

# --- Test time ---
env = GridWorld(n=10, start=(1,3), goal=(7,9))
Q, returns = train_q_learning(env, episodes=3000, alpha=0.1, gamma=0.99)
pi = greedy_policy(Q, env)
print("Greedy policy after training: ")
print_policy(pi, n=env.n)

print("Avg return (last 100): ", sum(returns[-100:]) / 100.0)


Greedy policy after training: 
→ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
→ → → ↓ ↓ ↓ ↓ ↓ ↓ ↓
→ → → → ↓ ↓ → ↓ ↓ ↓
→ → → → → → → ↓ ↓ ↓
→ ↓ ↓ ↓ → → → ↓ ↓ ↓
↓ ↓ ↓ → ↓ ↓ → → ↓ ↓
→ → → → → → → → ↓ ↓
→ → → → → → → → → G
→ ↑ → ↑ → ↑ ↑ ↑ ↑ ↑
→ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
Avg return (last 100):  -1.76
