In [10]:
from collections import deque, namedtuple
import os
from snnUtils import *
from IPython.display import clear_output

from tqdm import tqdm
import pandas as pd
import random, imageio, time, copy
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt

import torch.nn as nn
import torch
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

In [11]:
# Define the super parameters
projectName = "snnRL"

# Save/Get weights from presistent storage. Pass empty string for not saving. 
# Pass derive for using google derive (If code is running in colab). If local, 
# pass the location of your desire
savePath = "./Data"
backUpNetworks = False 
saveLen = 30 # Number of latest checkpoints to save

# Handle save path
if savePath != None:
    if savePath == "derive":
        # Mount gdrive if we want to interact with cloab
        from google.colab import drive
        drive.mount('/content/gdrive')
        savePath = "gdrive/MyDrive/Colab Notebooks/Data/"
    backUpNetworks = True

# Making the environment
env = gym.make("LunarLander-v3") # Use render_mode = "human" to render each episode
state, info = env.reset() # Get a sample state of the environment
stateSize = env.observation_space.shape # Number of variables to define current step
nActions = env.action_space.n # Number of actions
nObs = len(state) # Number of features


# Set pytorch parameters: The device (CPU or GPU) and data types
__device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
__dtype = torch.float

In [None]:
class qNetwork_SNN(nn.Module):
    def __init__(self, inputSize, L1Size, L2Size, L3Size, L4Size, outputSize, **kwargs):
        super().__init__()

        # Model super parameters
        self.beta = kwargs["beta"]
        self.tSteps = kwargs["tSteps"]

        # Defining the layers
        self.layer1 = nn.Linear(inputSize, L1Size)
        self.L1LIF = snn.Leaky(beta = self.beta)
        self.layer2 = nn.Linear(L1Size, L2Size)
        self.L2LIF = snn.Leaky(beta = self.beta)
        self.layer3 = nn.Linear(L2Size, L3Size)
        self.L3LIF = snn.Leaky(beta = self.beta)
        self.layer4 = nn.Linear(L3Size, L4Size)
        self.L4LIF = snn.Leaky(beta = self.beta)
        self.output = nn.Linear(L4Size, outputSize)
        self.outputLIF = snn.Leaky(beta = self.beta)


    def forward(self, x):

        # Set initial potentials to be zero
        potential1 = self.L1LIF.reset_mem()
        potential2 = self.L2LIF.reset_mem()
        potential3 = self.L3LIF.reset_mem()
        potential4 = self.L4LIF.reset_mem()
        potential5 = self.outputLIF.reset_mem()

        # Save the state of the output layer
        outSpikes = []
        outPotentials = []

        # Iterate through time steps
        for t in range(self.tSteps):
            # First layer
            current1 = self.layer1(x)
            spk1, potential1 = self.L1LIF(current1, potential1)

            # Second layer
            current2 = self.layer2(spk1)
            spk2, potential2 = self.L2LIF(current2, potential2)

            # Third layer
            current3 = self.layer3(spk2)
            spk3, potential3 = self.L3LIF(current3, potential3)

            # Fourth layer
            current4 = self.layer4(spk3)
            spk4, potential4 = self.L4LIF(current4, potential4)

            #Output
            current5 = self.output(spk4)
            spk5, potential5 = self.outputLIF(current5, potential5)

            # Save output
            outSpikes.append(spk5)
            outPotentials.append(potential5)

        return torch.stack(outSpikes, dim = 0).sum(dim = 0)

# Model parameters
nL1, nL2, nL3, nL4 = 128, 64, 32, 16
learningRate = .00001
timeSteps = 150
snnBeta = .95
eDecay = 0.999
miniBatchSize = 1000 # The length of minibatch that is used for training
gamma = .995 # The discount factor
extraInfo = ""
modelDetails = f"{nL1}_{nL2}_{nL3}_{nL4}_{learningRate}_{timeSteps}_{snnBeta}_{eDecay}_{miniBatchSize}_{gamma}_{extraInfo}"

# Make the model objects
qNetwork_model = qNetwork_SNN(stateSize[0], nL1, nL2, nL3, nL4, nActions, beta = snnBeta, tSteps = timeSteps).to(__device, dtype = __dtype)
targetQNetwork_model = qNetwork_SNN(stateSize[0], nL1, nL2, nL3, nL4, nActions, beta = snnBeta, tSteps = timeSteps).to(__device, dtype = __dtype)

# Two models should have identical weights initially
targetQNetwork_model.load_state_dict(qNetwork_model.state_dict())

# TODO: Add gradient clipping to the optimizer for avoiding exploding gradients
# Suitable optimizer for gradient descent
optimizer_main = torch.optim.Adam(qNetwork_model.parameters(), lr=learningRate)
optimizer_target = torch.optim.Adam(targetQNetwork_model.parameters(), lr=learningRate)

# Starting episode and ebsilon
startEpisode = 0
startEbsilon = None
lstHistory = None

