# Objective 3: STDP + Q-learning Integration

Combining spiking STDP learning rule with Q-learning updates


In [36]:
# 1. Imports and setup
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gymnasium as gym
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [37]:
# 2. Hyperparameters (tuned for better accuracy)
num_episodes   = 10000  # Increased for more learning
max_steps      = 100   # Reasonable episode length

gamma          = 0.95        # Discount factor for Q-learning
alpha_td       = 0.005       # Lower TD error learning rate
alpha_stdp     = 0.002       # Lower STDP/eligibility trace learning rate

tau_e          = 20.0        # Eligibility trace decay (ms)
w_min, w_max   = -1.0, 1.0   # Synaptic weight bounds

epsilon        = 1.0         # Initial epsilon for epsilon-greedy
min_epsilon    = 0.01
epsilon_decay  = 0.998        # Slower decay for more exploration

temperature    = 1.0         # For softmax exploration (optional)
tau_decay      = 0.995

# Environment setup
env        = gym.make('Taxi-v3')
num_states = env.observation_space.n
num_actions= env.action_space.n

# SNN/Network parameters for better learning
n_hidden = 64   # Increased from 32
snn_steps = 10  # Increased from 5



In [38]:
# 3. Network: Spiking synapses as Q-table
#    Represent Q-values as trainable synaptic weights

# One-hot input layer, no hidden layer
class SNN_Q(nn.Module):
    def __init__(self, num_states, num_actions, beta=0.9, n_hidden=64, snn_steps=10):
        super().__init__()
        # weights from state neurons to action neurons
        self.w = nn.Parameter(torch.zeros(num_states, num_actions))
        self.beta = beta
        self.n_hidden = n_hidden
        self.snn_steps = snn_steps

    def forward(self, state_idx):
        # one-hot encoding
        x = torch.zeros(1, num_states, device=device)
        x[0, state_idx] = 1.0
        # instantaneous "membrane" potentials = weighted sum
        mem = x @ self.w
        return mem.squeeze(0)

model = SNN_Q(num_states, num_actions).to(device)
optimizer = optim.SGD(model.parameters(), lr=alpha_td)



In [39]:
# 4. STDP learning rule (simplified):
def stdp_update(w, pre_spike, post_spike, lr=alpha, tau_plus=20, tau_minus=20):
    # if pre before post: potentiation, else depression
    delta_w = torch.zeros_like(w)
    # here pre_spike, post_spike are 0/1
    if pre_spike and post_spike:
        delta_w += lr * w.new_tensor(np.exp(-1/tau_plus))
    elif pre_spike and not post_spike:
        delta_w -= lr * w.new_tensor(np.exp(-1/tau_minus))
    return delta_w



In [40]:
# 5. Training loop with Q-learning + STDP
env = gym.make('Taxi-v3')
table = []
for episode in range(num_episodes):
    state, _ = env.reset()
    for step in range(max_steps):
        # forward pass: get Q-values
        q_vals = model(state)
        # epsilon-greedy
        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            action = q_vals.argmax().item()
        next_state, reward, done, truncated, _ = env.step(action)

        # estimate target using next state's max Q
        with torch.no_grad():
            next_q = model(next_state)
            target = reward + gamma * next_q.max().item()

        # compute classical Q-error
        pred = q_vals[action]
        error = target - pred.item()

        # STDP-like weight update for synapse (state->action)
        pre_spike = 1
        post_spike = 1 if action == q_vals.argmax().item() else 0
        dw = stdp_update(model.w[state, action], pre_spike, post_spike)

        # directly adjust weight tensor
        model.w.data[state, action] += alpha * error + dw

        state = next_state
        if done or truncated:
            break
    epsilon = max(min_epsilon, epsilon * epsilon_decay)



In [42]:
# 6. Evaluate policy
env_eval = gym.make('Taxi-v3')

for ep in range(5):
    s, _ = env_eval.reset()
    done=False
    print(f"Episode {ep}")
    while not done:
        q_vals = model(s)
        a = q_vals.argmax().item()
        s, r, done, truncated, _ = env_eval.step(a)
        env_eval.render()
        if done or truncated:
            print(f"Reward: {r}\n")
            break


Episode 0
Reward: 20

Episode 1
Reward: 20

Episode 2
Reward: -1

Episode 3
Reward: -1

Episode 4
Reward: -1

