In [1]:
import gym
import numpy as np

from controller import (
    ControllerState,
    get_action_space,
    execute_action,
    get_observation,
    compute_reward
)

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.


In [2]:
from torchvision import datasets, transforms
import random

transform = transforms.Compose([transforms.ToTensor()])

test_dataset = datasets.CIFAR10(
    root="./data",
    train=False,
    transform=transform,
    download=False
)

def load_random_image_to_spikemem(
    spike_mem,
    dataset,
    ttfs_encoder,
    Tmax,
    input_layer=0
):
    # sample random image
    idx = random.randint(0, len(dataset) - 1)
    img, label = dataset[idx]

    img_tensor = img.unsqueeze(0)   # [1, C, H, W]
    C, H, W = img.shape

    # TTFS encode
    spike_seq = ttfs_encoder(img_tensor)  # [1, T, C, H, W]
    spike_seq = spike_seq.squeeze(0).cpu().numpy()  # [T, C, H, W]

    # reset input layer
    spike_mem.reset_layer(input_layer)

    # load spikes
    for t in range(Tmax):
        cur = spike_seq[t]
        for ch in range(C):
            for r in range(H):
                for c in range(W):
                    if cur[ch, r, c] != 0.0:
                        spike_mem.put_spike(
                            layer=input_layer,
                            t=t,
                            ch=ch,
                            row=r,
                            col=c
                        )

    return label   # useful for logging / debugging

In [3]:
class SpikeSchedulingEnv(gym.Env):
    def __init__(self, encoder, processor, spike_mem, neuron_mem, dataset, output_layer):
        super().__init__()
        self.processor = processor
        self.spike_mem = spike_mem
        self.neuron_mem = neuron_mem
        self.output_layer = output_layer
        self.dataset = dataset
        self.ttfs = encoder
        self.Tmax = 8

        self.state = ControllerState(processor, 8)

        self.observation_space = gym.spaces.Box(
            low=0, high=np.inf, shape=(5,), dtype=np.float32
        )

        self.last_action = None
        # action_space is dynamic â†’ we handle indexing manually

    def reset(self):
        # reset internal state
        self.state.reset()
        self.neuron_mem.reset_all()
        self.spike_mem.reset_all()

        self.step_count = 0
        self.done = False

        # ðŸ”¹ load random CIFAR image into SpikeMemory
        self.current_label = load_random_image_to_spikemem(
            spike_mem=self.spike_mem,
            dataset=self.dataset,          # stored in env
            ttfs_encoder=self.ttfs,         # stored in env
            Tmax=self.Tmax,
            input_layer=0
        )

        # compute initial legal actions & observation
        actions = get_action_space(
            self.state,
            self.processor
        )

        obs = get_observation(
            self.spike_mem,
            self.neuron_mem,
            self.state,
            actions
        )

        self.last_action = None

        return obs
    
    def step(self, action_idx):
        if self.done:
            raise RuntimeError("Episode already done")

        actions = get_action_space(self.state, self.processor)

        if not actions:
            # deadlock
            reward = -100.0
            self.done = True
            obs = get_observation(self.spike_mem, self.neuron_mem, self.state, actions)
            return obs, reward, True, {"deadlock": True}

        prev_spikes = self.spike_mem.count_total_spikes()
        prev_neurons = self.neuron_mem.total_active()

        # Execute chosen action
        action = actions[action_idx]

        action_repeated = (
            self.last_action is not None
            and action == self.last_action
        )
        self.last_action = action

        execute_action(action, self.state, self.processor,
                    self.spike_mem, self.neuron_mem)

        self.step_count += 1

        # Check termination
        done = self.spike_mem.has_output_spike(self.output_layer)
        deadlock = False

        reward = compute_reward(
            prev_spikes,
            prev_neurons,
            self.spike_mem,
            self.neuron_mem,
            done,
            deadlock,
            action_repeated=action_repeated
        )

        obs = get_observation(self.spike_mem, self.neuron_mem, self.state, actions)
        self.done = done

        return obs, reward, done, {}