# Making the memory buffer object
memorySize = 100_000 # The length of the entire memory
mem = ReplayMemory(memorySize, __dtype, __device)

# If given access to drive, try to load the latest saved weights
qNetworkSaveHistory = deque(maxlen = saveLen)
targetQNetworkSaveHistory = deque(maxlen = saveLen)
if backUpNetworks:
    if os.path.isdir(savePath):
        _lst = os.listdir(savePath)
        for _file in _lst:
            if f"{projectName}_{modelDetails}.pth" == _file:
                qNetworkSaveHistory = torch.load(os.path.join(savePath, _file))
                qNetworkSaveHistory = qNetworkSaveHistory if isinstance(qNetworkSaveHistory, list) else [qNetworkSaveHistory]
                _chekcPoint = qNetworkSaveHistory[0] # Take the most recent chekcpoint

                # Load Q-Network
                qNetwork_model.load_state_dict(_chekcPoint["qNetwork_state_dict"]) # Model weights
                optimizer_main.load_state_dict(_chekcPoint["qNetwork_optimizer_state_dict"]) # Optimizer

                # Load target Q-Network
                targetQNetwork_model.load_state_dict(_chekcPoint["targetQNetwork_state_dict"]) # Model weights
                
                # Load process parameters
                startEpisode = int(_chekcPoint["episode"]) # Starting episode number
                startEbsilon = float(_chekcPoint["hyperparameters"]["ebsilon"]) # Starting ebsilon
                lstHistory = _chekcPoint["train_history"]
                eDecay = _chekcPoint["hyperparameters"]["eDecay"] if "eDecay" in _chekcPoint["hyperparameters"].keys() else None

                if "experiences" in _chekcPoint.keys():
                    mem.loadExperiences(
                        _chekcPoint["experiences"]["state"],
                        _chekcPoint["experiences"]["action"],
                        _chekcPoint["experiences"]["reward"],
                        _chekcPoint["experiences"]["nextState"],
                        _chekcPoint["experiences"]["done"],
                    )

                # Backup the current file to avoide data loss in future read/writes (if reading is successful)
                import shutil
                shutil.copyfile(os.path.join(savePath, _file), os.path.join(savePath, _file.replace(".pth", "_Backup.pth")))
                print(f"Loaded network weights for episode {startEpisode}")
    else:
        print("Save path doesn't exist. Making it.")
        os.makedirs(savePath)

beginning_qNetwork = [qNetwork_model.layer1.weight, qNetwork_model.layer2.weight, qNetwork_model.output.weight]
beginning_targeQNetwork = [targetQNetwork_model.layer1.weight, targetQNetwork_model.layer2.weight, targetQNetwork_model.output.weight]

  qNetworkSaveHistory = torch.load(os.path.join(savePath, _file))


NameError: name '_chekcPoint' is not defined

In [None]:
print(f"Device is: {__device}")

# Start the timer
tstart = time.time()

# The experience of the agent is saved as a named tuple containing various variables
agentExp = namedtuple("exp", ["state", "action", "reward", "nextState", "done"])

# Parameters
nEpisodes = 6000 # Number of learning episodes
maxNumTimeSteps = 1000 # The number of time step in each episode
ebsilon = 1 if startEbsilon == None else startEbsilon # The starting  value of ebsilon
ebsilonEnd   = .1 # The finishing value of ebsilon
eDecay = eDecay # The rate at which ebsilon decays
numUpdateTS = 4 # Frequency of time steps to update the NNs
numP_Average = 100 # The number of previous episodes for calculating the average episode reward

# Variables for saving the required data for later analysis
episodePointHist = [] # For saving each episode's point for later demonstration
episodeTimeHist = [] # For saving the time it took for episode to end
actionString = "" # A string containing consecutive actions taken in an episode (dellimited by comma, i.e. 1,2,4,2,1 etc.)
episodeHistDf = None
lstHistory = [] if lstHistory == None else lstHistory
initialCond = None # initial condition (state) of the episode
epPointAvg = -999999 if len(lstHistory) == 0 else pd.DataFrame(lstHistory).iloc[-numP_Average:]["points"].mean()
latestChekpoint = 0


