In [None]:
#%pip install opencv-python
#%pip install torchvision
#%pip install scikit-image
#%pip install gym[Atari]
#%pip install gym[accept-rom-license]
#%pip install tensorflow
#%pip install pyglet

In [None]:
import gym
from gym import wrappers
import tensorflow.compat.v1 as tf
from skimage.transform import resize
import matplotlib.pyplot as plt
from IPython import display
import random
import time
import imageio
st = time.time()
import cv2
import pyNN.spiNNaker as p
p.setup(timestep=1)
import math
import torchvision
import numpy as np 
import torchvision.transforms as transforms

In [None]:
TRAIN = True

In [None]:
def inputToSpikeRateArray(frame):
    if len(frame)>1:
        return 0
    frame = np.array(frame)
    print(frame.flatten())
    return frame.flatten()

In [None]:
def frameprocess(frame,frame_height=84, frame_width=65):
    frame_height = frame_height
    frame_width = frame_width
    processed = tf.image.rgb_to_grayscale(frame)
    processed = tf.image.crop_to_bounding_box(processed, 34, 0, 160, 140)
    processed = tf.image.resize_images(processed, 
                                            [frame_height, frame_width], 
                                            method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    return processed

In [None]:
class Atari(object):
    """Wrapper for the gym Atari to keep track of state"""
    def __init__(self, envName, no_op_steps=10, agent_history_length=4):
        self.env = gym.make(envName)
        self.state = None
        self.last_lives = 0
        self.no_op_steps = no_op_steps
        self.agent_history_length = agent_history_length

    def reset(self,evaluation=False):
        """
        Resets the environment and stacks four frames ontop of each other to 
        create the first state
        """
        frame = self.env.reset()
        self.last_lives = 0
        terminal_life_lost = True # Set to true so that the agent starts 
                                  # with a 'FIRE' action when evaluating
        if evaluation:
            for _ in range(random.randint(1, self.no_op_steps)):
                frame, _, _, _ = self.env.step(1) # Action 'Fire'
        processed_frame = frameprocess(frame)
        self.state = np.repeat(processed_frame, self.agent_history_length, axis=2)
        
        return terminal_life_lost

    def step(self,action):
        """
        Performs an action and observes the reward and terminal state from the environment
        """
        new_frame, reward, terminal, info = self.env.step(action)
            
        if info['lives'] < self.last_lives:
            terminal_life_lost = True
        else:
            terminal_life_lost = terminal
        self.last_lives = info['lives']
        
        processed_new_frame = frameprocess(new_frame)
        new_state = np.append(self.state[:, :, 1:], processed_new_frame, axis=2)
        self.state = new_state
        
        return processed_new_frame, reward, terminal, terminal_life_lost, new_frame

In [None]:
def make_a_gif(frames_for_gif):
    for idx, frame_idx in enumerate(frames_for_gif): 
        frames_for_gif[idx] = resize(frame_idx, (420, 320, 3), 
                                     preserve_range=True, order=0).astype(np.uint8)
        
    imageio.mimsave("ATARI_PONG.gif", 
                    frames_for_gif, duration=1/30)

In [None]:
def selectOutput(records,previous_spikes):
    choice = 0
    best = 0
    j=-1
    new_previous_spikes=[]
    if previous_spikes==[]:
        previous_spikes=[0]*len(records)
    for rec in records:
        j+=1
        #print("action",j,"spikes",len(rec)-previous_spikes[j])
        new_previous_spikes.append(len(rec))
    for i in range(len(records)):
        if (len(records[i])-previous_spikes[i])>best:
            best = len(records[i])-previous_spikes[i]
            choice = i
    return choice,new_previous_spikes

In [None]:
env = Atari("PongDeterministic-v4")

In [None]:
timestep = 1.0
duration = 3000

# Main parameters from Izhikevich 2007 STDP paper
#t_pre = [1500, 2400]  # Pre-synaptic neuron times
#t_post = [1502]  # Post-synaptic neuron stimuli time
#t_dopamine = [1600]  # Dopaminergic neuron spike times
tau_c = 1000  # Eligibility trace decay time constant.
tau_d = 200  # Dopamine trace decay time constant.
DA_concentration_reward = [0.1]*env.env.action_space.n  # Dopamine trace step increase size
DA_concentration_punishment = [0.1]*env.env.action_space.n
# Initial weight
rewarded_syn_weight = 0.0

In [None]:
np.array((env.state)).size

In [None]:
tau_c = 1000  # Eligibility trace decay time constant.
tau_d = 200  # Dopamine trace decay time constant.
DA_concentration = 0.1  # Dopamine trace step increase size


##### INPUT LAYER #####
inputLayer = p.Population(np.array((env.state)).size,p.SpikeSourcePoisson(rate=inputToSpikeRateArray(env.state)))
inputLayer.record(["spikes"])
##### Training INPUT ######
if TRAIN:
    rewardLayer = [p.Population(1,p.SpikeSourcePoisson(rate=1)) for i in range(env.env.action_space.n)]
    punishmentLayer = [p.Population(1,p.SpikeSourcePoisson(rate=1)) for i in range(env.env.action_space.n)]
    
#####   STDP    #####
timing_rule = p.SpikePairRule(tau_plus=0.1, tau_minus=0.1, A_plus=0.1, A_minus=0.1)
weight_rule = p.AdditiveWeightDependence(w_max=10.0, w_min=0.01)
stdp_model_excitatory = p.STDPMechanism(timing_dependence=timing_rule, weight_dependence=weight_rule, weight=5)
##### NMSTDP ######


#### SECOND LAYER ####
pop=[p.Population(int(100),p.IF_curr_exp()) for action in range(env.env.action_space.n)]
s = [m.record(["spikes"]) for m in pop]

#### Projections ####
projections01= [p.Projection(inputLayer,m,p.AllToAllConnector(), synapse_type=stdp_model_excitatory) for m in pop]
rewardproj=[]
punishmentproj = []
for i in range(env.env.action_space.n):
    rewardproj.append(p.Projection(
    rewardLayer[i],pop[i],
    p.AllToAllConnector(),
    synapse_type=p.extra_models.Neuromodulation(
    weight=DA_concentration_reward[i], tau_c=tau_c, tau_d=tau_d, w_max=20.0),
receptor_type='reward', label='reward synapses'))
                      
    punishmentproj.append(p.Projection(
    punishmentLayer[i],pop[i],
    p.AllToAllConnector(),
    synapse_type=p.extra_models.Neuromodulation(
    weight=DA_concentration_punishment[i], tau_c=tau_c, tau_d=tau_d, w_max=20.0),
receptor_type='reward', label='reward synapses'))


In [None]:
env = Atari("PongDeterministic-v4")
env.reset()
frames=[]
total_reward=0
done = False
count=0
previous_spikes=[]
rwrd_running_tot=[]
while not done:
    #observation, reward, done, info,new_frame = env.step(action)
    env.env.render()
    count+=1
    print(count)
    ri = total_reward

    action,previous_spikes = selectOutput([m.get_data("spikes").segments[0].spiketrains[0] for m in pop],previous_spikes)
    for i in range(10):
        observation, reward, done, info,new = env.step(action)#processed_new_frame, reward, terminal, terminal_life_lost, new_frame
    
    total_reward+=reward
    inputLayer.set(rate=inputToSpikeRateArray(env.state))

    if reward > 0:#REWARD SIGNAL CHANGES WHEN REWARD CHANGES
        DA_concentration_reward[action] *= (reward+1)
    else:
        DA_concentration_punishment[action] *= 1/abs(reward+1)

    total_reward += reward
    p.run(1)

rwrd_running_tot.append(total_reward)
print("GAME REWARD",total_reward)

In [None]:
import matplotlib.pyplot as plt
plt.plot([i for i in range(len(rwrd_running_tot))],rwrd_running_tot)
plt.savefig('plot.png', dpi=300, bbox_inches='tight')

In [None]:
print(rwrd_running_tot)

In [None]:
print(time.time()-st)