In [None]:
# default_exp bridge

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#hide
# stellt sicher, dass beim verändern der core library diese wieder neu geladen wird
%load_ext autoreload
%autoreload 2

# Bridge

In [None]:
#export
from bfh_mt_hs2020_rl_basics.agent import Agent

from typing import Iterable, Tuple, List
import numpy as np

from ignite.engine import Engine

from ptan.experience import ExperienceFirstLast

import torch
import torch.nn as nn
from torch.optim import Optimizer, Adam
from torch import device


class Bridge:
    
    def __init__(self, agent: Agent, device: device, optimizer: Optimizer = None, 
                 learning_rate: float = 0.0001, 
                 gamma: float = 0.9, 
                 initial_population: int = 1000, 
                 batch_size: int = 32):

        self.gamma = gamma
        self.initial_population = initial_population
        self.batch_size = batch_size
                
        self.device = device    
        self.agent = agent
        
        if optimizer is not None:
            self.optimzer = optimizer
        else:
            self.optimizer = Adam(self.agent.net.parameters(), lr=learning_rate)


    def batch_generator(self):
        self.agent.buffer.populate(self.initial)
        while True:
            self.agent.buffer.populate(1)
            yield self.agent.buffer.sample(self.batch_size)


    def process_batch(self, engine:Engine, batch: List[ExperienceFirstLast]):
        self.optimizer.zero_grad()
        loss_v = self._calc_loss(batch)

        loss_v.backward()
        self.optimizer.step()
        
        self.agent.iteration_completed(engine.state.iteration)

        return {
            "loss": loss_v.item(),
            "epsilon": self.agent.selector.epsilon,
        }


    def _calc_loss(self, batch: List[ExperienceFirstLast]):
        
        states, actions, rewards, dones, next_states = self._unpack_batch(batch)

        states_v      = torch.tensor(states).to(self.device)
        next_states_v = torch.tensor(next_states).to(self.device)
        actions_v     = torch.tensor(actions).to(self.device)
        rewards_v     = torch.tensor(rewards).to(self.device)
        done_mask     = torch.BoolTensor(dones).to(self.device)

        actions_v         = actions_v.unsqueeze(-1)
        state_action_vals = self.agent.net(states_v).gather(1, actions_v)
        state_action_vals = state_action_vals.squeeze(-1)
        
        with torch.no_grad():
            next_state_vals            = self.agent.tgt_net.target_model(next_states_v).max(1)[0]
            next_state_vals[done_mask] = 0.0
    
        bellman_vals = next_state_vals.detach() * self.gamma + rewards_v
        return nn.MSELoss()(state_action_vals, bellman_vals)


    def _unpack_batch(self, batch: List[ExperienceFirstLast]):
        states, actions, rewards, dones, last_states = [],[],[],[],[]
        
        for exp in batch:
            state = np.array(exp.state)
            states.append(state)
            actions.append(exp.action)
            rewards.append(exp.reward)
            dones.append(exp.last_state is None)
            
            if exp.last_state is None:
                lstate = state  # the result will be masked anyway
            else:
                lstate = np.array(exp.last_state)
            last_states.append(lstate)
            
        return np.array(states, copy=False), \
               np.array(actions), \
               np.array(rewards, dtype=np.float32), \
               np.array(dones,   dtype=np.uint8), \
               np.array(last_states, copy=False)

In [None]:
from bfh_mt_hs2020_rl_basics.agent import Agent
from bfh_mt_hs2020_rl_basics.env import CarEnv

def basic_init() -> Bridge:
    env = CarEnv()
    agent = Agent(env, gamma=0.9, buffer_size=1000)
    bridge = Bridge(agent, torch.device("cpu"), gamma=0.9)
    
    return bridge

In [None]:
def simple_experiences() -> List[ExperienceFirstLast]:
    return [
        ExperienceFirstLast( np.array([0.0, 0.0, 0.0], dtype=np.float32), np.int64(0), 1.0,  np.array([0.5, 0.5, 0.5], dtype=np.float32)),
        ExperienceFirstLast( np.array([1.0, 1.0, 1.0], dtype=np.float32), np.int64(1), 2.0,  None)        
    ]

In [None]:
def test_init():
    assert basic_init() != None

In [None]:
def test_unpack():
    bridge = basic_init()
    batch = simple_experiences()
    unpacked = bridge._unpack_batch(batch)
    # todo -Checks

In [None]:
def test_calc_loss():
    bridge = basic_init()
    batch = simple_experiences()
    loss = bridge._calc_loss(batch)
    # todo -Checks  

In [None]:
from ignite.engine import Engine

def test_process_batch():
    bridge = basic_init()
    batch = simple_experiences()
    bridge.process_batch(Engine(bridge.process_batch), batch)
    # todo -Checks

In [None]:
# Basis Tests
test_init()
test_unpack()
test_calc_loss()
test_process_batch()