In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_
import torch.distributions as td

import numpy as np
from numpy.random import normal, randint, choice

import scipy as sp
from scipy.special import softmax

import matplotlib.pyplot as plt
from collections import defaultdict

In [2]:
def batch_to_one_hot(s, S):
    """
    Arguments
    ---------
    s : torch.tensor [bs, 1]
    S : int
    
    Returns
    -------
    s_OH : torch.tensor [bs, S]
    """
    s_OH = torch.zeros((len(s), S))
    s_OH[range(len(s)), s] = 1
    return s_OH

In [3]:
def reverse_one_hot(x):
    """
    Arguments
    ---------
    x : torch.tensor[bs, d_x]
    
    Returns
    x_int : torch.tensor[bs]
    """
    x_int = x.argmax(-1)
    return x_int

In [4]:
def t_every(s, s_t):
    """
    Arguments
    ---------
    s : torch.tensor [T]
    s_t : int
    
    Returns
    -------
    t_every : list(int, int, ...)
    """
    return [t.item() for t in np.argwhere(s[:-1] == s_t).squeeze(0)]

def t_first(s, s_t):
    """
    Arguments
    ---------
    s : torch.tensor [T]
    s_t : int
    
    Returns
    -------
    t_first : list(int)
    """
    return [t_every(s, s_t)[0]]

In [5]:
def loss_fn(V_pred, V_gt):
    """
    Arguments
    ---------
    G : torch.tensor[1]
    V : torch.tensor[1]

    Returns
    -------
    MSE : torch.tensor[1]
    """
    MSE = ((V_pred - V_gt) ** 2).mean()
    return MSE

In [6]:
class MDP():
    """
    Attributes
    ----------
    S : int
        number of states
        
    d_x : int
        number of state features
        
    A : int
        number of actions
    
    P : torch.tensor [S, A, S]
        transition probabilities
        
    P_dist : td.Categorical
        used to sample new state when agent performs action
        
    R : torch.tensor [S, A]
        rewards
        
    gamma : float
        discount factor
        
        
    Property
    --------
    pi_opt
    V_opt
    Q_opt
    """
    def __init__(self, S, A):
        self.A = A
        self.S = S
        self.d_x = S
        self.P = nn.Softmax(-1)(torch.randn((S,A,S)))
        self.P_dist = td.Categorical(probs = self.P)
        
        self.R = torch.randint(0,10,size = (S,A))
        
        self.gamma = 0.5
    
    def V_next(self, V, pi):
        """ V(s) - Value-State, k-step

        Arguments
        ---------
        pi : torch.tensor[S, A]
        V : torch.tensor[S]

        Returns
        -------
        V : torch.tensor[S]
        """
        V_next = torch.stack([self.R[s, pi[s].argmax()] \
                            + self.gamma * (self.P[s, pi[s].argmax()] * V).sum()
                              for s in range(self.S)])
        return V_next 
    
    
    def Q_next(self, V, a):
        """ Q(s,a) - Action-Value-State, k-step

        Arguments
        ---------
        a : int
        V : torch.tensor[S]

        Returns
        -------
        Q_next : torch.tensor[S]
        """
        Q_next = torch.stack([
            self.R[s, a] + self.gamma * (self.P[s, a] * V).sum()
            for s in range(self.S)
        ])
        return Q_next
    
    
    def a_next_opt(self, V):
        """
        Arguments
        ---------
        V : torch.tensor[S]

        Returns
        -------
        a_max : torch.tensor[S]
        """
        Q_a = torch.stack([self.Q_next(V, a) for a in range(self.A)], dim = -1)
        a_max = Q_a.argmax(-1)
        return a_max
    
    
    def V(self, pi):
        """ V(s) - State-Value
        Arguments
        ---------
        pi : torch.tensor [S,A]
        
        Returns
        -------
        V : torch.tensor[S]
        """
        k = 0
        while True:
            if k == 0:
                V = torch.zeros(self.S)
            if k >= 1:
                V_prev = V
                V = self.V_next(V, pi)
                if np.linalg.norm(V - V_prev) < 0.01:
                    break
            k+=1
        return V  
    
    
    def Q(self, a, pi):
        """ Q(s,a) - Action-State-Value
        
        Arguments
        ---------
        pi : torch.tensor [S,A]
        a : int
        
        Returns
        -------
        Q : np.array[S]
        """
        Q = torch.stack([self.R[s,a] \
                       + self.gamma * (self.P[s,a] * self.V(pi)).sum() 
                         for s in range(self.S)])
        return Q
    
    @property
    def pi_opt(self):
        k=0
        while True:
            if k == 0:
                V = torch.zeros(self.S)
                # the initial pi is not necessary for the algorithm
                # it just makes it easer to define when to stop the iteration
                pi = nn.Softmax(-1)(torch.randn((self.S, self.A)))
            if k >= 1:
                pi_prev = pi
                a_opt = self.a_next_opt(V)
                pi = batch_to_one_hot(a_opt, self.A)
                V = self.V_next(V, pi)
                
                if torch.equal(pi, pi_prev):
                    break  
            k+=1
        return pi
    
    @property
    def V_opt(self):
        return self.V(self.pi_opt)
    
    @property
    def Q_opt(self):
        return torch.stack([self.Q(a, self.pi_opt) for a in range(self.A)], dim = 1)

