In [None]:
# Libs
import logging
import numpy as np
from plotly.subplots import make_subplots
from plotly import graph_objects as go
import shap
import torch
# Custom
from cartpole import CartPoleRegulatorEnv
from NFQagents import NFQagent, ReplayMemory, Transition


In [None]:
# %% Set up Logger
logger = logging.getLogger('lastRun')
logger.setLevel(logging.INFO)
fh = logging.FileHandler("lastRun.log", 'w')
fh.setFormatter(logging.Formatter(
    '[%(levelname)s] - %(asctime)s - %(message)s'))
logger.addHandler(fh)

# Test-System

    Description:
        A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track. The pendulum starts upright, and the goal is to prevent it from falling over by increasing and reducing the cart's velocity.
    Source:
        This environment corresponds to the version of the cart-pole problem described by Barto, Sutton, and Anderson
    Observation: 
        Type: Box(4)
        Num	Observation                 Min         Max
        0	Cart Position             -4.8            4.8
        1	Cart Velocity             -Inf            Inf
        2	Pole Angle                 -24 deg        24 deg
        3	Pole Velocity At Tip      -Inf            Inf
        
    Actions:
        Type: Discrete(2)
        Num	Action
        0	Push cart to the left
        1	Push cart to the right
        
        Note: The amount the velocity that is reduced or increased is not fixed; it depends on the angle the pole is pointing. This is because the center of gravity of the pole increases the amount of energy needed to move the cart underneath it
    Reward:
        Reward is 1 for every step taken, including the termination step
    Starting State:
        All observations are assigned a uniform random value in [-0.05..0.05]
    Episode Termination:
        Pole Angle is more than 12 degrees
        Cart Position is more than 2.4 (center of the cart reaches the edge of the display)
        Episode length is greater than 200
        Solved Requirements
        Considered solved when the average reward is greater than or equal to 195.0 over 100 consecutive trials.

# Selbstlernender Regler mittels _NFQ-Learning_ / _DQN-Learning_
Der selbstlernende Regler wird mittels Neuro-Fitted-Q-Learning Realisiert. Die Implementierung erfolgt in Anlehnung an die Dissertation von Dr. Hafner [1] sowie an die Veröffentlichung von Prof. Riedmiller [2].

Das klassische NFQ-Verfahren ist ein offline Lernverfahren, bei dem ein Multi-Layer-Perceptron wiederholt auf die gleiche Datenbasis trainiert wird. Ziel dabei ist es zu lernen, welche Aktion im jeweiligen Zustand den größten langfristigen Erfolg verspricht (Q-Funktion). Die Kosten/Belohnungen, welche Teil des Trainingsdatensatzes sind dürfen dabei nicht kumuliert werden, sondern beziehen sich immer auf die aktuell gewählte Aktion.

## Q-Funktion am _Cartpole_ Beispiel
Gäbe es für ein gegebenes System eine Zuordnung aus allen möglichen Zuständen sowie entstehenden Kosten (Minimierungsaufgabe) bzw. den entstehenden Gewinn (Maximierungsaufgabe) zu jeder möglichen Aktion, wäre es möglich dieses System optimal zu steuern. Wäre eine solche Zuordnung beispielsweise für das Spiel Schach bekannt, wäre es möglich jedes Spiel zu gewinnen, da in jedem Zug die Optimale Entscheidung getroffen werden kann. Leider ist bereits für einfache Systeme, wie dem Schach-Beispiel, diese Zuordnung zu umfangreich, um diese sinnvoll zu speichern oder sogar anzuwenden. Daher wurden Verfahren Entwickelt um diese Zuordnung zu Approximieren. Dabei erfolgt die Approximation durch eine *Q-Funktion*. Beim NFQ-Verfahren wird die Q-Funktion durch ein vorwärtsgerichtetes künstliches Neuronales Netz umgesetzt. Dies sieht für das Cartpole-Beispiel so aus:

![Neuronales Netz zur Approximation der Q-Funktion im Cartpole-Bsp.](QfunNet_Cartpole.svg)

## Konvergenzverhalten
Durch das iterative Trainingsverfahren läuft das Training nicht stabil ab. Dadurch wird keine kontinuierliche Verbesserung erreicht, sondern es kann vorkommen, dass ein bereits guter Regler wieder schlechter wird. In [1] werden diese diese Effekte näher erläutert und Methoden umd diese abzumildern beschrieben. Im vorliegenden Beispiel sind diese Methoden bereits berücksichtigt.

