Skip to content

Commit

Permalink
Merge pull request #225 from Huizerd/reward
Browse files Browse the repository at this point in the history
Implementation of reward prediction error
  • Loading branch information
djsaunde committed Apr 22, 2019
2 parents bd5f836 + 4a32996 commit 8ebf7b0
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 22 deletions.
8 changes: 4 additions & 4 deletions bindsnet/learning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def _connection_update(self, **kwargs) -> None:
Keyword arguments:
:param float reward: Reward signal from reinforcement learning task.
:param Union[float, torch.Tensor] reward: Reward signal from reinforcement learning task.
:param float a_plus: Learning rate (post-synaptic).
:param float a_minus: Learning rate (pre-synaptic).
"""
Expand Down Expand Up @@ -436,7 +436,7 @@ def _conv2d_connection_update(self, **kwargs) -> None:
Keyword arguments:
:param float reward: Reward signal from reinforcement learning task.
:param Union[float, torch.Tensor] reward: Reward signal from reinforcement learning task.
:param float a_plus: Learning rate (post-synaptic).
:param float a_minus: Learning rate (pre-synaptic).
"""
Expand Down Expand Up @@ -531,7 +531,7 @@ def _connection_update(self, **kwargs) -> None:
Keyword arguments:
:param float reward: Reward signal from reinforcement learning task.
:param Union[float, torch.Tensor] reward: Reward signal from reinforcement learning task.
:param float a_plus: Learning rate (post-synaptic).
:param float a_minus: Learning rate (pre-synaptic).
"""
Expand Down Expand Up @@ -580,7 +580,7 @@ def _conv2d_connection_update(self, **kwargs) -> None:
Keyword arguments:
:param float reward: Reward signal from reinforcement learning task.
:param Union[float, torch.Tensor] reward: Reward signal from reinforcement learning task.
:param float a_plus: Learning rate (post-synaptic).
:param float a_minus: Learning rate (pre-synaptic).
"""
Expand Down
83 changes: 83 additions & 0 deletions bindsnet/learning/reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from abc import ABC, abstractmethod
from typing import Union

import torch


class AbstractReward(ABC):
# language=rst
"""
Abstract base class for reward computation.
"""

@abstractmethod
def compute(self, **kwargs) -> None:
# language=rst
"""
Computes/modifies reward.
"""
pass

@abstractmethod
def update(self, **kwargs) -> None:
# language=rst
"""
Updates internal variables needed to modify reward. Usually called once per episode.
"""
pass


class MovingAvgRPE(AbstractReward):
# language=rst
"""
Calculates reward prediction error (RPE) based on an exponential moving average (EMA) of past rewards.
"""

def __init__(self) -> None:
# language=rst
"""
Constructor for EMA reward prediction error.
"""
self.reward_predict = torch.tensor(0.0) # Predicted reward (per step).
self.reward_predict_episode = torch.tensor(0.0) # Predicted reward per episode.
self.rewards_predict_episode = [] # List of predicted rewards per episode (used for plotting).

def compute(self, **kwargs) -> torch.Tensor:
# language=rst
"""
Computes the reward prediction error using EMA.
Keyword arguments:
:param Union[float, torch.Tensor] reward: Current reward.
:return: Reward prediction error.
"""
# Get keyword arguments.
reward = kwargs['reward']

return reward - self.reward_predict

def update(self, **kwargs) -> None:
# language=rst
"""
Updates the EMAs. Called once per episode.
Keyword arguments:
:param Union[float, torch.Tensor] accumulated_reward: Reward accumulated over one episode.
:param int steps: Steps in that episode.
:param float ema_window: Width of the averaging window.
"""
# Get keyword arguments.
accumulated_reward = kwargs['accumulated_reward']
steps = torch.tensor(kwargs['steps']).float()
ema_window = torch.tensor(kwargs.get('ema_window', 10.0))

# Compute average reward per step.
reward = accumulated_reward / steps

# Update EMAs.
self.reward_predict = (1 - 1 / ema_window) * self.reward_predict + 1 / ema_window * reward
self.reward_predict_episode = (1 - 1 / ema_window) * self.reward_predict_episode + \
1 / ema_window * accumulated_reward
self.rewards_predict_episode.append(self.reward_predict_episode.item())
22 changes: 17 additions & 5 deletions bindsnet/network/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch
import tempfile
from typing import Dict
from typing import Dict, Optional

import torch

from .monitors import AbstractMonitor
from .nodes import AbstractInput, Nodes
from .topology import AbstractConnection
from .monitors import AbstractMonitor
from ..learning.reward import AbstractReward

