In [None]:
#Install Needed libraries# python 3.8
%pip install gym -qqq
%pip install pygame -qqq
%pip install gym[accept-rom-license] -qqq
%pip install spikingjelly==0.0.0.0.4 -qqq
%pip install cupy -qqq
%pip install ale-py -qqq
%pip install tensorflow -qqq

# Makes Heavy Use of the Spiking Jelly Repo Though had to modify some of the libraries code to get it to work with images as in default it was meant to work with states that describe the game

In [None]:
import gym
import math
import random
import numpy as np
from collections import namedtuple
from itertools import count
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from spikingjelly.clock_driven import neuron, functional
import os

from torch.utils.tensorboard import SummaryWriter

In [None]:
game = "Breakout-v0"
import cv2
import time
import imageio
import gym
from tqdm import tqdm
import numpy as np
from skimage.transform import resize
from collections import deque
import random
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import tensorflow.compat.v1 as tf

In [None]:
def generate_gif(frames_for_gif,name):
    for idx, frame_idx in enumerate(frames_for_gif): 
        frames_for_gif[idx] = resize(frame_idx, (420, 320, 3), 
                                     preserve_range=True, order=0).astype(np.uint8)
        
    imageio.mimsave(name+str(game)+".gif", frames_for_gif, duration=1/30)

In [None]:
def frameprocess(frame,frame_height=84, frame_width=84):
    frame_height = frame_height
    frame_width = frame_width
    processed = tf.image.rgb_to_grayscale(frame)
    processed = tf.image.crop_to_bounding_box(processed, 34, 0, 160, 140)
    processed = tf.image.resize_images(processed, 
                                            [frame_height, frame_width], 
                                            method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    return processed

In [None]:
class Atari(object):
    """Wrapper for the environment provided by gym"""
    def __init__(self, envName, no_op_steps=10, agent_history_length=4):
        self.env = gym.make(envName)
        self.unwrapped = self.env.unwrapped
        self.state = None
        self.totalPixels = 84*84*4
        self.last_lives = 0
        self.no_op_steps = no_op_steps
        self.metadata = self.env.metadata
        self.agent_history_length = agent_history_length
        self.spec = self.env.spec
        self.action_space = self.env.action_space
        self.render = self.env.render

    def reset(self,evaluation=False):
        frame = self.env.reset()
        self.last_lives = 0
        processed_frame = frameprocess(frame)  
        self.state = np.repeat(processed_frame, self.agent_history_length, axis=2)

    def step(self,action):
        new_frame, reward, done, info = self.env.step(action)
        processed_new_frame = frameprocess(new_frame) 
        new_state = np.append(self.state[:, :, 1:], processed_new_frame, axis=2)    
        self.state = new_state
        return processed_new_frame, reward, done, new_frame
    
    def seed(self,seed):
        self.env.seed(seed)

In [None]:
Transition = namedtuple('Transition',('state', 'action', 'next_state', 'reward'))

To Make SNN output floating numbers. Firing threshold of a neuron is set to be infinity, which won’t fire at all, and we adopt the final membrane potential to represent Q function.

In [None]:
class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        #print("TRANSITION",Transition(*args))
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


class NonSpikingLIFNode(neuron.LIFNode):
    def __init__(self, *args, **kwargs):
        super().__init__(*args,**kwargs)

    def forward(self, dv: torch.Tensor):
        self.neuronal_charge(dv)
        self.neuronal_fire()
        self.neuronal_reset()
        return self.v


In [None]:
# Spiking DQN algorithm
class DQSN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, T=16):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            neuron.IFNode(),
            nn.Linear(hidden_size, output_size),
            NonSpikingLIFNode(tau=2.0)
        )

        self.T = T

    def forward(self, x):
        for t in range(self.T):
            self.fc(x)
            
        return self.fc[-1].v