## Erweiterung
In Anlehnung an das neuere Deep-Q-Learning wurde für das NFQ Verfahren eine Online-Komponente implementiert. Das heißt, der Speicher welcher zum Training des neuronalen Netzes verwendet wird, wird zur Trainingszeit durch aktuelle Erfahrungswerte erweitert. Zusätzlich besteht die Möglichkeit das KNN, welches die Systemsteuerung übernimmt von dem KNN, welches die Zielwerte für das Training berechnet zu trennen. Das Zielnetz wird dabei nur in größeren Abständen auf das kontinuierlich Trainierte Hauptnetz angepasst. Dies führt zu einer besseren Konvergenz des Lernalgorithmus.

## Literatur
[1]R. Hafner, „Dateneffiziente selbstlernende neuronale Regler“, Universität Osnabrück, Osnabrück, 2009.

[2]M. Riedmiller, „Neural Fitted Q Iteration – First Experiences with a Data Efficient Neural Reinforcement Learning Method“, in Machine Learning: ECML 2005, Bd. 3720, J. Gama, R. Camacho, P. B. Brazdil, A. M. Jorge, und L. Torgo, Hrsg. Berlin, Heidelberg: Springer Berlin Heidelberg, 2005, S. 317–328. doi: 10.1007/11564096_32.


## Agent mit Erfahrungsspeicher und -modell

In [None]:
render = True
renderStep = 10
shapStep = 101
shapSamples = 100
# hyper parameter
nEpoch = 800

nBatch = 800
nEval = 3  # repetitions per evaluation step

memCapacity = 16000 # number of transitions to store in experience memory
eps0 = 0.3

# max transition cost
cMax = 1.
cFix = 0.  # small reward for each successfull transition
cStep = 0.016  # see "viewCostFunction.py"

train_env = CartPoleRegulatorEnv(cMax, cFix, mode="train")
eval_env = CartPoleRegulatorEnv(cMax, cFix, mode="eval")

# init agent
# slightly over the max. possible costs, 
# to prevent max. values in sigmoid activation of output layer
nSteps = max(train_env.max_steps, eval_env.max_steps)
QbMin =  -nSteps * cStep + nSteps * cFix  # - nSteps * max cost of one step
QbMax =  cMax + nSteps * cStep + nSteps * cFix  # cost on fail + nSteps * max cost of one step
agent = NFQagent(4, 2, QbMin, QbMax, eps0)
memory = ReplayMemory(memCapacity)

# define helper functions
def evaluate(agent, env, memory=None, render=False):
    # switch of exploration, if no memory update
    if memory is None:
        agent.Epsilon = 0.
    state = env.reset()
    costs = 0.

    for stepIdx in range(env.max_steps):
        if memory:
            state_old = state.copy()
        aIdx = agent.act(state)
        state, cost, done, info = env.step(aIdx)

        if memory:
            memory.push(torch.FloatTensor(state_old),
                        aIdx, 
                        torch.FloatTensor(state),
                        cost, float(done))
        costs += cost
        if done:
            break
        if render:
            env.render()

    return costs, stepIdx

def fillMemoryExplore(nActions, memory, agent, train_env):
    fills = 0
    while fills < memory.capacity:
        aIdx = np.random.randint(nActions)
        state = train_env.reset()
        for _ in range(nBatch):
            aIdx = agent.act(state)
            next_state, cost, done, info = train_env.step(aIdx)
            if done:
                break
            memory.push(torch.FloatTensor(state), 
                        aIdx, 
                        torch.FloatTensor(next_state),
                        cost, 0.)
            fills += 1
            state = np.copy(next_state)

        memory.push(torch.FloatTensor(state), 
                    aIdx, 
                    torch.FloatTensor(next_state),
                    cost, float(done))
        fills += 1

def callNN(data):
    """ Helper function for shap value calculation 
        The shap package only accepts numpy arrays. Hence a wrapper around
        the torch tensors is needed.
    """
    return agent.model(torch.from_numpy(data).float()).detach().numpy()


## Training

In [None]:
# prepare Visualisation
# save and show training loss of each epoch
xEpochs = np.arange(nEpoch)
loss = np.zeros(nEpoch)
evalSucc = np.zeros(nEpoch)
cost = np.zeros(nEpoch)
cMean = np.zeros(nEpoch-50)
cStdU = np.zeros(nEpoch-50)
cStdL = np.zeros(nEpoch-50)
# observe agents min and max Q-Values
Qmin = np.zeros(nEpoch)
Qmax = np.zeros(nEpoch)
# shap values (influence of nn inputs to outputs)
xShap = xEpochs[::shapStep]
sPositionL = np.zeros(xShap.size)
sVelocityL = np.zeros(xShap.size)
sAngleL = np.zeros(xShap.size)
sTipVelocityL = np.zeros(xShap.size)
sPositionR = np.zeros(xShap.size)
sVelocityR = np.zeros(xShap.size)
sAngleR = np.zeros(xShap.size)
sTipVelocityR = np.zeros(xShap.size)