In [7]:
class MDPWorld(MDP):
    def __init__(self, S, A):
        super(MDPWorld, self).__init__(S, A)
      
    @property
    def initial_state(self):
        """
        Returns
        -------
        s : torch.tensor[1]
        """
        return torch.tensor([0])
    
    def step(self, s, a):
        """
        Arguments
        ---------
        s : torch.tensor[1]
        a : torch.tensor[1]

        Returns
        -------
        r : torch.tensor[1]
        s : torch.tensor[1]
        """
        r = self.R[s,a]
        s_next = self.P_dist.sample()[s,a]
        return r, s_next

In [8]:
class Q(nn.Module):
    def __init__(self, d_x, S, A):
        super().__init__()
        self.S = S
        self.A = A
        self.MLP = nn.Sequential(
            nn.Linear(d_x, A)
        )
        
    def forward(self, s, a = None):
        """
        Arguments
        ---------
        s : torch.tensor[bs]
        a : torch.tensor[bs]
        max : bool
        
        Returns
        -------
        Q : torch.tensor[bs, A]
        or 
        Q_a : torch.tensor[bs]
        or 
        Q_max : torch.tensor[bs], a_max 
        """
        x = batch_to_one_hot(s, self.S)
        Q = self.MLP(x)
        
        if a != None:
            Q_a = Q[range(len(a)), a]
            return Q_a 
        return Q
    
    def greedy(self, s, eps = 0.):
        """
        Arguments
        ---------
        s : torch.tensor[1]
        eps : float
        
        Returns
        -------
        Q_max : torch.tensor[1]
        pi_s : torch.tensor[A]
        """
        Q = self(s)
        a_max = Q.argmax(-1)
        Q_max = Q[range(len(a_max)), a_max]
        pi_s = torch.ones(self.A) * (eps/(self.A-1))
        pi_s[a_max] = 1-eps
        return Q_max, a_max, pi_s

In [55]:
class Agent(nn.Module):
    def __init__(self):
        super().__init__()
        self.x = 0
        
    def place(self, world):
        """place the agent into the world and initialize Q function and state
        
        Arguments
        ---------
        world : World
        
        Returns
        -------
        -
        """
        self.world = world
        self.Q = Q(d_x = world.d_x, S = world.S, A = world.A)
        self.Q_target = Q(d_x = world.d_x, S = world.S, A = world.A)
        self.s = world.initial_state
        self.pi = td.Categorical(probs = nn.Softmax(-1)(torch.randn((world.S, world.A))))
        
    def take_action(self):
        """
        Returns
        -------
        list(s,a,r,s)
        s : torch.tensor[1]
        a : torch.tensor[1]
        r : torch.tensor[1]
        s_next : torch.tensor[1]
        """
        e = list()
        a = self.pi.sample()[self.s]
        r,s_next = self.world.step(self.s, a)
        
        e.append(self.s)
        e.append(a)
        e.append(r)
        e.append(s_next)
        
        self.s = s_next
        return e
    
    def take_episode(self, T):
        """Sample s,a,r for T timesteps
        
        Arguments
        ---------
        T : int
        
        Returns
        -------
        s : torch.tensor[T,   1]
        a : torch.tensor[T-1, 1]
        r : torch.tensor[T-1, 1]
        """
        s = list()
        a = list()
        r = list()
        for t in range(T):
            s_t,a_t,r_t,s_t_next = self.take_action()
            s.append(s_t)
            a.append(a_t)
            r.append(r_t)
        s.append(s_t_next)
        return (torch.stack(s),
                torch.stack(a), 
                torch.stack(r))
    
    def freeze_weights(self, model):
        for w in model.parameters():
            w.requires_grad = False
            
    def unfreeze_weights(self, model):
        for w in model.parameters():
            w.requires_grad = True
        
    def update_Q_target(self):
        w_target = list(self.Q_target.parameters())
        for i, w in enumerate(self.Q.parameters()):
            w_target[i].copy_(w)
            
    def learn_step(self, t, e):
        s,a,r,s_next = e
        # greedy
        _, a_max, pi_s_next = self.Q.greedy(s_next, eps = (1/(t+1))**0.5)

        # update policy
        self.pi.probs[s_next] = pi_s_next

        # loss
        Y = r + self.world.gamma * self.Q_target(s, a_max)
        Q_sa = self.Q(s,a)
        loss = (Y - Q_sa)**2
        return loss

