# CartPole mit DQN

## Aufgabe 3
Löse das CartPole-v0 Environment mittels DQN.

In [None]:
%run ../setup.ipynb

In [None]:
from lib.statistics import plot
import time
import numpy as np
from collections import deque
from contextlib import suppress

def interact_with_environment(env, agent, n_episodes=400, max_steps=200, train=True, verbose=True):      
    statistics = []
    
    with suppress(KeyboardInterrupt):
        for episode in range(n_episodes):
            done = False
            total_reward = 0
            state = env.reset()
            episode_start_time = time.time()

            for t in range(max_steps):
                action = agent.act(state)
                next_state, reward, done, _ = env.step(action)

                if train:
                    agent.train((state, action, next_state, reward, done))

                state = next_state
                total_reward += reward

                if done:
                    break

            if verbose and episode % 10 == 0:
                speed = t / (time.time() - episode_start_time)
                print(f'episode: {episode}/{n_episodes}, score: {total_reward}, steps: {t}, '
                      f'e: {agent.epsilon:.3f}, speed: {speed:.2f} steps/s')

            statistics.append({
                'episode': episode,
                'score': total_reward,
                'steps': t
            })
        
    return statistics

### 3.1
Implementiere in **agent.py** einen Agenten, der in der Lage ist das CartPole Environment zu lösen.

In [None]:
import gym
env = gym.make('CartPole-v0')

In [None]:
from agent import DQN

action_size = env.action_space.n
state_size = env.observation_space.shape[0]

# Hyperparams
annealing_steps = 1000  # not episodes!
gamma = 0.95
epsilon = 1.0
epsilon_min = 0.01
epsilon_decay = (epsilon - epsilon_min) / annealing_steps
alpha = 0.01
batch_size = 32
memory_size = 10000
start_replay_step = 2000
target_model_update_interval = 1000

agent = DQN(action_size=action_size, state_size=state_size, gamma=gamma, 
            epsilon=epsilon, epsilon_decay=epsilon_decay, epsilon_min=epsilon_min, 
            alpha=alpha, batch_size=batch_size, memory_size=memory_size,
            start_replay_step=start_replay_step, 
            target_model_update_interval=target_model_update_interval)
statistics = interact_with_environment(env, agent, verbose=True)
plot(statistics)

In [None]:
from gym.wrappers import Monitor
# capture every episode and clean 'video' folder before each run
env = Monitor(env, './video', video_callable=lambda episode_id: True, force=True)
statistics = interact_with_environment(env, agent, n_episodes=10, train=False, verbose=False)
plot(statistics, y_limits=(0,200))