fig = make_subplots(rows=3, cols=1,
                    shared_xaxes=True,
                    vertical_spacing=0.02,
                    specs=[[{'secondary_y': True}], [{'secondary_y': True}], 
                           [{'secondary_y': False}]])
fig.update_xaxes(title_text="number of Epoch", row=3, col=1)
fig.update_yaxes(row=1, col=1, title_text="Cost", autorange=True)
fig.update_yaxes(row=1, col=1, title_text="Evaluation Success", 
                 range=[0., nEval], secondary_y=True, showgrid=False,
                 autorange=False)
fig['layout']['yaxis2']['showgrid'] = False
fig.update_yaxes(row=2, col=1, title_text="Q Values", autorange=True,
                 range=[QbMin, QbMax])
fig.update_yaxes(row=2, col=1, title_text="ANN Loss", autorange=True, 
                 secondary_y=True, showgrid=False)
fig['layout']['yaxis4']['showgrid'] = False
fig.update_yaxes(row=3, col=1, title_text="Shap Value", autorange=True)

lineCost = go.Scatter({"x": xEpochs,
                       "y": cost,
                       "opacity": 0.25,
                       "name": "cost",
                       "uid": "uid_lineCost",
                       "yaxis": "y1",
                       "line": {"color": "#000000",
                                "width": 1
                               }
                       })
lineCostM = go.Scatter({"x": xEpochs[50:],
                        "y": cMean,
                        "name": "cost mean",
                        "uid": "uid_lineCostM",
                        "yaxis": "y1",
                        "line": {"color": "#000000",
                                 "width": 1
                                }
                        })
lineCostSU = go.Scatter({"x": xEpochs[50:],
                         "name": "cost",
                         "y": cStdU,
                         "uid": "uid_lineCostSU",
                         "yaxis": "y1",
                         "line": {"color": "#000000",
                                  "width": 0.5,
                                  "dash": "dot"
                                 },
                         "showlegend": False
                         })
lineCostSL = go.Scatter({"x": xEpochs[50:],
                         "y": cStdL,
                         "name": "cost",
                         "uid": "uid_lineCostSL",
                         "yaxis": "y1",
                         "line": {"color": "#000000",
                                  "width": 0.5,
                                  "dash": "dot"
                                 },
                         "showlegend": False
                         })
lineEval = go.Scatter({"x": xEpochs,
                       "y": evalSucc,
                       "name": "evalSucc",
                       "uid": "uid_lineEval",
                       "yaxis": "y2",
                       "line": {"color": "#9CCC66",
                                "width": 2
                               }
                       })
lineQmin = go.Scatter({"x": xEpochs,
                       "y": Qmin,
                       "name": "Qmin",
                       "uid": "uid_lineQmin",
                       "yaxis": "y1",
                       "line": {"color": "#CC4F4F",
                                "width": 1
                                }
                       })
lineQmax = go.Scatter({"x": xEpochs,
                       "y": Qmax,
                       "name": "Qmax",
                       "uid": "uid_lineQmax",
                       "yaxis": "y1",
                       "line": {"color": "#1E70CC",
                                "width": 1
                                }
                       })
lineLoss = go.Scatter({"x": xEpochs,
                       "y": loss,
                       "name": "loss",
                       "uid": "uid_lineLoss",
                       "yaxis": "y2",
                       "line": {"color": "#D83C20",
                                "width": 1,
                                "dash": "dot"
                               }
                       })
line_sPositionL = go.Scatter({"x": xShap,
                              "y": sPositionL,
                              "name": "Position->Left",
                              "uid": "uid_line_sPositionL",
                              "yaxis": "y3",
                              "line": {"color": "#047de7",
                                       "width": 1,
                                       }},
                              legendgroup='Shap Values',
                              legendgrouptitle_text='Shap Values'
                              )
line_sVelocityL = go.Scatter({"x": xShap,
                              "y": sVelocityL,
                              "name": "Velocity->Left",
                              "uid": "uid_line_sVelocityL",
                              "yaxis": "y3",
                              "line": {"color": "#2d3ab4",
                                       "width": 1,
                                       }},
                              legendgroup='Shap Values',
                              legendgrouptitle_text='Shap Values'
                              )
