Skip to content

Commit

Permalink
Merge pull request #610 from BindsNET/hananel
Browse files Browse the repository at this point in the history
Gymnasium and MSTDP
  • Loading branch information
Hananel-Hazan authored Jan 16, 2023
2 parents 8321825 + 7171592 commit 54e5dec
Show file tree
Hide file tree
Showing 9 changed files with 1,070 additions and 704 deletions.
2 changes: 1 addition & 1 deletion bindsnet/environment/dot_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pandas as pd
import torch
from gym import spaces
from gymnasium import spaces

# Mappings for changing direction if reflected.
# Cannot cross a row boundary moving right or left.
Expand Down
20 changes: 14 additions & 6 deletions bindsnet/environment/environment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple

import gym
import gymnasium as gym
import numpy as np
import torch

Expand Down Expand Up @@ -59,7 +59,13 @@ class GymEnvironment(Environment):
A wrapper around the OpenAI ``gym`` environments.
"""

def __init__(self, name: str, encoder: Encoder = NullEncoder(), **kwargs) -> None:
def __init__(
self,
name: str,
render_mode: str = "rgb_array",
encoder: Encoder = NullEncoder(),
**kwargs,
) -> None:
# language=rst
"""
Initializes the environment wrapper. This class makes the
Expand All @@ -82,7 +88,7 @@ def __init__(self, name: str, encoder: Encoder = NullEncoder(), **kwargs) -> Non
2D inputs.
"""
self.name = name
self.env = gym.make(name)
self.env = gym.make(name, render_mode=render_mode)
self.action_space = self.env.action_space

self.encoder = encoder
Expand All @@ -94,6 +100,7 @@ def __init__(self, name: str, encoder: Encoder = NullEncoder(), **kwargs) -> Non
self.history_length = kwargs.get("history_length", None)
self.delta = kwargs.get("delta", 1)
self.add_channel_dim = kwargs.get("add_channel_dim", True)
self.seed = kwargs.get("seed", None)

if self.history_length is not None and self.delta is not None:
self.history = {
Expand Down Expand Up @@ -122,7 +129,8 @@ def step(self, a: int) -> Tuple[torch.Tensor, float, bool, Dict[Any, Any]]:
:return: Observation, reward, done flag, and information dictionary.
"""
# Call gym's environment step function.
self.obs, self.reward, self.done, info = self.env.step(a)
self.obs, self.reward, terminated, truncated, info = self.env.step(a)
self.done = terminated or truncated

if self.clip_rewards:
self.reward = np.sign(self.reward)
Expand Down Expand Up @@ -162,15 +170,15 @@ def step(self, a: int) -> Tuple[torch.Tensor, float, bool, Dict[Any, Any]]:
# Return converted observations and other information.
return self.obs, self.reward, self.done, info

def reset(self) -> torch.Tensor:
def reset(self, seed=None) -> torch.Tensor:
# language=rst
"""
Wrapper around the OpenAI ``gym`` environment ``reset()`` function.
:return: Observation from the environment.
"""
# Call gym's environment reset function.
self.obs = self.env.reset()
self.obs, self.info = self.env.reset(seed=seed)
self.preprocess()

self.history = {i: torch.Tensor() for i in self.history}
Expand Down
18 changes: 12 additions & 6 deletions bindsnet/learning/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,12 +1528,18 @@ def _connection_update(self, **kwargs) -> None:

# Parse keyword arguments.
reward = kwargs["reward"]
a_plus = torch.tensor(
kwargs.get("a_plus", 1.0), device=self.connection.w.device
)
a_minus = torch.tensor(
kwargs.get("a_minus", -1.0), device=self.connection.w.device
)
a_plus = kwargs.get("a_plus", 1.0)
if isinstance(a_plus, dict):
for k, v in a_plus.items():
a_plus[k] = torch.tensor(v, device=self.connection.w.device)
else:
a_plus = torch.tensor(a_plus, device=self.connection.w.device)
a_minus = kwargs.get("a_minus", -1.0)
if isinstance(a_minus, dict):
for k, v in a_minus.items():
a_minus[k] = torch.tensor(v, device=self.connection.w.device)
else:
a_minus = torch.tensor(a_minus, device=self.connection.w.device)

