In [None]:
import random
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm_notebook as tqdm
%matplotlib inline

In [None]:
def random_position(max_height, max_width):
    return np.random.randint(0, max_height), np.random.randint(0, max_width)

In [None]:
def compare(x, y):
    return np.array_equal(x, y)

In [None]:
class Gridworld:
    def __init__(self, height, width):
        self.height = height
        self.width = width
        
        self.initial_state = None
        self.initial_player_pos = None
        self.state = np.zeros((height, width, 4), dtype=int)
        self.player = np.array([0, 0, 0, 1])
        self.wall = np.array([0, 0, 1, 0])
        self.pit = np.array([0, 1, 0, 0])
        self.goal = np.array([1, 0, 0, 0])
        self.player_position = None
        self.world_ = None
        
        self.up = 0
        self.right = 1
        self.down = 2
        self.left = 3
        
    def set_player_position(self, position):
        if self.player_position is not None:
            self.state[tuple(self.player_position)] = np.zeros(4)
        self.player_position = tuple(position)
        self.state[tuple(position)] = self.player
        
    def save_state(self):
        self.initial_state = np.copy(self.state)
        self.initial_player_pos = np.copy(self.player_position)
        
    @classmethod
    def deterministic_easy(cls, height=4, width=4):
        world = cls(height, width)
        #place player
        world.set_player_position((0, 1))
        #place goal
        world.state[3, 3] = world.goal
        world.save_state()
        return world

    @classmethod
    def deterministic(cls, height=4, width=4):
        world = cls(height, width)
        #place player
        world.set_player_position((0, 1))
        #place wall
        world.state[2, 2] = world.wall
        #place pit
        world.state[1, 1] = world.pit
        #place goal
        world.state[3, 3] = world.goal
        world.save_state()
        return world
    
    @classmethod
    def random_player_pos(cls, height=4, width=4):
        world = cls(height, width)
        pos = random_position(height, width)
        while pos in ((2, 2), (1, 1), (1, 2)):
            pos = random_position(height, width)
        #place player
        world.set_player_position(pos)
        #place wall
        world.state[2, 2] = world.wall
        #place pit
        world.state[1, 1] = world.pit
        #place goal
        world.state[1, 2] = world.goal
        world.save_state()
        return world
    
    @classmethod
    def random(cls, height=4, width=4):
        world = cls(height, width)
        length = height * width
        places = [(i, j) for i in range(height) for j in range(width)]
        positions = random.sample(places, 4)
        #place player
        world.set_player_position(positions[0])
        #place wall
        world.state[positions[1]] = world.wall
        #place pit
        world.state[positions[2]] = world.pit
        #place goal
        world.state[positions[3]] = world.goal
        world.save_state()
        return world
    
    def step(self, action):
        diff = (0, 0)
        if action == self.up and self.player_position[0] > 0:
            diff = (-1, 0)
        elif action == self.right and self.player_position[1] < self.width - 1:
            diff = (0, 1)
        elif action == self.down and self.player_position[0] < self.height - 1:
            diff = (1, 0)
        elif action == self.left and  self.player_position[1] > 0:
            diff = (0, -1)
        
        old_pos = tuple(np.copy(self.player_position))
        new_pos = tuple(np.add(self.player_position, diff))
        done = False
        reward = -1
        if compare(self.state[new_pos], self.wall):
            new_pos = old_pos
        elif compare(self.state[new_pos], self.pit):
            done = True
            reward = -10
        elif compare(self.state[new_pos], self.goal):
            done = True
            reward = 10
        old_state = np.copy(self.state)
        self.set_player_position(new_pos)
        return old_state, reward, self.state, done
    
    def reset(self):
        self.state = np.copy(self.initial_state)
        self.player_position = np.copy(self.initial_player_pos)
        return self.state
    
    def display(self):
        grid = np.empty((self.height, self.width), dtype=str)
        for i in range(self.height):
            for j in range(self.width):
                point = self.state[i, j]
                if compare(point, self.player):
                    grid[i, j] = '@'
                elif compare(point, self.wall):
                    grid[i, j] = 'W'
                elif compare(point, self.goal):
                    grid[i, j] = '+'
                elif compare(point, self.pit):
                    grid[i, j] = '^'
                else:
                    grid[i, j] = ' '
        return grid