line_sAngleL = go.Scatter({"x": xShap,
                           "y": sAngleL,
                           "name": "Angle->Left",
                           "uid": "uid_line_sAngleL",
                           "yaxis": "y3",
                           "line": {"color": "#7ec71f",
                                    "width": 1,
                                    }},
                           legendgroup='Shap Values',
                           legendgrouptitle_text='Shap Values'
                           )
line_sTipVelocityL = go.Scatter({"x": xShap,
                                 "y": sTipVelocityL,
                                 "name": "TipVelocity->Left",
                                 "uid": "uid_line_sTipVelocityL",
                                 "yaxis": "y3",
                                 "line": {"color": "#ada505",
                                          "width": 1,
                                         }},
                                 legendgroup='Shap Values',
                                 legendgrouptitle_text='Shap Values'
                                 )
line_sPositionR = go.Scatter({"x": xShap,
                              "y": sPositionR,
                              "name": "Position->Right",
                              "uid": "uid_line_sPositionR",
                              "yaxis": "y3",
                              "line": {"color": "#047de7",
                                       "width": 1,
                                       "dash": "dot"
                                       }},
                              legendgroup='Shap Values',
                              legendgrouptitle_text='Shap Values'
                              )
line_sVelocityR = go.Scatter({"x": xShap,
                              "y": sVelocityR,
                              "name": "Velocity->Right",
                              "uid": "uid_line_sVelocityR",
                              "yaxis": "y3",
                              "line": {"color": "#2d3ab4",
                                       "width": 1,
                                       "dash": "dot"
                                       }},
                              legendgroup='Shap Values',
                              legendgrouptitle_text='Shap Values'
                              )
line_sAngleR = go.Scatter({"x": xShap,
                           "y": sAngleR,
                           "name": "Angle->Right",
                           "uid": "uid_line_sAngleR",
                           "yaxis": "y3",
                           "line": {"color": "#7ec71f",
                                    "width": 1,
                                    "dash": "dot"
                                    }},
                           legendgroup='Shap Values',
                           legendgrouptitle_text='Shap Values'
                           )
line_sTipVelocityR = go.Scatter({"x": xShap,
                                 "y": sTipVelocityR,
                                 "name": "TipVelocity->Right",
                                 "uid": "uid_line_sTipVelocityR",
                                 "yaxis": "y3",
                                 "line": {"color": "#ada505",
                                          "width": 1,
                                          "dash": "dot"
                                         }},
                                 legendgroup='Shap Values',
                                 legendgrouptitle_text='Shap Values'
                                 )
fig.update_layout(title='', height=800)
fig.add_trace(lineCost, row=1, col=1)
fig.add_trace(lineCostM, row=1, col=1)
fig.add_trace(lineCostSU, row=1, col=1)
fig.add_trace(lineCostSL, row=1, col=1)
fig.add_trace(lineEval, row=1, col=1, secondary_y=True)
fig.add_trace(lineQmin, row=2, col=1)
fig.add_trace(lineQmax, row=2, col=1)
fig.add_trace(lineLoss, row=2, col=1, secondary_y=True)
fig.add_trace(line_sPositionL, row=3, col=1)
fig.add_trace(line_sVelocityL, row=3, col=1)
fig.add_trace(line_sAngleL, row=3, col=1)
fig.add_trace(line_sTipVelocityL, row=3, col=1)
fig.add_trace(line_sPositionR, row=3, col=1)
fig.add_trace(line_sVelocityR, row=3, col=1)
fig.add_trace(line_sAngleR, row=3, col=1)
fig.add_trace(line_sTipVelocityR, row=3, col=1)
cEndWidget = go.FigureWidget(fig)
# get direct connection to line data (overwrite variables used for lines)
lineCost = cEndWidget.data[0]
lineCostM = cEndWidget.data[1]
lineCostSU = cEndWidget.data[2]
lineCostSL = cEndWidget.data[3]
lineEval = cEndWidget.data[4]
lineQmin = cEndWidget.data[5]
lineQmax = cEndWidget.data[6]
lineLoss = cEndWidget.data[7]
line_sPositionL = cEndWidget.data[8]
line_sVelocityL = cEndWidget.data[9]
line_sAngleL = cEndWidget.data[10]
line_sTipVelocityL = cEndWidget.data[11]
line_sPositionR = cEndWidget.data[12]
line_sVelocityR = cEndWidget.data[13]
line_sAngleR = cEndWidget.data[14]
line_sTipVelocityR = cEndWidget.data[15]

