Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

M-STDP-ET fix + code cleanup. #141

Merged
merged 3 commits into from Oct 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion bindsnet/datasets/__init__.py
Expand Up @@ -14,7 +14,7 @@
from typing import Tuple, List, Iterable, Any

__all__ = [
'Dataset', 'MNIST', 'SpokenMNIST', 'CIFAR10', 'CIFAR100', 'preprocess'
'Dataset', 'MNIST', 'FashionMNIST', 'SpokenMNIST', 'CIFAR10', 'CIFAR100', 'preprocess'
]


Expand Down
1 change: 0 additions & 1 deletion bindsnet/evaluation/__init__.py
@@ -1,5 +1,4 @@
import torch
import numpy as np

from itertools import product
from typing import Optional, Tuple, Dict
Expand Down
87 changes: 38 additions & 49 deletions bindsnet/learning/__init__.py
@@ -1,4 +1,5 @@
import torch
import numpy as np

from abc import ABC
from typing import Union, Tuple, Optional
Expand Down Expand Up @@ -119,7 +120,7 @@ def __init__(self, connection: AbstractConnection, nu: Optional[Union[float, Tup
'This learning rule is not supported for this Connection type.'
)

def _connection_update(self, **kwargs) -> None:
def _connection_update(self, dt, **kwargs) -> None:
# language=rst
"""
Post-pre learning rule for ``Connection`` subclass of ``AbstractConnection`` class.
Expand All @@ -145,7 +146,7 @@ def _connection_update(self, **kwargs) -> None:

self.connection.w = self.connection.w.view(*shape)

def _conv2d_connection_update(self, **kwargs) -> None:
def _conv2d_connection_update(self, dt, **kwargs) -> None:
# language=rst
"""
Post-pre learning rule for ``Conv2dConnection`` subclass of ``AbstractConnection`` class.
Expand Down Expand Up @@ -206,7 +207,7 @@ def __init__(self, connection: AbstractConnection, nu: Optional[Union[float, Tup
'This learning rule is not supported for this Connection type.'
)

def _connection_update(self, **kwargs) -> None:
def _connection_update(self, dt, **kwargs) -> None:
# language=rst
"""
Hebbian learning rule for ``Connection`` subclass of ``AbstractConnection`` class.
Expand All @@ -232,7 +233,7 @@ def _connection_update(self, **kwargs) -> None:

self.connection.w = self.connection.w.view(*shape)

def _conv2d_connection_update(self, **kwargs) -> None:
def _conv2d_connection_update(self, dt, **kwargs) -> None:
# language=rst
"""
Hebbian learning rule for ``Conv2dConnection`` subclass of ``AbstractConnection`` class.
Expand Down Expand Up @@ -291,7 +292,7 @@ def __init__(self, connection: AbstractConnection, nu: Optional[Union[float, Tup
'This learning rule is not supported for this Connection type.'
)

def _connection_update(self, **kwargs) -> None:
def _connection_update(self, dt, **kwargs) -> None:
# language=rst
"""
M-STDP learning rule for ``Connection`` subclass of ``AbstractConnection`` class.
Expand Down Expand Up @@ -328,7 +329,7 @@ def _connection_update(self, **kwargs) -> None:
self.connection.w += self.nu[0] * reward * eligibility
self.connection.w = self.connection.w.view(*shape)

def _conv2d_connection_update(self, **kwargs) -> None:
def _conv2d_connection_update(self, dt, **kwargs) -> None:
# language=rst
"""
M-STDP learning rule for ``Conv2dConnection`` subclass of ``AbstractConnection`` class.
Expand All @@ -349,25 +350,23 @@ def _conv2d_connection_update(self, **kwargs) -> None:
out_channels, _, kernel_height, kernel_width = self.connection.w.size()
padding, stride = self.connection.padding, self.connection.stride

# Get P^+ and P^- values (function of firing traces), and reshape source and target spikes.
p_plus = a_plus * im2col_indices(
# Reshaping spike traces and spike occurrences.
x_source = im2col_indices(
self.source.x, kernel_height, kernel_width, padding=padding, stride=stride
)
p_minus = a_minus * self.target.x.permute(1, 2, 3, 0).reshape(out_channels, -1)
pre_fire = im2col_indices(
x_target = self.target.x.permute(1, 2, 3, 0).reshape(out_channels, -1)
s_source = im2col_indices(
self.source.s.float(), kernel_height, kernel_width, padding=padding, stride=stride
)
post_fire = self.target.s.permute(1, 2, 3, 0).reshape(out_channels, -1).float()
s_target = self.target.s.permute(1, 2, 3, 0).reshape(out_channels, -1).float()

# Post-synaptic.
post = (p_plus @ post_fire.t()).view(self.connection.w.size())
if post.max() > 0:
post = post / post.max()
# Get P^+ and P^- values (function of firing traces), and reshape source and target spikes.
p_plus = a_plus * x_source
p_minus = a_minus * x_target

# Pre-synaptic.
pre = (pre_fire @ p_minus.t()).view(self.connection.w.size())
if pre.max() > 0:
pre = pre / pre.max()
# Pre- and post-synaptic updates.
pre = (s_source @ p_minus.t()).view(self.connection.w.size())
post = (p_plus @ s_target.t()).view(self.connection.w.size())

# Calculate point eligibility value.
eligibility = post + pre
Expand Down Expand Up @@ -408,14 +407,14 @@ def __init__(self, connection: AbstractConnection, nu: Optional[Union[float, Tup
'This learning rule is not supported for this Connection type.'
)

self.e_trace = 0
self.e_trace = torch.zeros(self.source.n, self.target.n)
self.tc_e_trace = 0.04
self.p_plus = 0
self.p_plus = torch.zeros(self.source.n)
self.tc_plus = 0.05
self.p_minus = 0
self.p_minus = torch.zeros(self.target.n)
self.tc_minus = 0.05

def _connection_update(self, **kwargs) -> None:
def _connection_update(self, dt, **kwargs) -> None:
# language=rst
"""
M-STDP-ET learning rule for ``Connection`` subclass of ``AbstractConnection`` class.
Expand All @@ -433,27 +432,22 @@ def _connection_update(self, **kwargs) -> None:
target_s = self.target.s.view(-1).float()
target_x = self.target.x.view(-1)

shape = self.connection.w.shape
self.connection.w = self.connection.w.view(self.source.n, self.target.n)

# Parse keyword arguments.
reward = kwargs['reward']
a_plus = kwargs.get('a_plus', 1)
a_minus = kwargs.get('a_plus', -1)

# Get P^+ and P^- values (function of firing traces).
self.p_plus = -(self.tc_plus * self.p_plus) + a_plus * source_x
self.p_minus = -(self.tc_minus * self.p_minus) + a_minus * target_x
self.p_plus = self.p_plus * np.exp(-dt * self.tc_plus) + a_plus * source_x
self.p_minus = self.p_minus * np.exp(-dt * self.tc_minus) + a_minus * target_x

# Calculate value of eligibility trace.
self.e_trace -= self.tc_e_trace * self.e_trace
self.e_trace += torch.ger(self.p_plus, target_s) + torch.ger(source_s, self.p_minus)

# Compute weight update.
self.connection.w += self.nu[0] * reward * self.e_trace
self.connection.w = self.connection.w.view(*shape)

def _conv2d_connection_update(self, **kwargs) -> None:
def _conv2d_connection_update(self, dt, **kwargs) -> None:
# language=rst
"""
M-STDP-ET learning rule for ``Conv2dConnection`` subclass of ``AbstractConnection`` class.
Expand All @@ -474,31 +468,26 @@ def _conv2d_connection_update(self, **kwargs) -> None:
out_channels, _, kernel_height, kernel_width = self.connection.w.size()
padding, stride = self.connection.padding, self.connection.stride

# Get P^+ and P^- values (function of firing traces).
self.p_plus = -(self.tc_plus * self.p_plus) + a_plus * im2col_indices(
# Reshaping spike traces and spike occurrences.
x_source = im2col_indices(
self.source.x, kernel_height, kernel_width, padding=padding, stride=stride
)
self.p_minus = -(self.tc_minus * self.p_minus) + a_minus * \
self.target.x.permute(1, 2, 3, 0).reshape(out_channels, -1)

# Get pre- and post-synaptic spiking neurons.
pre_fire = im2col_indices(
x_target = self.target.x.permute(1, 2, 3, 0).reshape(out_channels, -1)
s_source = im2col_indices(
self.source.s.float(), kernel_height, kernel_width, padding=padding, stride=stride
)
post_fire = self.target.s.permute(1, 2, 3, 0).reshape(out_channels, -1).float()
s_target = self.target.s.permute(1, 2, 3, 0).reshape(out_channels, -1).float()

# Post-synaptic.
post = (self.p_plus @ post_fire.t()).view(self.connection.w.size())
if post.max() > 0:
post = post / post.max()
# Get P^+ and P^- values (function of firing traces).
self.p_plus = self.p_plus * np.exp(-dt / self.tc_plus) + a_plus * x_source
self.p_minus = self.p_minus * np.exp(-dt / self.tc_minus) + a_minus * x_target

# Pre-synaptic.
pre = (pre_fire @ self.p_minus.t()).view(self.connection.w.size())
if pre.max() > 0:
pre = pre / pre.max()
# Post-synaptic and pre-synaptic updates.
post = (self.p_plus @ s_target.t()).view(self.connection.w.size())
pre = (s_source @ self.p_minus.t()).view(self.connection.w.size())

# Calculate point eligibility value.
self.e_trace += -(self.tc_e_trace * self.e_trace) + (post + pre)
# Calculate value of eligibility trace.
self.e_trace = post + pre

# Compute weight update.
self.connection.w += self.nu[0] * reward * self.e_trace
5 changes: 2 additions & 3 deletions bindsnet/network/__init__.py
Expand Up @@ -156,7 +156,7 @@ def save(self, fname: str) -> None:
network.add_connection(connection=C, source='X', target='Y')

# Save the network to disk.
network.save(str(Path.home()) + '/network.p')
network.save(str(Path.home()) + '/network.pt')
"""
torch.save(self, open(fname, 'wb'))

Expand Down Expand Up @@ -232,7 +232,6 @@ def run(self, inpts: Dict[str, torch.Tensor], time: int, **kwargs) -> None:
"""
# Parse keyword arguments.
clamps = kwargs.get('clamp', {})
clamps_v = kwargs.get('clamp_v', {})
reward = kwargs.get('reward', None)
masks = kwargs.get('masks', {})

Expand All @@ -259,7 +258,7 @@ def run(self, inpts: Dict[str, torch.Tensor], time: int, **kwargs) -> None:
# Run synapse updates.
for c in self.connections:
self.connections[c].update(
reward=reward, mask=masks.get(c, None), learning=self.learning
dt=self.dt, reward=reward, mask=masks.get(c, None), learning=self.learning
)

# Get input to all layers.
Expand Down
25 changes: 13 additions & 12 deletions bindsnet/network/topology.py
Expand Up @@ -74,15 +74,16 @@ def compute(self, s: torch.Tensor) -> None:
pass

@abstractmethod
def update(self, **kwargs) -> None:
def update(self, dt, **kwargs) -> None:
# language=rst
"""
Compute connection's update rule.
"""
learning = kwargs.get('learning', True)
reward = kwargs.get('reward', None)

if learning:
self.update_rule.update(reward=reward)
self.update_rule.update(dt=dt, reward=reward)

mask = kwargs.get('mask', None)
if mask is not None:
Expand Down Expand Up @@ -135,9 +136,9 @@ def __init__(self, source: Nodes, target: Nodes, nu: Optional[Union[float, Tuple
self.w = kwargs.get('w', None)

if self.w is None:
if self.wmin is None and self.wmax is None:
if self.wmin is None or self.wmax is None:
self.w = torch.rand(source.n, target.n)
else:
elif self.wmin is not None and self.wmax is not None:
self.w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin)
else:
if self.wmin is not None and self.wmax is not None:
Expand All @@ -158,12 +159,12 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
a_post = self.a_pre @ self.w
return a_post.view(*self.target.shape)

def update(self, **kwargs) -> None:
def update(self, dt, **kwargs) -> None:
# language=rst
"""
Compute connection's update rule.
"""
super().update(**kwargs)
super().update(dt=dt, **kwargs)

def normalize(self) -> None:
# language=rst
Expand Down Expand Up @@ -249,12 +250,12 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
"""
return F.conv2d(s.float(), self.w, stride=self.stride, padding=self.padding, dilation=self.dilation)

def update(self, **kwargs) -> None:
def update(self, dt, **kwargs) -> None:
# language=rst
"""
Compute connection's update rule.
"""
super().update(**kwargs)
super().update(dt=dt, **kwargs)

def normalize(self) -> None:
# language=rst
Expand Down Expand Up @@ -384,15 +385,15 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
a_post = self.a_pre @ self.w.view(self.source.n, self.target.n)
return a_post.view(*self.target.shape)

def update(self, **kwargs) -> None:
def update(self, dt, **kwargs) -> None:
# language=rst
"""
Compute connection's update rule.
"""
if kwargs['mask'] is None:
kwargs['mask'] = self.mask

super().update(**kwargs)
super().update(dt=dt, **kwargs)

def normalize(self) -> None:
# language=rst
Expand Down Expand Up @@ -462,12 +463,12 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
# Compute multiplication of mean-field pre-activation by connection weights.
return self.a_pre * self.w

def update(self, **kwargs) -> None:
def update(self, dt, **kwargs) -> None:
# language=rst
"""
Compute connection's update rule.
"""
super().update(**kwargs)
super().update(dt=dt, **kwargs)

def normalize(self) -> None:
# language=rst
Expand Down
2 changes: 1 addition & 1 deletion bindsnet/pipeline/__init__.py
Expand Up @@ -63,7 +63,7 @@ def __init__(self, network: Network, environment: Environment, encoding: Callabl
self.time = kwargs.get('time', 1)
self.delta = kwargs.get('delta', 1)
self.output = kwargs.get('output', None)
self.save_dir = kwargs.get('save_dir', 'network.p')
self.save_dir = kwargs.get('save_dir', 'network.pt')
self.plot_interval = kwargs.get('plot_interval', None)
self.save_interval = kwargs.get('save_interval', None)
self.print_interval = kwargs.get('print_interval', None)
Expand Down
2 changes: 1 addition & 1 deletion bindsnet/preprocessing/__init__.py
Expand Up @@ -20,7 +20,7 @@ class AbstractPreprocessor(ABC):
Abstract base class for Preprocessor.
"""

def process(self, csvfile: str, use_cache: bool = True, cachedfile: str = './processed/data.p') -> torch.tensor:
def process(self, csvfile: str, use_cache: bool = True, cachedfile: str = './processed/data.pt') -> torch.tensor:
# cache dictionary for storing encodings if previously encoded
cache = {'verify': '', 'data': None}

Expand Down
15 changes: 10 additions & 5 deletions examples/space_invaders/et_space_invaders.py
Expand Up @@ -11,6 +11,7 @@
from bindsnet.network.topology import Connection
from bindsnet.pipeline import Pipeline
from bindsnet.pipeline.action import select_multinomial
from bindsnet.analysis.plotting import plot_weights


parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -54,12 +55,15 @@

# Connections between layers.
# Input -> excitatory.
input_exc_conn = Connection(source=layers['X'], target=layers['E'],
w=torch.rand(layers['X'].n, layers['E'].n), wmax=1e-2)
input_exc_conn = Connection(
source=layers['X'], target=layers['E'], w=torch.rand(layers['X'].n, layers['E'].n), wmax=1e-2
)

# Excitatory -> readout.
exc_readout_conn = Connection(source=layers['E'], target=layers['R'], w=torch.rand(layers['E'].n, layers['R'].n),
wmax=0.5, update_rule=MSTDPET, nu=2e-2, norm=0.15 * layers['E'].n)
exc_readout_conn = Connection(
source=layers['E'], target=layers['R'], w=torch.rand(layers['E'].n, layers['R'].n), wmin=-0.5,
wmax=0.5, update_rule=MSTDPET, nu=1e-4, norm=0.15 * layers['E'].n
)

# Spike recordings for all layers.
spikes = {}
Expand Down Expand Up @@ -93,12 +97,13 @@
plot_interval=plot_interval, print_interval=print_interval, render_interval=render_interval,
action_function=select_multinomial, output='R')

weights_im = None
try:
while True:
pipeline.step()

if pipeline.done:
pipeline.reset_()

except KeyboardInterrupt:
plt.close("all")
environment.close()
2 changes: 1 addition & 1 deletion examples/space_invaders/random_baseline.py
Expand Up @@ -68,4 +68,4 @@
i += 1

save = (total, rewards, avg_rewards, lengths, avg_lengths)
p.dump(save, open(os.path.join('..', '..', 'results', 'SI_random_baseline_%d.p' % n), 'wb'))
p.dump(save, open(os.path.join('..', '..', 'results', 'SI_random_baseline_%d.pt' % n), 'wb'))