In [162]:
class DoubleDQN(Agent):
    def __init__(self):
        super(DoubleDQN, self).__init__()
        
    def learn(self, T = 10000, K_freeze = 100, lr = 0.01):
        """
        Arguments
        ---------
        T : exploration steps
        K_freeze : update period for Q_target
        lr : learning rate
        """
        optimizer = Adam(self.Q.parameters(), lr)
        self.freeze_weights(self.Q_target)

        for t in range(T):
            # take action
            e = self.take_action()
            loss = self.learn_step(t, e)
            
            # update Q weights
            loss.backward()
            clip_grad_norm_(self.Q.parameters(), 1)
            optimizer.step()
            optimizer.zero_grad()
            
            if t % K_freeze == 0:
                self.update_Q_target()  
                
        self.pi.probs = self.pi.probs.round()

In [180]:
class DQN_PR(Agent):
    def __init__(self):
        super(DQN_PR, self).__init__()
        self.memory = Memory()
        
    def init_cum_grad(self):
        return [torch.zeros_like(w) for w in self.Q.parameters()]
    
    def update_cum_grad(self, cum_grad, w_i):
        for i, w in enumerate(self.Q.parameters()):
            cum_grad[i] += w_i * w.grad

    def update_Q(self, cum_grad, lr):
        self.freeze_weights(self.Q)
        for i, w in enumerate(self.Q.parameters()):
            w.copy_(w - lr * cum_grad[i])
        self.unfreeze_weights(self.Q)
        
    def learn(self, T = 1000, K_freeze = 100, K = 10, k = 30, lr = 0.01):
        """
        Arguments
        ---------
        T : exploration steps
        K_freeze : update period for Q_target
        K : memory replay period
        k : batchsize
        lr : learning rate
        """
        self.freeze_weights(self.Q_target)
        for t in range(T):
            e = self.take_action()
            p = torch.tensor([1.])
            self.memory.store(t, e, p)
            
            if t % K == 0:
                self.memory.update_dist()
                cum_grad = self.init_cum_grad()
                for j in range(k):
                    e, j = self.memory.sample()
                    loss = self.learn_step(j, e)
                    self.memory.update(j, p = loss.detach().sqrt())
                    
                    loss.backward()
                    clip_grad_norm_(self.Q.parameters(), 1)
                    self.update_cum_grad(cum_grad, w_i = self.memory.w(j)/k)
    
                self.update_Q(cum_grad, lr)
            
            if t % K_freeze == 0:
                self.update_Q_target() 
                
        self.pi.probs = self.pi.probs.round()

In [99]:
class Memory():
    def __init__(self, N = 100, alpha = 0.5, beta = 0.5):
        self.N = N
        
        self.alpha = alpha
        self.beta = beta
        
        self.H = defaultdict(dict)
        self.p_max = torch.tensor([1.])
        
        self.P = None
        
    @property
    def p(self):
        return torch.cat([H_t['p'] for t, H_t in self.H.items()])
    
    def w(self, j):
        P = self.P_dist.probs
        return (P.min() / P[j]) ** self.beta
    
    def store(self, t, e, p):
        self.H[t]['e'] = e
        if p > self.p_max:
            self.p_max = p
        self.H[t]['p'] = self.p_max
    
    def update_dist(self):
        P = self.p ** self.alpha
        P = P/P.sum()
        self.P_dist = td.Categorical(probs = P)
        
    def sample(self):
        j = self.P_dist.sample().item()
        return self.H[j]['e'], j
    
    def update(self, j, p):
        self.H[j]['p'] = p

In [147]:
world = MDPWorld(10,10)

In [181]:
agent = DoubleDQN()
agent.place(world)
agent.learn() # gute nummer: 3000-100

In [182]:
world.V(pi = agent.pi.probs).sum()

tensor(152.1549)

In [152]:
world.V_opt.sum()

tensor(159.2016)