In [None]:
display(cEndWidget)

# RUN TRAINING
train = True
while train:
    # get data
    agent.Epsilion = 1.
    fillMemoryExplore(2, memory, agent, train_env)
    agent.Epsilion = eps0
    agent.std = np.array([4.8, 10., 0.5*np.pi, 10.], dtype=np.float32)
    memory.updateCapacity(int(memCapacity*2))

    # reinit vis arrays
    cost *= 0.
    cMean *= 0.
    cStdU *= 0.
    cStdL *= 0.
    evalSucc *= 0
    Qmin *= 0.
    Qmax *= 0.
    loss *= 0.
    sPositionL *= 0.
    sVelocityL *= 0.
    sAngleL *= 0.
    sTipVelocityL *= 0.
    sPositionR *= 0.
    sVelocityR *= 0.
    sAngleR *= 0.
    sTipVelocityR *= 0.
    # Index for shap value curves
    sIdx = 0

    for epoch in range(nEpoch):
        l, Qs = agent.batchTrainModel(memory, nBatch)
        agent.updateStateStats(memory)
        cEpoch = 0.
        nSucc = 0
        for eval_rep in range(1, nEval+1):
            c, steps = evaluate(agent, eval_env, memory, 
                                render and eval_rep==nEval and 
                                epoch % renderStep == 0)
            cEpoch += c
            if steps == eval_env.max_steps - 1:
                nSucc += 1

        cost[epoch:] = cEpoch / nEval
        idx = max(epoch - 50, 0)
        idxEnd = max(epoch, 50)
        cMean[idx:] = cost[idx:idxEnd].mean()
        std = cost[idx:idxEnd].std()
        cStdU[idx:] = cMean[idx] + std
        cStdL[idx:] = cMean[idx] - std
        evalSucc[epoch:] = nSucc
        Qmin[epoch:] = Qs[0]
        Qmax[epoch:] = Qs[1]
        loss[epoch:] = l

        with cEndWidget.batch_update():
            lineCost.y = cost
            lineCostM.y = cMean
            lineCostSU.y = cStdU
            lineCostSL.y = cStdL
            lineEval.y = evalSucc
            lineQmin.y = Qmin
            lineQmax.y = Qmax
            lineLoss.y = loss
            cEndWidget.layout['title'] = "Epoch: {}".format(epoch)
            if epoch % shapStep == 0:
                # load training data from memory
                data = Transition(*zip(*memory.sample(shapSamples)))
                data = torch.cat(data.state).reshape(-1, agent.stateSize)
                # scale data for NN
                data = (data / agent.std).numpy()
                # init explainer
                explainer = shap.KernelExplainer(callNN, np.zeros((1, 4)), 
                                                 output_names=["push left", 
                                                               "push right"])
                # calculate shap values
                shap_values = np.abs(explainer.shap_values(
                                data, nsample=100, silent=True)).mean(1)
                sPositionL[sIdx:] = shap_values[0][0]
                sVelocityL[sIdx:] = shap_values[0][1]
                sAngleL[sIdx:] = shap_values[0][2]
                sTipVelocityL[sIdx:] = shap_values[0][3]
                sPositionR[sIdx:] = shap_values[1][0]
                sVelocityR[sIdx:] = shap_values[1][1]
                sAngleR[sIdx:] = shap_values[1][2]
                sTipVelocityR[sIdx:] = shap_values[1][3]
                line_sPositionL.y = sPositionL
                line_sVelocityL.y = sVelocityL
                line_sAngleL.y = sAngleL
                line_sTipVelocityL.y = sTipVelocityL
                line_sPositionR.y = sPositionR
                line_sVelocityR.y = sVelocityR
                line_sAngleR.y = sAngleR
                line_sTipVelocityR.y = sTipVelocityR
                sIdx += 1

        # early training stop if currently no convergence is visible
        if epoch >= 0.5*nEpoch:
            if cMean[idx] > 0.5*cMax and nSucc == 0:
                for param in agent.model.parameters():
                    param.grad.data.clamp_(-1, 1)
                break

        # stop training if controller can reproducible hold pole
        if nSucc == nEval:
            train = False
            break

In [None]:
train_env.close()
eval_env.close()
# agent.loadModel()
demo_env = CartPoleRegulatorEnv(cMax, mode="demo")


In [None]:
agent.Epsilon = 0.
obs = demo_env.reset()
done = False
while not done:
    aIdx = agent.act(obs)
    obs, _, done, _ = demo_env.step(aIdx)
    demo_env.render()
demo_env.close()

In [None]:
demo_env.close()