for episode in range(startEpisode, nEpisodes):
    initialSeed = random.randint(1,1_000_000_000) # The random seed that determines the episode's I.C.
    state, info = env.reset(seed = initialSeed)
    points = 0
    actionString = ""
    initialCond = state

    tempTime = time.time()
    _lastPrinttime = tempTime # For printing the training progress 
    for t in range(maxNumTimeSteps):

        qValueForActions = qNetwork_model(torch.tensor(state, device = __device, dtype = __dtype))

        # use ebsilon-Greedy algorithm to take the new step
        action = getAction(qValueForActions, ebsilon)

        # Take a step
        observation, reward, terminated, truncated, info = env.step(action)

        # Store the experience of the current step in an experience deque.
        mem.addNew(
            agentExp(
                state, # Current state
                action,
                reward, # Current state's reward
                observation, # Next state
                True if terminated or truncated else False
            )
        )

        # Check to see if we have to update the networks in the current step
        update = updateNetworks(t, mem, miniBatchSize, numUpdateTS)

        if update:
            initial_weights = {name: param.clone() for name, param in qNetwork_model.named_parameters()}
            # Update the NNs
            experience = mem.sample(miniBatchSize)

            # Update the Q-Network and the target Q-Network
            # Bear in mind that we do not update the target Q-network with direct gradient descent.
            # so there is no optimizer needed for it
            fitQNetworks(experience, gamma, [qNetwork_model, optimizer_main], [targetQNetwork_model, None])

        # Save the necessary data
        points += reward
        state = observation.copy()
        actionString += f"{action},"

        # Print the training status. Print only once each second to avoid jitters.
        if 1 < (time.time() - _lastPrinttime):
            clear_output(wait=True)
            _lastPrinttime = time.time()
            print(f"ElapsedTime: {int(time.time() - tstart): <5}s | Episode: {episode: <5} | Timestep: {t: <5} | The average of the {numP_Average: <5} episodes is: {int(epPointAvg): <5}")
            print(f"Latest chekpoint: {latestChekpoint} | Speed {t/(time.time()-tempTime):.1f} tps | ebsilon: {ebsilon:.3f}")

            # fig= plt.figure(figsize=(12,6))
            # plt.plot(pd.DataFrame(lstHistory)["episode"], pd.DataFrame(lstHistory)["points"])
            # plt.show()

        # Handle episode ending
        if terminated or truncated:
            # Save the episode history in dataframe
            if (episode+1) % 3 == 0:
                # only save every 10 episodes
                lstHistory.append({
                    "episode": episode,
                    "seed": initialSeed,
                    "points": points,
                    "timesteps": t,
                    "duration": time.time() - tempTime
                })
                
            break

    # Saving the current episode's points and time
    episodePointHist.append(points)
    episodeTimeHist.append(time.time()-tempTime)

    # Getting the average of {numP_Average} episodes
    epPointAvg = np.mean(episodePointHist[-numP_Average:])

    # Decay ebsilon
    ebsilon = decayEbsilon(ebsilon, eDecay, ebsilonEnd)

    # Save model weights and parameters periodically (For later use)
    if backUpNetworks:
        if (episode + 1) % 20 == 0:
            _exp = mem.exportExpereince()
            _chekcPoint = {
                "episode": episode,
                'qNetwork_state_dict': qNetwork_model.state_dict(),
                'qNetwork_optimizer_state_dict': optimizer_main.state_dict(),
                'targetQNetwork_state_dict': targetQNetwork_model.state_dict(),
                'targetQNetwork_optimizer_state_dict': optimizer_target.state_dict(),
                'hyperparameters': {"ebsilon": ebsilon, "eDecay":eDecay},
                "train_history": lstHistory,
                "experiences": {
                    "state": _exp["state"],
                    "action": _exp["action"],
                    "reward": _exp["reward"],
                    "nextState": _exp["nextState"],
                    "done": _exp["done"]
                }
            }
            qNetworkSaveHistory.appendleft(_chekcPoint)
            torch.save(qNetworkSaveHistory, os.path.join(savePath, f"{projectName}_{modelDetails}.pth"))

            # Save the episode number
            latestChekpoint = episode

    # Stop the learning process if suitable average point is reacheds
    if 100 < epPointAvg:
        Tend = time.time()
        print(f"\nThe learning ended. Elapsed time for learning: {Tend-tstart:.2f}s. \nAVG of latest 100 episodes: {epPointAvg}")
        
        _exp = mem.exportExpereince()
        _chekcPoint = {
            "episode": episode,
            'qNetwork_state_dict': qNetwork_model.state_dict(),
            'qNetwork_optimizer_state_dict': optimizer_main.state_dict(),
            'targetQNetwork_state_dict': targetQNetwork_model.state_dict(),
            'targetQNetwork_optimizer_state_dict': optimizer_target.state_dict(),
            'hyperparameters': {"ebsilon": ebsilon, "eDecay":eDecay},
            "train_history": lstHistory,
            "experiences": {
                "state": _exp["state"],
                "action": _exp["action"],
                "reward": _exp["reward"],
                "nextState": _exp["nextState"],
                "done": _exp["done"]
            }
        }
        qNetworkSaveHistory.appendleft(_chekcPoint)
        torch.save(qNetworkSaveHistory, os.path.join(savePath, f"{projectName}_{modelDetails}.pth"))

        # Save the episode number
        latestChekpoint = episode
        
        break

# Reset the index
episodeHistDf = pd.DataFrame(lstHistory)
episodeHistDf.reset_index(drop=True, inplace=True)

env.close()

In [None]:
import matplotlib.pyplot as plt
fig= plt.figure(figsize=(12,6))
plt.plot(pd.DataFrame(lstHistory)["episode"], pd.DataFrame(lstHistory)["points"])