In [None]:
# default_exp agent

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#hide
# stellt sicher, dass beim verändern der core library diese wieder neu geladen wird
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Agent

In [None]:
from bfh_mt_hs2020_rl_basics.env import CarEnv

import gym
import ptan
import torch
import torch.nn as nn

In [None]:
#export
from bfh_mt_hs2020_rl_basics.env import CarEnv

from abc import ABC, abstractmethod
from typing import Iterable, Tuple, List

import torch

class AgentBase(ABC):
    
    def __init__(self, env: CarEnv, devicestr:str):
        self.env = env
        self.device = torch.device(devicestr)
    
    @abstractmethod
    def get_net(self):
        pass
    
    @abstractmethod
    def get_tgtnet(self):
        pass
    
    @abstractmethod
    def get_buffer(self):
        pass

    @abstractmethod
    def iteration_completed(self, iteration: int):
        pass
    

## Simple Agent

The SimpleAgent has no special improvements concering training stability.

In [None]:
#export
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self, obs_size, hidden_size, n_actions):
        super(SimpleNet, self).__init__()
        
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions)
        )

    def forward(self, x):
        return self.net(x.float())

In [None]:
#export
from bfh_mt_hs2020_rl_basics.env import CarEnv

import gym
import ptan
import torch
from torch import device

class SimpleAgent(AgentBase):
    
    def __init__(self, env: CarEnv, 
                 devicestr:str,  
                 gamma:float, 
                 buffer_size:int, 
                 target_net_sync:int = 1000, 
                 eps_start:float = 1.0, 
                 eps_final:float = 0.02, 
                 eps_frames:int = 10**5):
        
        super(SimpleAgent, self).__init__(env, devicestr)

        self.target_net_sync = target_net_sync
        
        self.net = self._config_net()
        
        self.tgt_net = ptan.agent.TargetNet(self.net)
        
        self.selector = ptan.actions.EpsilonGreedyActionSelector(
                                    epsilon=1, 
                                    selector=ptan.actions.ArgmaxActionSelector())
        
        self.epsilon_tracker = ptan.actions.EpsilonTracker(selector=self.selector, eps_start=eps_start, eps_final=eps_final, eps_frames=eps_frames)

        self.agent = agent = ptan.agent.DQNAgent(self.net, self.selector, device = self.device)
        
        self.exp_source = ptan.experience.ExperienceSourceFirstLast(self.env, self.agent, gamma=gamma)
        self.buffer = ptan.experience.ExperienceReplayBuffer(self.exp_source, buffer_size=buffer_size)
        

    def _config_net(self)-> nn.Module:
        return SimpleNet(self.env.observation_space.shape[0], 128, self.env.action_space.n).to(self.device)
    
    
    def iteration_completed(self, iteration: int):
        
        self.epsilon_tracker.frame(iteration)
        
        if iteration % self.target_net_sync == 0:
            self.tgt_net.sync()

    def get_net(self):
        return self.net
    
    def get_tgtnet(self):
        return self.tgt_net
    
    def get_buffer(self):
        return self.buffer
    

In [None]:
GAMMA = 0.9
REPLAY_SIZE = 1000

In [None]:
def test_simpleagent_cpu():
    print("test cpu")
    env = CarEnv()
    agent = SimpleAgent(env, "cpu", gamma=GAMMA, buffer_size=REPLAY_SIZE)

In [None]:
def test_simpleagent_cuda():
    print("test cuda")
    env = CarEnv()
    agent = SimpleAgent(env, "cuda", gamma=GAMMA, buffer_size=REPLAY_SIZE)

In [None]:
test_simpleagent_cpu()
test_simpleagent_cuda()

test cpu
test cuda


## Rainbow Agent

The RainbowAgent combines several measures that should increase the stability of the training process.

In [None]:
#export
import torch.nn as nn

class DuelingNet(nn.Module):
    def __init__(self, obs_size, hidden_size, n_actions):
        super(DuelingNet, self).__init__()
        
        self.net_adv = nn.Sequential(
            nn.NoisyLinear(obs_size, hidden_size),
            nn.ReLU(),
            nn.NoisyLinear(hidden_size, n_actions)
        )
        
        self.net_val = nn.Sequential(
            nn.NoisyLinear(obs_size, hidden_size),
            nn.ReLU(),
            nn.NoisyLinear(hidden_size, 1)
        )

    def forward(self, x):
        val = self.net_val(x.float())
        adv = self.net_adv(x.float())
    
        return val + (adv - adv.mean(dim=1, keepdim=True))

In [None]:
#export
from bfh_mt_hs2020_rl_basics.env import CarEnv

import gym
import ptan
import torch
from torch import device

class RainbowAgent(AgentBase):
    
    def __init__(self, env: CarEnv, 
                 devicestr:str,  
                 gamma:float, 
                 buffer_size:int, 
                 target_net_sync:int = 1000,
                 steps_count:int = 3,
                 prio_replay_alpha:float = 0.6):
        
        self.env = env
        self.steps_count = steps_count
        self.device = torch.device(devicestr)
        self.target_net_sync = target_net_sync
        
        self.net = self._config_net()
        
        self.tgt_net = ptan.agent.TargetNet(self.net)
        
        self.selector = ptan.actions.ArgmaxActionSelector()
        
        self.agent = agent = ptan.agent.DQNAgent(self.net, self.selector, device = self.device)
        
        self.exp_source = ptan.experience.ExperienceSourceFirstLast(self.env, self.agent, gamma=gamma, steps_count=self.steps_count)
        
        self.buffer = ptan.experience.PrioritizedReplayBuffer(self.exp_source, buffer_size=buffer_size, alpha=prio_replay_alpha)
        

    def _config_net(self)-> nn.Module:
        return DuelingNet(self.env.observation_space.shape[0], 128, self.env.action_space.n).to(self.device)


    def iteration_completed(self, iteration: int):
        
        if iteration % self.target_net_sync == 0:
            self.tgt_net.sync()


    def get_net(self):
        return self.net


    def get_tgtnet(self):
        return self.tgt_net


    def get_buffer(self):
        return self.buffer
    