__all__ = [
'load', 'Network', 'nodes', 'monitors', 'topology'
Expand Down Expand Up @@ -82,19 +84,25 @@ class Network:
plt.tight_layout(); plt.show()
"""

def __init__(self, dt: float = 1.0, learning: bool = True) -> None:
def __init__(self, dt: float = 1.0, learning: bool = True,
reward_fn: Optional[AbstractReward] = None) -> None:
# language=rst
"""
Initializes network object.
:param dt: Simulation timestep.
:param learning: Whether to allow connection updates. True by default.
:param reward_fn: Optional class allowing for modification of reward in case of reward-modulated learning.
"""
self.dt = dt
self.layers = {}
self.connections = {}
self.monitors = {}
self.learning = learning
if reward_fn is not None:
self.reward_fn = reward_fn()
else:
self.reward_fn = None

def add_layer(self, layer: Nodes, name: str) -> None:
# language=rst
Expand Down Expand Up @@ -220,7 +228,7 @@ def run(self, inpts: Dict[str, torch.Tensor], time: int, **kwargs) -> None:
to not spiking. The ``Tensor``s should have shape ``[n_neurons]``.
:param Dict[str, torch.Tensor] injects_v: Mapping of layer names to boolean masks if neurons should be added
voltage. The ``Tensor``s should have shape ``[n_neurons]``.
:param float reward: Scalar value used in reward-modulated learning.
:param Union[float, torch.Tensor] reward: Scalar value used in reward-modulated learning.
:param Dict[Tuple[str], torch.Tensor] masks: Mapping of connection names to boolean masks determining which
weights to clamp to zero.
Expand Down Expand Up @@ -260,6 +268,10 @@ def run(self, inpts: Dict[str, torch.Tensor], time: int, **kwargs) -> None:
masks = kwargs.get('masks', {})
injects_v = kwargs.get('injects_v', {})

# Compute reward.
if self.reward_fn is not None:
kwargs['reward'] = self.reward_fn.compute(**kwargs)

# Effective number of timesteps.
timesteps = int(time / self.dt)

Expand Down
10 changes: 5 additions & 5 deletions bindsnet/network/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
Compute convolutional pre-activations given spikes using layer weights.
:param s: Incoming spikes.
:return: Spikes multiplied by synapse weights.
:return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation).
"""
return F.conv2d(s.float(), self.w, self.b, stride=self.stride, padding=self.padding, dilation=self.dilation)

Expand Down Expand Up @@ -376,7 +376,7 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
Compute max-pool pre-activations given spikes using online firing rate estimates.
:param s: Incoming spikes.
:return: Spikes multiplied by synapse weights.
:return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation).
"""
self.firing_rates -= self.decay * self.firing_rates
self.firing_rates += s.float()
Expand Down Expand Up @@ -514,7 +514,7 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
Compute pre-activations given spikes using layer weights.
:param s: Incoming spikes.
:return: Incoming spikes multiplied by synaptic weights (with or with decaying spike activation).
:return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation).
"""
# Compute multiplication of pre-activations by connection weights.
if self.w.shape[0] == self.source.n and self.w.shape[1] == self.target.n:
Expand Down Expand Up @@ -599,7 +599,7 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
Compute pre-activations given spikes using layer weights.
:param s: Incoming spikes.
:return: Incoming spikes multiplied by synaptic weights (with or with decaying spike activation).
:return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation).
"""
# Compute multiplication of mean-field pre-activation by connection weights.
return s.float().mean() * self.w
Expand Down Expand Up @@ -681,7 +681,7 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
Compute convolutional pre-activations given spikes using layer weights.
:param s: Incoming spikes.
:return: Spikes multiplied by synapse weights.
:return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation).
"""
return torch.mm(self.w, s.unsqueeze(-1).float()).squeeze(-1)

Expand Down
17 changes: 9 additions & 8 deletions bindsnet/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self, network: Network, environment: Environment, encoding: Callabl
self.v_ims, self.v_axes = None, None
self.obs_im, self.obs_ax = None, None
self.reward_im, self.reward_ax = None, None
self.accumulated_reward = 0
self.accumulated_reward = 0.0
self.reward_list = []

# Setting kwargs.
Expand Down Expand Up @@ -115,6 +115,7 @@ def __init__(self, network: Network, environment: Environment, encoding: Callabl
name: torch.Tensor() for name, layer in network.layers.items() if isinstance(layer, AbstractInput)
}

self.action = None
self.obs = None
self.reward = None
self.done = None
Expand Down Expand Up @@ -155,14 +156,12 @@ def step(self, **kwargs) -> None:

# Choose action based on output neuron spiking.
if self.action_function is not None:
a = self.action_function(self, output=self.output)
else:
a = None
self.action = self.action_function(self, output=self.output)

# Run a step of the environment.
self.obs, reward, self.done, info = self.env.step(a)
self.obs, reward, self.done, info = self.env.step(self.action)

# Set reward in case of delay
# Set reward in case of delay.
if self.reward_delay is not None:
self.rewards = torch.tensor([reward, *self.rewards[1:]]).float()
self.reward = self.rewards[-1]
Expand Down Expand Up @@ -194,10 +193,12 @@ def step(self, **kwargs) -> None:
self.iteration += 1

if self.done:
if self.network.reward_fn is not None:
self.network.reward_fn.update(**kwargs)
self.iteration = 0
self.episode += 1
self.reward_list.append(self.accumulated_reward)
self.accumulated_reward = 0
self.accumulated_reward = 0.0
self.plot_reward()

def plot_obs(self) -> None:
Expand Down Expand Up @@ -331,5 +332,5 @@ def reset_(self) -> None:
self.env.reset()
self.network.reset_()
self.iteration = 0
self.accumulated_reward = 0
self.accumulated_reward = 0.0
self.history = {i: torch.Tensor() for i in self.history}

0 comments on commit 8ebf7b0

Please sign in to comment.