In [None]:
def train(use_cuda, model_dir, log_dir, env_name, hidden_size, num_episodes, seed):
    BATCH_SIZE = 128
    GAMMA = 0.999
    EPS_START = 0.9
    EPS_END = 0.05
    EPS_DECAY = 200
    TARGET_UPDATE = 10

    T = 16

    random.seed(seed)
    np.random.seed(seed)

    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

    device = torch.device("cuda" if use_cuda else "cpu")

    steps_done = 0

    writer = SummaryWriter(log_dir=log_dir)

    #env = gym.make(env_name).unwrapped
    #env.seed(seed)
    #n_states = env.observation_space.shape[0]
    #n_actions = env.action_space.n
    
    env = Atari(env_name)
    n_states = env.totalPixels
    n_actions = env.action_space.n
    
    ##print("STATES",n_states,"ACTIONS",n_actions)

    policy_net = DQSN(n_states, hidden_size, n_actions, T).to(device)
    target_net = DQSN(n_states, hidden_size, n_actions, T).to(device)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(policy_net.parameters())
    memory = ReplayMemory(10000)

    def select_action(state, steps_done):
        sample = random.random()
        eps_threshold = EPS_END + (EPS_START - EPS_END) * \
                        math.exp(-1. * steps_done / EPS_DECAY)
        if sample > eps_threshold:
            #Breakout
            with torch.no_grad():
                ac = torch.tensor([[torch.argmax(policy_net(state))]])
                #print("AC",ac)
                functional.reset_net(policy_net)
                return ac
        else:
            #print("RANDOM",torch.tensor([[random.randrange(env.action_space.n)]], device=device, dtype=torch.long))
            return torch.tensor([[random.randrange(env.action_space.n)]], device=device, dtype=torch.long)

    def optimize_model():
        if len(memory) < BATCH_SIZE:
            return

        transitions = memory.sample(BATCH_SIZE)

        batch = Transition(*zip(*transitions))

        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,batch.next_state)), device=device, dtype=torch.bool)
        
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
                                                    
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)

        state_action_values = policy_net(state_batch).gather(1, action_batch)

        next_state_values = torch.zeros(BATCH_SIZE, device=device)
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
        functional.reset_net(target_net)
        expected_state_action_values = (next_state_values * GAMMA) + reward_batch

        loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

        optimizer.zero_grad()
        loss.backward()
        for param in policy_net.parameters():
            if param.grad is not None:
                param.grad.data.clamp_(-1, 1)
        optimizer.step()
        functional.reset_net(policy_net)

    max_reward = 0
    max_pt_path = os.path.join(model_dir, f'saved_net_{hidden_size}_max.pt')
    pt_path = os.path.join(model_dir, f'saved_net_{hidden_size}.pt')

    for i_episode in range(num_episodes):
        # Initialize the environment and state
        env.reset()
        state = torch.zeros([1, n_states], dtype=torch.float, device=device)

        total_reward = 0

        for t in count():
            action = select_action(state, steps_done)
            steps_done += 1
            #print("ACTION",action)
            next_state, reward, done, _ = env.step(action.item())
            next_state = env.state

            #BREAKOUT
            next_state = np.reshape(next_state,(1,84*84*4))
            
            #print("STATE SHAPE",next_state.shape)
            total_reward += reward
            next_state = torch.from_numpy(next_state).float().to(device).unsqueeze(0)
            
            
            reward = torch.tensor([reward], device=device)

            if done:
                next_state = None

            if(state!=None and len(state.shape)==3):
                state = torch.reshape(state, (1,84*84*4))
            if(next_state!=None and len(next_state.shape)==3):
                next_state = torch.reshape(next_state, (1,84*84*4))

            #print(len(action.shape),len(reward.shape))
            memory.push(state, action, next_state, reward)

            state = next_state
            if done and total_reward > max_reward:
                max_reward = total_reward
                torch.save(policy_net.state_dict(), max_pt_path)
                print(f'max_reward={max_reward}, save models')

            optimize_model()

            if done:
                print(f'Episode: {i_episode}, Reward: {total_reward}')
                writer.add_scalar('Spiking-DQN-state-' + env_name + '/Reward', total_reward, i_episode)
                break

        if i_episode % TARGET_UPDATE == 0:
            target_net.load_state_dict(policy_net.state_dict())

    print('complete')
    torch.save(policy_net.state_dict(), pt_path)
    print('state_dict path is', pt_path)

    writer.close()