# Compute weight update based on the eligibility value of the past timestep.
update = reward * self.eligibility
Expand Down
32 changes: 32 additions & 0 deletions bindsnet/network/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,39 @@ def run(
self.layers[l].s[:, unclamp[t]] = 0

# Run synapse updates.
if "a_minus" in kwargs:
A_Minus = kwargs["a_minus"]
kwargs.pop("a_minus")
if isinstance(A_Minus, dict):
A_MD = True
else:
A_MD = False
else:
A_Minus = None

if "a_plus" in kwargs:
A_Plus = kwargs["a_plus"]
kwargs.pop("a_plus")
if isinstance(A_Plus, dict):
A_PD = True
else:
A_PD = False
else:
A_Plus = None

for c in self.connections:
if A_Minus != None and ((isinstance(A_Minus, float)) or (c in A_Minus)):
if A_MD:
kwargs["a_minus"] = A_Minus[c]
else:
kwargs["a_minus"] = A_Minus

if A_Plus != None and ((isinstance(A_Plus, float)) or (c in A_Plus)):
if A_PD:
kwargs["a_plus"] = A_Plus[c]
else:
kwargs["a_plus"] = A_Plus

self.connections[c].update(
mask=masks.get(c, None), learning=self.learning, **kwargs
)
Expand Down
10 changes: 6 additions & 4 deletions bindsnet/pipeline/base_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import collections
import collections.abc
import time
from typing import Any, Dict, Tuple

Expand All @@ -25,11 +25,11 @@ def recursive_to(item, device):
return item.to(device)
elif isinstance(item, (string_classes, int, float, bool)):
return item
elif isinstance(item, collections.Mapping):
elif isinstance(item, collections.abc.Mapping):
return {key: recursive_to(item[key], device) for key in item}
elif isinstance(item, tuple) and hasattr(item, "_fields"):
return type(item)(*(recursive_to(i, device) for i in item))
elif isinstance(item, collections.Sequence):
elif isinstance(item, collections.abc.Sequence):
return [recursive_to(i, device) for i in item]
else:
raise NotImplementedError(f"Target type {type(item)} not supported.")
Expand Down Expand Up @@ -89,6 +89,7 @@ def __init__(self, network: Network, **kwargs) -> None:

self.print_interval = kwargs.get("print_interval", None)
self.test_interval = kwargs.get("test_interval", None)
self.plot_interval = kwargs.get("plot_interval", None)
self.step_count = 0
self.init_fn()
self.clock = time.time()
Expand Down Expand Up @@ -133,7 +134,8 @@ def step(self, batch: Any, **kwargs) -> Any:
)
self.clock = time.time()

self.plots(batch, step_out)
if self.plot_interval is not None and self.step_count % self.plot_interval == 0:
self.plots(batch, step_out)

if self.save_interval is not None and self.step_count % self.save_interval == 0:
self.network.save(self.save_dir)
Expand Down
6 changes: 5 additions & 1 deletion examples/breakout/breakout_stdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
network.add_connection(middle_out, source="Hidden Layer", target="Output Layer")

# Load the Breakout environment.
environment = GymEnvironment("BreakoutDeterministic-v4")
environment = GymEnvironment("BreakoutDeterministic-v4", render_mode="human")
environment.reset()

# Build pipeline from specified components.
Expand Down Expand Up @@ -69,6 +69,10 @@ def run_pipeline(pipeline, episode_count):
print(f"Episode {i} total reward:{total_reward}")


# enable MSTDP
environment_pipeline.network.learning = True


print("Training: ")
run_pipeline(environment_pipeline, episode_count=100)

Expand Down
13 changes: 7 additions & 6 deletions examples/breakout/random_network_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
input_exc_conn = Connection(
source=layers["X"],
target=layers["E"],
w=0.01 * torch.rand(layers["X"].n, layers["E"].n),
w=0.1 * torch.rand(layers["X"].n, layers["E"].n),
wmax=0.02,
norm=0.01 * layers["X"].n,
)
Expand All @@ -64,7 +64,7 @@
exc_readout_conn = Connection(
source=layers["E"],
target=layers["R"],
w=0.01 * torch.rand(layers["E"].n, layers["R"].n),
w=0.1 * torch.rand(layers["E"].n, layers["R"].n),
update_rule=Hebbian,
nu=[1e-2, 1e-2],
norm=0.5 * layers["E"].n,
Expand Down Expand Up @@ -95,16 +95,16 @@
network.add_monitor(voltages[layer], name="%s_voltages" % layer)

# Load the Breakout environment.
environment = GymEnvironment("BreakoutDeterministic-v4")
environment = GymEnvironment("BreakoutDeterministic-v4", render_mode="human")
environment.reset()

pipeline = EnvironmentPipeline(
network,
environment,
encoding=bernoulli,
time=1,
history=5,
delta=10,
history_length=1,
delta=1,
time=100,
plot_interval=plot_interval,
print_interval=print_interval,
render_interval=render_interval,
Expand All @@ -119,6 +119,7 @@
avg_lengths = []

i = 0
# pipeline.reset_state_variables()
try:
while i < n:
result = pipeline.env_step()
Expand Down
Loading

0 comments on commit 54e5dec

Please sign in to comment.