In [None]:
import torch
import random
import gym
from gym import wrappers
import torch
import torch.nn as nn
import time
ti = time.time()
import imageio
import numpy as np
from skimage.transform import resize
import torch.nn.functional as F
import random
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import tensorflow.compat.v1 as tf

In [None]:
def generate_gif(frames_for_gif):
    for idx, frame_idx in enumerate(frames_for_gif): 
        frames_for_gif[idx] = resize(frame_idx, (420, 320, 3), 
                                     preserve_range=True, order=0).astype(np.uint8)
        
    imageio.mimsave("ATARI_PONG.gif", 
                    frames_for_gif, duration=1/30)

In [None]:
def frameprocess(frame,frame_height=84, frame_width=84):
    frame_height = frame_height
    frame_width = frame_width
    processed = tf.image.rgb_to_grayscale(frame)
    processed = tf.image.crop_to_bounding_box(processed, 34, 0, 160, 160)
    processed = tf.image.resize_images(processed, 
                                            [frame_height, frame_width], 
                                            method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    return processed

In [None]:
from abc import ABC, abstractmethod
from typing import Tuple, Dict, Any

import gym
import numpy as np
import torch

from bindsnet.datasets.preprocess import subsample, gray_scale, binary_image, crop
from bindsnet.encoding import Encoder, NullEncoder


class Environment(ABC):
    # language=rst
    """
    Abstract environment class.
    """

    @abstractmethod
    def step(self, a: int) -> Tuple[Any, ...]:
        # language=rst
        """
        Abstract method head for ``step()``.

        :param a: Integer action to take in environment.
        """
        pass

    @abstractmethod
    def reset(self) -> None:
        # language=rst
        """
        Abstract method header for ``reset()``.
        """
        pass

    @abstractmethod
    def render(self) -> None:
        # language=rst
        """
        Abstract method header for ``render()``.
        """
        pass

    @abstractmethod
    def close(self) -> None:
        # language=rst
        """
        Abstract method header for ``close()``.
        """
        pass

    @abstractmethod
    def preprocess(self) -> None:
        # language=rst
        """
        Abstract method header for ``preprocess()``.
        """
        pass


class GymEnvironment(Environment):
    # language=rst
    """
    A wrapper around the OpenAI ``gym`` environments.
    """

    def __init__(self, name: str, encoder: Encoder = NullEncoder(), **kwargs) -> None:
        # language=rst
        """
        Initializes the environment wrapper. This class makes the
        assumption that the OpenAI ``gym`` environment will provide an image
        of format HxW or CxHxW as an observation (we will add the C
        dimension to HxW tensors) or a 1D observation in which case no
        dimensions will be added.

        :param name: The name of an OpenAI ``gym`` environment.
        :param encoder: Function to encode observations into spike trains.

        Keyword arguments:

        :param float max_prob: Maximum spiking probability.
        :param bool clip_rewards: Whether or not to use ``np.sign`` of rewards.

        :param int history: Number of observations to keep track of.
        :param int delta: Step size to save observations in history.
        :param bool add_channel_dim: Allows for the adding of the channel dimension in
            2D inputs.
        """
        self.name = name
        self.frames = []
        self.env = gym.make(name)
        self.action_space = self.env.action_space
        self.no_op_steps = 10
        self.agent_history_length = 4
        self.state = None
        self.encoder = encoder
        # Keyword arguments.
        self.max_prob = kwargs.get("max_prob", 1.0)
        self.clip_rewards = kwargs.get("clip_rewards", True)

        self.history_length = kwargs.get("history_length", None)
        self.delta = kwargs.get("delta", 1)
        self.add_channel_dim = kwargs.get("add_channel_dim", True)

        if self.history_length is not None and self.delta is not None:
            self.history = {
                i: torch.Tensor()
                for i in range(1, self.history_length * self.delta + 1, self.delta)
            }
        else:
            self.history = {}

        self.episode_step_count = 0
        self.history_index = 1

        self.obs = None
        self.reward = None

        assert (
            0.0 < self.max_prob <= 1.0
        ), "Maximum spiking probability must be in (0, 1]."
        
    def get_cart_location(self,screen_width):
        world_width = self.env.x_threshold * 2
        scale = screen_width / world_width
        return int(self.env.state[0] * scale + screen_width / 2.0)  # MIDDLE OF CART

    def get_screen(self):
        # Returned screen requested by gym is 400x600x3, but is sometimes larger
        # such as 800x1200x3. Transpose it into torch order (CHW).
        screen = self.env.render(mode='rgb_array').transpose((2, 0, 1))
        # Cart is in the lower half, so strip off the top and bottom of the screen
        _, screen_height, screen_width = screen.shape
        screen = screen[:, int(screen_height*0.4):int(screen_height * 0.8)]
        view_width = int(screen_width * 0.6)
        cart_location = self.get_cart_location(screen_width)
        if cart_location < view_width // 2:
            slice_range = slice(view_width)
        elif cart_location > (screen_width - view_width // 2):
            slice_range = slice(-view_width, None)
        else:
            slice_range = slice(cart_location - view_width // 2,
                                cart_location + view_width // 2)
        # Strip off the edges, so that we have a square image centered on a cart
        screen = screen[:, :, slice_range]
        # Convert to float, rescale, convert to torch tensor
        # (this doesn't require a copy)
        screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
        screen = torch.from_numpy(screen)
        # Resize, and add a batch dimension (BCHW)
        return resize(screen,(80,80))#.unsqueeze(0)

        
        
    def generate_gif(self,name):
        frames_for_gif = self.frames
        for idx, frame_idx in enumerate(frames_for_gif): 
            frames_for_gif[idx] = resize(frame_idx, (420, 320, 3), preserve_range=True, order=0).astype(np.uint8)
            
        imageio.mimsave(str(name)+"ATARI_PONG.gif", frames_for_gif, duration=1/30)

    def step(self, a: int) -> Tuple[torch.Tensor, float, bool, Dict[Any, Any]]:
        # language=rst
        """
        Wrapper around the OpenAI ``gym`` environment ``step()`` function.

        :param a: Action to take in the environment.
        :return: Observation, reward, done flag, and information dictionary.
        """
        # Call gym's environment step function.
        self.obs, self.reward, self.done, info = self.env.step(a)
        self.frames.append(self.obs)
        #processed_new_frame = frameprocess(self.obs)   # (6★)
        #new_state = np.append(self.state[:, :, 1:], processed_new_frame, axis=2) # (6★)   
        #self.state = new_state

        if self.clip_rewards:
            self.reward = np.sign(self.reward)

        self.preprocess()

        # Add the raw observation from the gym environment into the info
        # for debugging and display.
        info["gym_obs"] = self.obs

        # Store frame of history and encode the inputs.
        if len(self.history) > 0:
            self.update_history()
            self.update_index()
            # Add the delta observation into the info for debugging and display.
            info["delta_obs"] = self.obs

        # The new standard for images is BxTxCxHxW.
        # The gym environment doesn't follow exactly the same protocol.
        #
        # 1D observations will be left as is before the encoder and will become BxTxL.
        # 2D observations are assumed to be mono images will become BxTx1xHxW
        # 3D observations will become BxTxCxHxW
        print(self.obs.dim)
        if self.obs.dim() == 2 and self.add_channel_dim:
            # We want CxHxW, it is currently HxW.
            self.obs = self.obs.unsqueeze(0)

        # The encoder will add time - now Tx...
        if self.encoder is not None:
            self.obs = self.encoder(self.obs)

        # Add the batch - now BxTx...
        self.obs = self.obs.unsqueeze(0)

        self.episode_step_count += 1
        
        
        # Return converted observations and other information.
        return self.obs, self.reward, self.done, info

    def reset(self) -> torch.Tensor:
        # language=rst
        """
        Wrapper around the OpenAI ``gym`` environment ``reset()`` function.

        :return: Observation from the environment.
        """
        # Call gym's environment reset function.
        self.obs = self.env.reset()
        self.preprocess()

        self.history = {i: torch.Tensor() for i in self.history}

        self.episode_step_count = 0
        
        
        #for _ in range(random.randint(1, self.no_op_steps)):
        #    frame, _, _, _ = self.env.step(1) # Action 'Fire'
        #processed_frame = frameprocess(frame)   # (★★★)
        #self.state = np.repeat(processed_frame, self.agent_history_length, axis=2)
        

        return self.obs

    def render(self) -> None:
        # language=rst
        """
        Wrapper around the OpenAI ``gym`` environment ``render()`` function.
        """
        pass
        #self.env.render()

    def close(self) -> None:
        # language=rst
        """
        Wrapper around the OpenAI ``gym`` environment ``close()`` function.
        """
        self.env.close()

    def preprocess(self) -> None:
        # language=rst
        """
        Pre-processing step for an observation from a ``gym`` environment.
        """
        if self.name == "SpaceInvaders-v0":
            self.obs = subsample(gray_scale(self.obs), 84, 110)
            self.obs = self.obs[26:104, :]
            self.obs = binary_image(self.obs)
        elif self.name == "BreakoutDeterministic-v4":
            self.obs = subsample(gray_scale(crop(self.obs, 34, 194, 0, 160)), 80, 80)
            self.obs = binary_image(self.obs)
        else:  # Default pre-processing step.
            self.obs = self.get_screen()
            self.obs = subsample(crop(self.obs, 34, 194, 0, 160), 80, 80)
            self.obs = binary_image(self.obs)

        self.obs = torch.from_numpy(self.obs).float()

    def update_history(self) -> None:
        # language=rst
        """
        Updates the observations inside history by performing subtraction from most
        recent observation and the sum of previous observations. If there are not enough
        observations to take a difference from, simply store the observation without any
        differencing.
        """
        # Recording initial observations.
        if self.episode_step_count < len(self.history) * self.delta:
            # Store observation based on delta value.
            if self.episode_step_count % self.delta == 0:
                self.history[self.history_index] = self.obs
        else:
            # Take difference between stored frames and current frame.
            temp = torch.clamp(self.obs - sum(self.history.values()), 0, 1)

            # Store observation based on delta value.
            if self.episode_step_count % self.delta == 0:
                self.history[self.history_index] = self.obs

            assert (
                len(self.history) == self.history_length
            ), "History size is out of bounds"
            self.obs = temp

    def update_index(self) -> None:
        # language=rst
        """
        Updates the index to keep track of history. For example: ``history = 4``,
        ``delta = 3`` will produce ``self.history = {1, 4, 7, 10}`` and
        ``self.history_index`` will be updated according to ``self.delta`` and will wrap
        around the history dictionary.
        """
        if self.episode_step_count % self.delta == 0:
            if self.history_index != max(self.history.keys()):
                self.history_index += self.delta
            else:
                # Wrap around the history.
                self.history_index = (self.history_index % max(self.history.keys())) + 1

In [None]:
if False:
    from abc import ABC, abstractmethod
    from typing import Tuple, Dict, Any
    import torchvision.transforms as transforms
    import gym
    import numpy as np
    import torch

    from bindsnet.datasets.preprocess import subsample, gray_scale, binary_image, crop
    from bindsnet.encoding import Encoder, NullEncoder


    class Environment(ABC):
        # language=rst
        """
        Abstract environment class.
        """

        @abstractmethod
        def step(self, a: int) -> Tuple[Any, ...]:
            # language=rst
            """
            Abstract method head for ``step()``.

            :param a: Integer action to take in environment.
            """
            pass

        @abstractmethod
        def reset(self) -> None:
            # language=rst
            """
            Abstract method header for ``reset()``.
            """
            pass

        @abstractmethod
        def render(self) -> None:
            # language=rst
            """
            Abstract method header for ``render()``.
            """
            pass

        @abstractmethod
        def close(self) -> None:
            # language=rst
            """
            Abstract method header for ``close()``.
            """
            pass

        @abstractmethod
        def preprocess(self) -> None:
            # language=rst
            """
            Abstract method header for ``preprocess()``.
            """
            pass


    class GymEnvironment(Environment):
        # language=rst
        """
        A wrapper around the OpenAI ``gym`` environments.
        """

        def __init__(self, name: str, encoder: Encoder = NullEncoder(), **kwargs) -> None:
            # language=rst
            """
            Initializes the environment wrapper. This class makes the
            assumption that the OpenAI ``gym`` environment will provide an image
            of format HxW or CxHxW as an observation (we will add the C
            dimension to HxW tensors) or a 1D observation in which case no
            dimensions will be added.

            :param name: The name of an OpenAI ``gym`` environment.
            :param encoder: Function to encode observations into spike trains.

            Keyword arguments:

            :param float max_prob: Maximum spiking probability.
            :param bool clip_rewards: Whether or not to use ``np.sign`` of rewards.

            :param int history: Number of observations to keep track of.
            :param int delta: Step size to save observations in history.
            :param bool add_channel_dim: Allows for the adding of the channel dimension in
                2D inputs.
            """
            self.name = name
            self.frames = []
            self.env = gym.make(name)
            self.action_space = self.env.action_space
            self.no_op_steps = 10
            self.agent_history_length = 4
            self.state = np.zeros((84,84,4))
            self.encoder = encoder

            # Keyword arguments.
            self.max_prob = kwargs.get("max_prob", 1.0)
            self.clip_rewards = kwargs.get("clip_rewards", True)

            self.history_length = kwargs.get("history_length", None)
            self.delta = kwargs.get("delta", 1)
            self.add_channel_dim = kwargs.get("add_channel_dim", True)

            if self.history_length is not None and self.delta is not None:
                self.history = {
                    i: torch.Tensor()
                    for i in range(1, self.history_length * self.delta + 1, self.delta)
                }
            else:
                self.history = {}

            self.episode_step_count = 0
            self.history_index = 1

            self.obs = None
            self.reward = None

            assert (
                0.0 < self.max_prob <= 1.0
            ), "Maximum spiking probability must be in (0, 1]."
            
        def generate_gif(self,name):
            frames_for_gif = self.frames
            for idx, frame_idx in enumerate(frames_for_gif): 
                frames_for_gif[idx] = resize(frame_idx, (420, 320, 3), preserve_range=True, order=0).astype(np.uint8)
                
            imageio.mimsave(str(name)+"ATARI_PONG.gif", frames_for_gif, duration=1/30)

        def step(self, a: int) -> Tuple[torch.Tensor, float, bool, Dict[Any, Any]]:
            # language=rst
            """
            Wrapper around the OpenAI ``gym`` environment ``step()`` function.

            :param a: Action to take in the environment.
            :return: Observation, reward, done flag, and information dictionary.
            """
            # Call gym's environment step function.
            self.obs, self.reward, self.done, info = self.env.step(a)
            self.frames.append(self.obs)
            processed_new_frame = frameprocess(self.obs)   # (6★)
            new_state = np.append(self.state[:, :, 1:], processed_new_frame, axis=2) # (6★)   
            self.state = new_state

            if self.clip_rewards:
                self.reward = np.sign(self.reward)

            self.preprocess()

            # Add the raw observation from the gym environment into the info
            # for debugging and display.
            info["gym_obs"] = self.obs

            # Store frame of history and encode the inputs.
            if len(self.history) > 0:
                self.update_history()
                self.update_index()
                # Add the delta observation into the info for debugging and display.
                info["delta_obs"] = self.obs

            # The new standard for images is BxTxCxHxW.
            # The gym environment doesn't follow exactly the same protocol.
            #
            # 1D observations will be left as is before the encoder and will become BxTxL.
            # 2D observations are assumed to be mono images will become BxTx1xHxW
            # 3D observations will become BxTxCxHxW
            if self.obs.dim() == 2 and self.add_channel_dim:
                # We want CxHxW, it is currently HxW.
                self.obs = self.obs.unsqueeze(0)

            # The encoder will add time - now Tx...
            if self.encoder is not None:
                self.obs = self.encoder(self.obs)

            # Add the batch - now BxTx...
            self.obs = self.obs.unsqueeze(0)

            self.episode_step_count += 1
            
            
            # Return converted observations and other information.
            return self.obs, self.reward, self.done, info

        def reset(self) -> torch.Tensor:
            self.frames = []
            # language=rst
            """
            Wrapper around the OpenAI ``gym`` environment ``reset()`` function.

            :return: Observation from the environment.
            """
            # Call gym's environment reset function.
            self.obs = self.env.reset()
            self.preprocess()

            self.history = {i: torch.Tensor() for i in self.history}

            self.episode_step_count = 0
            
            
            for _ in range(random.randint(1, self.no_op_steps)):
                frame, _, _, _ = self.env.step(1) # Action 'Fire'
            processed_frame = frameprocess(frame)   # (★★★)
            self.state = np.repeat(processed_frame, self.agent_history_length, axis=2)
            

            return self.obs

        def render(self) -> None:
            # language=rst
            """
            Wrapper around the OpenAI ``gym`` environment ``render()`` function.
            """
            
            #self.env.render()

        def close(self) -> None:
            # language=rst
            """
            Wrapper around the OpenAI ``gym`` environment ``close()`` function.
            """
            self.env.close()

        def preprocess(self) -> None:
            # language=rst
            """
            Pre-processing step for an observation from a ``gym`` environment.
            """
            tt = transforms.ToTensor()
            if self.name == "SpaceInvaders-v0":
                #self.obs = self.state
                transform = transforms.ToPILImage()
                self.obs = tt(self.state)
                print(self.obs.dim)
            elif self.name == "BreakoutDeterministic-v4":
                transform = transforms.ToPILImage()
                self.obs = tt(self.state)
            else:  # Default pre-processing step.
                self.obs = tt(self.state)

        def update_history(self) -> None:
            # language=rst
            """
            Updates the observations inside history by performing subtraction from most
            recent observation and the sum of previous observations. If there are not enough
            observations to take a difference from, simply store the observation without any
            differencing.
            """
            # Recording initial observations.
            if self.episode_step_count < len(self.history) * self.delta:
                # Store observation based on delta value.
                if self.episode_step_count % self.delta == 0:
                    self.history[self.history_index] = self.obs
            else:
                # Take difference between stored frames and current frame.
                temp = torch.clamp(self.obs - sum(self.history.values()), 0, 1)

                # Store observation based on delta value.
                if self.episode_step_count % self.delta == 0:
                    self.history[self.history_index] = self.obs

                assert (
                    len(self.history) == self.history_length
                ), "History size is out of bounds"
                self.obs = temp

        def update_index(self) -> None:
            # language=rst
            """
            Updates the index to keep track of history. For example: ``history = 4``,
            ``delta = 3`` will produce ``self.history = {1, 4, 7, 10}`` and
            ``self.history_index`` will be updated according to ``self.delta`` and will wrap
            around the history dictionary.
            """
            if self.episode_step_count % self.delta == 0:
                if self.history_index != max(self.history.keys()):
                    self.history_index += self.delta
                else:
                    # Wrap around the history.
                    self.history_index = (self.history_index % max(self.history.keys())) + 1



In [None]:
from bindsnet.encoding import bernoulli,poisson
#from bindsnet.environment import GymEnvironment
from bindsnet.learning import MSTDP
from bindsnet.network import Network
from bindsnet.network.nodes import Input, LIFNodes
from bindsnet.network.topology import Connection
from bindsnet.pipeline import EnvironmentPipeline
from bindsnet.pipeline.action import select_softmax
from collections import Counter
# Build network.
network = Network(dt=1.0)

# Layers of neurons.
inpt = Input(n=80 * 80, shape=[1, 1, 1, 80, 80], traces=True)
middle = LIFNodes(n=100, traces=True)
out = LIFNodes(n=2, refrac=0, traces=True)

# Connections between layers.
inpt_middle = Connection(source=inpt, target=middle, wmin=-1, wmax=1e-1)
middle_out = Connection(
    source=middle,
    target=out,
    wmin=-1,
    wmax=1,
    update_rule=MSTDP,
    nu=1e-1,
    norm=0.5 * middle.n,
)

# Add all layers and connections to the network.
network.add_layer(inpt, name="Input Layer")
network.add_layer(middle, name="Hidden Layer")
network.add_layer(out, name="Output Layer")
network.add_connection(inpt_middle, source="Input Layer", target="Hidden Layer")
network.add_connection(middle_out, source="Hidden Layer", target="Output Layer")

# Load the Breakout environment.
environment = GymEnvironment("CartPole-v0")
environment.reset()

# Build pipeline from specified components.
environment_pipeline = EnvironmentPipeline(
    network,
    environment,
    encoding=poisson,
    action_function=select_softmax,
    output="Output Layer",
    time=100,
    history_length=4,
    delta=1,
    plot_interval=1,
    render_interval=1,
)


def run_pipeline(pipeline, episode_count):
    ep_reward = []
    for i in range(episode_count):
        total_reward = 0
        pipeline.reset_state_variables()
        is_done = False
        while not is_done:
            result = pipeline.env_step()
            pipeline.step(result)

            reward = result[1]
            total_reward += reward
            is_done = result[2]
        ep_reward+=[int(total_reward)]
        print(f"Episode {i} total reward:{total_reward}")
    return ep_reward


print("Training: ")
tr = run_pipeline(environment_pipeline, episode_count=10)
print(tr)
# stop MSTDP
environment_pipeline.network.learning = False

print("Testing: ")
run_pipeline(environment_pipeline, episode_count=1)
environment.generate_gif(str(1000)+" Training Episodes")

In [None]:
from bindsnet.encoding import bernoulli,poisson
#from bindsnet.environment import GymEnvironment
from bindsnet.learning import MSTDP
from bindsnet.network import Network
from bindsnet.network.nodes import Input, LIFNodes
from bindsnet.network.topology import Connection
from bindsnet.pipeline import EnvironmentPipeline
from bindsnet.pipeline.action import select_softmax
from collections import Counter
# Build network.
network = Network(dt=1.0)

# Layers of neurons.
inpt = Input(n=84 * 84 * 4 *4 , shape=[4, 1, 1, 4, 84, 84], traces=True)
#inpt = Input(n=84 * 84 * 4, shape=[4,84,84], traces=True)
#inpt = Input(n=80 * 80, shape=[1,1,4,84,84], traces=True)
middle = LIFNodes(n=100, traces=True)
out = LIFNodes(n=4, refrac=0, traces=True)

# Connections between layers.
inpt_middle = Connection(source=inpt,update_rule=MSTDP, target=middle, wmin=0, wmax=1e-1)
middle_out = Connection(
    source=middle,
    target=out,
    wmin=0,
    wmax=1,
    update_rule=MSTDP,
    nu=1e-1,
    norm=0.5 * middle.n,
)

# Add all layers and connections to the network.
network.add_layer(inpt, name="Input Layer")
network.add_layer(middle, name="Hidden Layer")
network.add_layer(out, name="Output Layer")
network.add_connection(inpt_middle, source="Input Layer", target="Hidden Layer")
network.add_connection(middle_out, source="Hidden Layer", target="Output Layer")

# Load the Breakout environment.
environment = GymEnvironment("BreakoutDeterministic-v4")
environment.reset()

# Build pipeline from specified components.
environment_pipeline = EnvironmentPipeline(
    network,
    environment,
    encoding=poisson,
    action_function=select_softmax,
    output="Output Layer",
    time=100,
    history_length=4,
    delta=1,
    plot_interval=1,
    render_interval=1,
)


def run_pipeline(pipeline, episode_count):
    ep_reward = []
    for i in range(episode_count):
        total_reward = 0
        pipeline.reset_state_variables()
        is_done = False
        while not is_done:
            result = pipeline.env_step()
            pipeline.step(result)
            reward = result[1]
            total_reward += reward
            is_done = result[2]
        ep_reward+=[int(total_reward)]
        print(f"Episode {i} total reward:{total_reward}")
    return ep_reward


print("Training: ")
tr = run_pipeline(environment_pipeline, episode_count=10)
print(tr)

# stop MSTDP
environment_pipeline.network.learning = False

print("Testing: ")
run_pipeline(environment_pipeline, episode_count=1)
environment.generate_gif(str(1000)+" Training Episodes")

In [None]:
plt.bar(dict(Counter(tr)).keys(),dict(Counter(tr)).values())
print(time.time()-ti)

In [None]:
plt.plot([i for i in range(len(tr))],tr)

In [None]:
print(network.layers)
print(network.connections)
for l in network.connections.values():
    print(l.w)
environment.action_space.n

In [None]:
x = torch.from_numpy(np.zeros((4,80,80)))

x.masked_fill_(x != 0, 1)