In [None]:
def play(use_cuda, pt_path, env_name, hidden_size, played_frames=60, save_fig_num=0, fig_dir=None, figsize=(12, 6), firing_rates_plot_type='bar', heatmap_shape=None):    
    T = 16
    FRAMESTORE = []
    SCREENSTORE = []
    plt.rcParams['figure.figsize'] = figsize
    plt.ion()
    device = torch.device("cuda" if use_cuda else "cpu")

    env = Atari(env_name)
    n_states = env.totalPixels
    n_actions = env.action_space.n

    policy_net = DQSN(n_states, hidden_size, n_actions, T).to(device)
    policy_net.load_state_dict(torch.load(pt_path, map_location=device))

    env.reset()
    state = torch.zeros([1, n_states], dtype=torch.float, device=device)

    with torch.no_grad():
        #functional.set_monitor(policy_net, True)
        delta_lim = 0
        over_score = 1e9

        for i in count():
            LIF_v = policy_net(state)
            action = torch.tensor([[torch.argmax(LIF_v)]])

            if firing_rates_plot_type == 'bar':
                plt.subplot2grid((2, 9), (1, 0), colspan=3)
            elif firing_rates_plot_type == 'heatmap':
                plt.subplot2grid((2, 3), (1, 0))

            plt.xticks(np.arange(4), (meaning for meaning in env.unwrapped.get_action_meanings()))
            plt.ylabel('Voltage')
            plt.title('Voltage of LIF neurons at last time step')
            delta_lim = (LIF_v.max() - LIF_v.min()) * 0.5
            plt.ylim(LIF_v.min() - delta_lim, LIF_v.max() + delta_lim)
            plt.yticks([])
            plt.text(0, LIF_v[0][0], str(round(LIF_v[0][0].item(), 2)), ha='center')
            plt.text(1, LIF_v[0][1], str(round(LIF_v[0][1].item(), 2)), ha='center')

            plt.bar(np.arange(4), LIF_v.squeeze(), color=['r', 'gray'] if action == 0 else ['gray', 'r'], width=0.5)
            
            if LIF_v.min() - delta_lim < 0:
                plt.axhline(0, color='black', linewidth=0.1)

            policy_net.fc[1].set_monitor()
            IF_spikes = np.asarray(policy_net.fc[1].monitor['s'])
            firing_rates = IF_spikes.mean(axis=0).squeeze()
            
            if firing_rates_plot_type == 'bar':
                plt.subplot2grid((2, 9), (0, 4), rowspan=2, colspan=5)
            elif firing_rates_plot_type == 'heatmap':
                plt.subplot2grid((2, 3), (0, 1), rowspan=2, colspan=2)
            
            plt.title('Firing rates of IF neurons')

            if firing_rates_plot_type == 'bar':
                plt.xlabel('Neuron index')
                plt.ylabel('Firing rate')
                plt.xlim(0, firing_rates.size)
                plt.ylim(0, 1.01)
                plt.bar(np.arange(firing_rates.size), firing_rates, width=0.5)

            elif firing_rates_plot_type == 'heatmap':
                heatmap = plt.imshow(firing_rates.reshape(heatmap_shape), vmin=0, vmax=1, cmap='ocean')
                plt.gca().invert_yaxis()
                cbar = heatmap.figure.colorbar(heatmap)
                cbar.ax.set_ylabel('Magnitude', rotation=90, va='top')
            

            state, reward, done, obs = env.step(action)
            FRAMESTORE.append(obs)

            subtitle = ""
            if done:
                over_score = min(over_score, i)
                subtitle = f'Game over, Score={over_score}'
            plt.suptitle(subtitle)
            
            state = np.reshape(env.state,(84*84*4))
            state = torch.from_numpy(state).float().to(device).unsqueeze(0)
            screen = env.render(mode='rgb_array').copy()
            screen[200, :, :] = 0
            
            if firing_rates_plot_type == 'bar':
                plt.subplot2grid((2, 9), (0, 0), colspan=3)
            elif firing_rates_plot_type == 'heatmap':
                plt.subplot2grid((2, 3), (0, 0))
            
            plt.xticks([])
            plt.yticks([])
            plt.title('Game screen')
            plt.imshow(screen, interpolation='bicubic')
            plt.pause(0.001)
            
            if i < save_fig_num:
                plt.savefig(os.path.join(fig_dir, f'{i}.png'))
            
            if done and i >= played_frames:
                env.close()
                plt.close()
                break
    generate_gif(FRAMESTORE,"GAMEFOOTAGE")
    generate_gif(SCREENSTORE,"SCREENFOOTAGE")

In [None]:
train(use_cuda=False, model_dir='./', log_dir='./log', env_name='Breakout-v0', hidden_size=256, num_episodes=100000, seed=random.randint(1,10000))

In [None]:
play(use_cuda=False, pt_path='./saved_net_256_max.pt', env_name='Breakout-v0', hidden_size=256, played_frames=300)

# Credits to
## https://github.com/fangwei123456/spikingjelly
## https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html