In [4]:
MAX_ACTIONS = 256      # upper bound on possible actions
STATE_DIM = 5          # [spikes, neurons, n_actions, min_t, max_t]

GAMMA = 0.99
LAMBDA = 0.95
CLIP_EPS = 0.2
LR = 3e-4
ENTROPY_COEF = 0.01
VALUE_COEF = 0.5

Tmax = 8
NUM_LAYERS = 5  

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

class PPOPolicy(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()

        self.actor = nn.Sequential(
            nn.Linear(obs_dim, 128),
            nn.ReLU(),
            nn.Linear(128, act_dim)
        )

        self.critic = nn.Sequential(
            nn.Linear(obs_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    # -----------------------------
    # Used during rollout
    # -----------------------------
    def act(self, obs, legal_actions=None):
        logits = self.actor(obs)

        if legal_actions is not None:
            mask = torch.full_like(logits, -1e9)
            mask[legal_actions] = 0
            logits = logits + mask

        dist = Categorical(logits=logits)
        action = dist.sample()

        log_prob = dist.log_prob(action)
        entropy = dist.entropy()
        value = self.critic(obs)

        return action.item(), log_prob, value, entropy

    # -----------------------------
    # REQUIRED for PPO UPDATE
    # -----------------------------
    def evaluate(self, states, actions):
        """
        states:  [B, obs_dim]
        actions: [B] (LongTensor)
        """
        logits = self.actor(states)
        dist = Categorical(logits=logits)

        log_probs = dist.log_prob(actions)
        entropy = dist.entropy().mean()
        values = self.critic(states).squeeze(-1)

        return log_probs, values, entropy
    
    def forward(self, state):
        """
        state: [state_dim]
        """
        logits = self.actor(state)
        value = self.critic(state)
        return logits, value

In [6]:
def select_action(model, state, action_mask):
    """
    state: torch.Tensor [state_dim]
    action_mask: torch.Tensor [MAX_ACTIONS] (1 = valid, 0 = invalid)
    """
    logits, value = model(state)

    # mask invalid actions
    masked_logits = logits.clone()
    masked_logits[action_mask == 0] = -1e9

    dist = torch.distributions.Categorical(logits=masked_logits)
    action = dist.sample()

    log_prob = dist.log_prob(action)
    entropy = dist.entropy()

    return action.item(), log_prob, value, entropy

In [7]:
class RolloutBuffer:
    def __init__(self, gamma=0.99):
        self.gamma = gamma
        self.clear()

    def clear(self):
        self.states = []
        self.actions = []
        self.log_probs = []
        self.rewards = []
        self.values = []
        self.dones = []
        self.entropies = []

    def compute_returns(self):
        """
        Standard discounted returns (NO GAE yet)
        """
        returns = []
        G = 0.0

        for reward, done in zip(
            reversed(self.rewards),
            reversed(self.dones)
        ):
            if done:
                G = 0.0
            G = reward + self.gamma * G
            returns.insert(0, G)

        return torch.tensor(returns, dtype=torch.float32)


In [8]:
def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
    advantages = []
    gae = 0.0

    values = values + [0.0]

    for t in reversed(range(len(rewards))):
        delta = rewards[t] + gamma * values[t+1] * (1 - dones[t]) - values[t]
        gae = delta + gamma * lam * (1 - dones[t]) * gae
        advantages.insert(0, gae)

    returns = [adv + val for adv, val in zip(advantages, values[:-1])]
    return advantages, returns


In [9]:
def ppo_update(model, optimizer, buffer, epochs=4):
    states = torch.stack(buffer.states)
    actions = torch.tensor(buffer.actions)
    old_log_probs = torch.stack(buffer.log_probs)
    returns = buffer.compute_returns()
    advantages = returns - torch.stack(buffer.values)

    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    for _ in range(epochs):
        log_probs, values, entropy = model.evaluate(states, actions)

        ratio = torch.exp(log_probs - old_log_probs)

        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - CLIP_EPS, 1 + CLIP_EPS) * advantages

        policy_loss = -torch.min(surr1, surr2).mean()
        value_loss = F.mse_loss(values.squeeze(), returns)

        loss = (
            policy_loss
            + VALUE_COEF * value_loss
            - ENTROPY_COEF * entropy.mean()
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


In [10]:
from spike_memory import SpikeMemory
from neuron_memory import NeuronMemory
from spike_processor import SpikeProcessor
from snn_model import SCNN_CIFAR10_TTFS, TTFS_Encoder

encoder = TTFS_Encoder(T=Tmax)

spike_mem = SpikeMemory(num_layers=NUM_LAYERS, Tmax=Tmax)
neuron_mem = NeuronMemory()

model = SCNN_CIFAR10_TTFS()
model.load_state_dict(
    torch.load("ttfs_based_scnn_model_weights.pth", map_location="cpu"),
    strict=False
)
model.eval()

processor = SpikeProcessor(
    model,
    neuron_mem,
    spike_mem
)

output_layer = max(processor.shapes.keys())

In [11]:
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=False)

In [15]:
env = SpikeSchedulingEnv(encoder, processor, spike_mem, neuron_mem, train_dataset, output_layer)
model = PPOPolicy(STATE_DIM, MAX_ACTIONS)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

buffer = RolloutBuffer()

for episode in range(1000):
    obs = env.reset()

    episode_reward = 0

    done = False

    while not done:
        # -------------------------------------------------
        # 1. Query legal actions from environment
        # -------------------------------------------------
        actions = get_action_space(env.state, env.processor)

        # print("Legal actions:")
        # for i, (L, R) in enumerate(actions):
        #     print(
        #         f"[{i}] Layer={L}, Row={R}, "
        #         f"t={env.state.exec_t[(L, R)]}"
        #     )

        action_mask = torch.zeros(MAX_ACTIONS, dtype=torch.float32)
        action_mask[:len(actions)] = 1.0

        # -------------------------------------------------
        # 2. Prepare state tensor (NO grad here)
        # -------------------------------------------------
        state = torch.tensor(obs, dtype=torch.float32)

        # -------------------------------------------------
        # 3. Sample action from policy
        # -------------------------------------------------
        action, log_prob, value, entropy = select_action(
            model, state, action_mask
        )

        # print(action)
        # idx = int(input("Choose action index: "))
        

        # -------------------------------------------------
        # 4. Step environment
        # -------------------------------------------------
        next_obs, reward, done, info = env.step(action)

        episode_reward += reward

        # -------------------------------------------------
        # 5. Store rollout (ðŸ”¥ DETACH EVERYTHING ðŸ”¥)
        # -------------------------------------------------
        buffer.states.append(state.detach())              # state is already no-grad, but safe
        buffer.actions.append(int(action))                # store as int
        buffer.log_probs.append(log_prob.detach())
        buffer.values.append(value.detach().squeeze())
        buffer.entropies.append(entropy.detach())

        buffer.rewards.append(float(reward))
        buffer.dones.append(bool(done))

        # -------------------------------------------------
        # 6. Move to next state
        # -------------------------------------------------
        obs = next_obs


    ppo_update(model, optimizer, buffer)
    buffer.clear()

    # print(f"Episode {episode} finished")
    print(
    f"Ep {episode} | "
    f"Reward={episode_reward:.2f} | "
)


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Ep 0 | Reward=359.00 | 
Ep 1 | Reward=848.50 | 
Ep 2 | Reward=481.00 | 
Ep 3 | Reward=344.00 | 
Ep 4 | Reward=778.00 | 
Ep 5 | Reward=166.00 | 
Ep 6 | Reward=477.00 | 
Ep 7 | Reward=473.50 | 
Ep 8 | Reward=505.00 | 
Ep 9 | Reward=182.00 | 
Ep 10 | Reward=264.00 | 
Ep 11 | Reward=195.00 | 
Ep 12 | Reward=363.00 | 
Ep 13 | Reward=600.00 | 
Ep 14 | Reward=587.00 | 
Ep 15 | Reward=173.00 | 
Ep 16 | Reward=591.00 | 
Ep 17 | Reward=643.00 | 
Ep 18 | Reward=284.00 | 
Ep 19 | Reward=568.00 | 


KeyboardInterrupt: 