In [None]:
# Maze Navigation with SNN + STDP (Enhanced)

import numpy as np
import gym
import matplotlib.pyplot as plt
from IPython.display import clear_output

# 1) Maze layout
maze_grid = np.zeros((10, 10), dtype=int)
maze_grid[1:9, 5] = 1
maze_grid[5, 1:9] = 1

# 2) Cardinal motions (N, E, S, W)
motions = [
    np.array([-1,  0]),  # North
    np.array([ 0,  1]),  # East
    np.array([ 1,  0]),  # South
    np.array([ 0, -1]),  # West
]

# 3) Custom gym‐style env (uses grid directly)
class CustomMazeEnv(gym.Env):
    def __init__(self):
        super().__init__()
        self.observation_space = gym.spaces.Box(0, 3, shape=maze_grid.shape, dtype=np.uint8)
        self.action_space      = gym.spaces.Discrete(len(motions))
        self.agent_pos = np.array([0,0])
        self.goal      = np.array([9,9])

    def reset(self):
        self.agent_pos = np.array([0,0])
        return self._get_obs()

    def step(self, action):
        old_dist = np.linalg.norm(self.agent_pos - self.goal)
        # move if free
        new = self.agent_pos + motions[action]
        if (0 <= new[0] < 10 and 0 <= new[1] < 10 and maze_grid[new[0],new[1]]==0):
            self.agent_pos = new
        done = np.array_equal(self.agent_pos, self.goal)
        # base reward at goal
        reward = 1.0 if done else -0.01
        # shaping: reduction in distance
        new_dist = np.linalg.norm(self.agent_pos - self.goal)
        reward += 0.1 * (old_dist - new_dist)
        return self._get_obs(), reward, done, {}

    def _get_obs(self):
        obs = maze_grid.copy()
        obs[tuple(self.agent_pos)] = 2
        obs[tuple(self.goal)]      = 3
        return obs

    def render(self):
        print(self._get_obs())

env = CustomMazeEnv()

# 4) SNN + STDP definitions
class LIFNeuron:
    def __init__(self, tau=20., v_thresh=1., v_reset=0.):
        self.tau = tau; self.v_thresh = v_thresh; self.v_reset = v_reset; self.v = 0.
    def step(self, I, dt=1.):
        self.v += dt * (-self.v/self.tau + I)
        if self.v >= self.v_thresh:
            self.v = self.v_reset
            return 1
        return 0

class STDP_Synapse:
    def __init__(self, w_init=0.5, A_plus=0.1, A_minus=0.12, tau_plus=20., tau_minus=20., w_min=0., w_max=1.):
        self.w = w_init
        self.A_plus, self.A_minus = A_plus, A_minus
        self.tau_plus, self.tau_minus = tau_plus, tau_minus
        self.w_min, self.w_max = w_min, w_max
        self.pre_trace = 0.
        self.post_trace = 0.
    def update(self, pre_spike, post_spike):
        self.pre_trace  *= np.exp(-1/self.tau_plus)
        self.post_trace *= np.exp(-1/self.tau_minus)
        if pre_spike:  self.pre_trace += 1.
        if post_spike: self.post_trace += 1.
        dw = self.A_plus * self.pre_trace * post_spike \
           - self.A_minus * self.post_trace * pre_spike
        self.w = np.clip(self.w + dw, self.w_min, self.w_max)
    def modulate_reward(self, reward):
        # boost learning after reward
        self.A_plus  *= (1 + reward)
        self.A_minus *= (1 - reward)

# 5) Build network: 4 wall sensors + 2 goal‐dir sensors → 4 actions
def build_network(n_sensors=6, n_motors=4):
    sensors  = [LIFNeuron(tau=15.) for _ in range(n_sensors)]
    motors   = [LIFNeuron(tau=15.) for _ in range(n_motors)]
    synapses = [[STDP_Synapse() for _ in range(n_motors)] for _ in range(n_sensors)]
    return sensors, motors, synapses

sensors, motors, synapses = build_network()

# 6) Sensor function
def get_sensor_signals(env):
    pos = env.agent_pos
    # 4 local wall detectors
    sig = []
    for m in motions:
        new = pos + m
        free = (0 <= new[0] < 10 and 0 <= new[1] < 10 and maze_grid[new[0],new[1]]==0)
        sig.append(float(free))
    # 2 goal‐direction signals
    delta = env.goal - pos
    dist  = np.linalg.norm(delta) + 1e-6
    sig.append(delta[0]/dist)
    sig.append(delta[1]/dist)
    return sig

# 7) Supervised R‑STDP Training Loop

# --- Teacher: greedy action minimizing distance to goal ---
def teacher_action(env):
    pos, goal = env.agent_pos, env.goal
    dists = []
    for m in motions:
        new = pos + m
        if 0 <= new[0] < 10 and 0 <= new[1] < 10 and maze_grid[new[0],new[1]] == 0:
            dists.append(np.linalg.norm(new - goal))
        else:
            dists.append(1e6)
    return int(np.argmin(dists))

# 7) Training loop with resets & input scaling

history = []
n_episodes = 1000
max_steps  = 200

# Input gains
SENS_GAIN  = 3.0
MOTOR_GAIN = 2.0

for ep in range(n_episodes):
    obs = env.reset()
    success = 0

    for t in range(max_steps):
        # ---- 0) Reset all membrane potentials ----
        for n in sensors + motors:
            n.v = 0.0

        # 1) Sensor → input currents (scaled) → spikes
        I_sens   = get_sensor_signals(env)
        spikes_s = [n.step(I=SENS_GAIN * i) for n, i in zip(sensors, I_sens)]

        # 2) Motor drive: weighted sum → scaled → spikes
        I_motors = []
        for j in range(len(motors)):
            Ij = sum(synapses[i][j].w * spikes_s[i] for i in range(len(sensors)))
            I_motors.append(MOTOR_GAIN * Ij)

        spikes_m = [n.step(I=Ij) for n, Ij in zip(motors, I_motors)]

        # 3) Choose action
        if sum(spikes_m) == 0:
            action = np.random.randint(len(motors))
        else:
            action = int(np.argmax(spikes_m))

        # 4) Teacher’s action
        teach = teacher_action(env)

        # 5) Execute
        obs, reward, done, _ = env.step(action)

        # 6) STDP + supervised STDP + reward
        for i in range(len(sensors)):
            for j in range(len(motors)):
                syn = synapses[i][j]
                # (a) unsupervised STDP
                syn.update(pre_spike=spikes_s[i], post_spike=spikes_m[j])
                # (b) supervised “teacher” STDP
                syn.update(pre_spike=spikes_s[i], post_spike=(1 if j == teach else 0))
                # (c) reward‐modulated STDP on success
                if reward > 0:
                    syn.modulate_reward(reward)

        if done:
            success = 1
            break

    history.append(success)
    if (ep + 1) % 20 == 0:
        clear_output(wait=True)
        print(f"Episode {ep+1}/{n_episodes}, success (last‑20): {np.mean(history[-20:])*100:.1f}%")

# 8) Plot
plt.plot(np.convolve(history, np.ones(10)/10, mode='valid'))
plt.xlabel('Episode')
plt.ylabel('Success rate (10‑episode avg)')
plt.title('SNN + R‑STDP (with resets & gains)')
plt.show()


Episode 320/1000, success (last‑20): 0.0%