In [None]:
class QDistFunction(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(64, 164)
        self.l2 = nn.Linear(164, 150)
        self.l3 = nn.Linear(150, 21 * 4)
        self.smax = nn.Softmax()
        
    def forward(self, x):
        x = x.view(-1)
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.l3(x)
        # view -> softmax
        x = torch.cat(
            [F.softmax(x[21 * a:21 * (a + 1)]) for a in range(4)],
            dim=0
        )
        return x
    
    def fit_step(
        self,
        old_state,
        new_state,
        action,
        reward,
        range_,
        gamma,
        loss_fun,
        optimizer,
        grad_clip: Optional[float] = None,
) -> None:
        self.train()

        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = self(old_state)
        
        new_qdist = self(new_state)
        len_range = len(range_)
        # Softmax this rather than using max EV
        evs = [
            sum(z * p for z, p in zip(range_, new_qdist[len_range * a:len_range * (a + 1)]))
            for a in range(4)
        ]
        a = np.argmax(evs)
        dist = new_qdist[21 * a:21 * (a + 1)]
        m = torch.zeros((21,))
        for i, z in enumerate(range_):
            tzj = reward + gamma * z
            if tzj < -10:
                tzj = -10
            elif tzj > 10:
                tzj = 10
            bj = tzj + 10
            l = int(np.floor(bj))
            u = int(np.ceil(bj))
            m[l] += dist[i] * (u - bj)
            m[u] += dist[i] * (bj - l)
        
        loss = loss_fun(outputs[21 * action:21 * (action + 1)], m.detach())
        loss.backward()
        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(
                self.parameters(),
                grad_clip
            )
        optimizer.step()
        self.eval()

In [None]:
class Agent:
    def __init__(self, env, epsilon, gamma, lr=0.01, max_steps=100):
        self.q = QDistFunction()
        self.env = env
        self.epsilon = epsilon
        self.gamma = gamma
        self.max_steps = max_steps
        self.loss_fun = nn.BCELoss()
        self.optimizer = torch.optim.Adam(self.q.parameters(), lr)
        self.range_ = list(range(-10, 11))
        
    def run_episode(self):
        done = False
        total_reward = 0.0
        while not done:
            qdist = self.q(
                torch.tensor(self.env.state, dtype=torch.float32)
            )
            samples = torch.cat(
                [torch.multinomial(qdist[21 * a:21 * (a + 1)], 1)
                for a in range(4)]
            )
            action = samples.max(0)[1]
                
            old_state, reward, _, done = self.env.step(action)
            total_reward += reward
            
            self.q.fit_step(
                torch.tensor(old_state, dtype=torch.float32),
                torch.tensor(self.env.state, dtype=torch.float32),
                action,
                reward,
                self.range_,
                self.gamma,
                self.loss_fun,
                self.optimizer
            )
        return total_reward
    
    def run_model(self, world=0):
        done = False
        self.env.reset()
        print(self.env.display())
        for _ in range(self.max_steps):
            qdist = self.q(torch.tensor(self.env.state, dtype=torch.float32))
            len_range = len(self.range_)
            evs = [
                sum(z * p for z, p in zip(self.range_, qdist[len_range * a:len_range * (a + 1)]))
                for a in range(4)
            ]
            action = np.argmax(evs)

            _, _, _, done = self.env.step(action)
            print(self.env.display())
            if done:
                break
            
    def train(self, epochs=1000):
        self.env.reset()
        rewards = np.zeros(epochs)
        for i in tqdm(range(epochs)):
            rewards[i] = self.run_episode()
            self.env.reset()
        return pd.Series(rewards)

In [None]:
agent = Agent(
    Gridworld.deterministic_easy(),
    epsilon=0.1,
    gamma=0.9,
)

In [None]:
rewards = agent.train(2000)

In [None]:
rewards.expanding().mean().plot()
plt.show()

In [None]:
agent.run_model()