## Exercise 1 Solution

In [None]:
import brian2 as b2
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

def run_lif_simulation(input_current, tau_m, refractory_period, duration=100*b2.ms):
    """Runs and plots LIF simulation for given parameters."""
    b2.start_scope()

    V_rest = -65 * b2.mV
    V_reset = -65 * b2.mV
    V_th = -50 * b2.mV
    R_m = 100 * b2.Mohm

    lif_eqs = '''
    dv/dt = (-(v - V_rest) + R_m * I) / tau_m : volt (unless refractory)
    I : amp
    '''

    G = b2.NeuronGroup(1, lif_eqs, threshold='v > V_th', reset='v = V_reset',
                       refractory=refractory_period, method='exact')

    G.v = V_rest

    G.I = input_current

    spike_monitor = b2.SpikeMonitor(G)
    state_monitor = b2.StateMonitor(G, 'v', record=0)

    b2.run(duration)

    plt.figure(figsize=(12, 6))

    plt.subplot(2, 1, 1)
    plt.plot(state_monitor.t / b2.ms, state_monitor.v[0] / b2.mV, label='Membrane Potential')
    plt.axhline(V_th / b2.mV, color='red', linestyle='--', label='Threshold V_th')
    plt.axhline(V_rest / b2.mV, color='gray', linestyle=':', label='Resting V_rest')
    plt.xlabel('Time (ms)')
    plt.ylabel('Potential (mV)')
    plt.title(f'LIF Neuron (I={input_current}, tau={tau_m}, refractory={refractory_period})')
    plt.legend()
    plt.grid(True)

    plt.subplot(2, 1, 2)
    if spike_monitor.num_spikes > 0:
        plt.plot(spike_monitor.t / b2.ms, spike_monitor.i, '.k', label='Spikes')
    else:
        plt.plot([], [], '.k', label='Spikes')
    plt.xlabel('Time (ms)')
    plt.ylabel('Neuron Index')
    plt.yticks([])
    plt.title('Spike Output')
    plt.grid(True)
    plt.ylim(-0.5, 0.5)
    plt.xlim(0, duration / b2.ms)

    plt.tight_layout()
    plt.show()

    num_spikes = spike_monitor.num_spikes
    firing_rate = num_spikes / (duration / b2.second)
    print(f"Number of spikes: {num_spikes}")
    print(f"Approximate firing rate: {firing_rate:.2f} Hz")
    if num_spikes > 0:
        print(f"Spike times: {spike_monitor.t / b2.ms} ms")

print("--- Testing Input Current = 150 pA ---")
run_lif_simulation(input_current=150 * b2.pA, tau_m=10*b2.ms, refractory_period=5*b2.ms)

print("\n--- Testing tau_m = 20 ms ---")
run_lif_simulation(input_current=200 * b2.pA, tau_m=20*b2.ms, refractory_period=5*b2.ms)

print("\n--- Testing refractory = 0 ms ---")
run_lif_simulation(input_current=200 * b2.pA, tau_m=10*b2.ms, refractory_period=0*b2.ms)

## Exercise 2 Solution

In [None]:
import brian2 as b2
import numpy as np
import matplotlib.pyplot as plt
from random import sample

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

def run_network_simulation(num_inputs=50, num_outputs=10, input_rate=20*b2.Hz,
                           min_weight_val=0.5, max_weight_val=1.5,
                           connection_rule='synapses.connect()',
                           duration=200*b2.ms):
    """Runs and plots the simple feedforward network simulation."""
    b2.start_scope()

    tau_m = 10 * b2.ms
    V_rest = -65 * b2.mV
    V_reset = -65 * b2.mV
    V_th = -50 * b2.mV
    lif_eqs = '''
    dv/dt = -(v - V_rest) / tau_m : volt (unless refractory)
    '''

    input_group = b2.PoissonGroup(num_inputs, rates=input_rate)

    output_group = b2.NeuronGroup(num_outputs, lif_eqs, threshold='v > V_th',
                                  reset='v = V_reset', refractory=5*b2.ms, method='exact')
    output_group.v = V_rest

    synapse_model = 'w : volt'
    on_pre_eq = 'v_post += w'
    synapses = b2.Synapses(input_group, output_group, model=synapse_model, on_pre=on_pre_eq)

    try:
        exec(connection_rule)
    except Exception as e:
        print(f"Error executing connection rule '{connection_rule}': {e}")
        return

    if len(synapses) > 0:
        synapses.w = np.random.uniform(min_weight_val, max_weight_val, size=len(synapses)) * b2.mV
    else:
        print("Warning: No synapses were created based on the connection rule.")

    input_spike_mon = b2.SpikeMonitor(input_group, name='InputSpikes')
    output_spike_mon = b2.SpikeMonitor(output_group, name='OutputSpikes')

    print(f"Running simulation (duration={duration}, N_in={num_inputs}, N_out={num_outputs}, rate={input_rate}, weights=[{min_weight_val},{max_weight_val}], connect='{connection_rule}')")
    b2.run(duration)
    print("Simulation complete.")

    plt.figure(figsize=(12, 8))
    plt.subplot(2, 1, 1)
    if input_spike_mon.num_spikes > 0:
        plt.plot(input_spike_mon.t / b2.ms, input_spike_mon.i, '.k', markersize=2)
    plt.xlabel('Time (ms)')
    plt.ylabel('Input Neuron Index')
    plt.title(f'Input Layer Spikes ({num_inputs} Poisson Neurons @ {input_rate})')
    plt.grid(True, alpha=0.3)
    plt.xlim(0, duration / b2.ms)
    plt.ylim(-1, num_inputs)

    plt.subplot(2, 1, 2)
    if output_spike_mon.num_spikes > 0:
        plt.plot(output_spike_mon.t / b2.ms, output_spike_mon.i, '.r', markersize=4)
    plt.xlabel('Time (ms)')
    plt.ylabel('Output Neuron Index')
    plt.title(f'Output Layer Spikes ({num_outputs} LIF Neurons)')
    plt.grid(True, alpha=0.3)
    plt.xlim(0, duration / b2.ms)
    plt.ylim(-1, num_outputs)

    plt.tight_layout()
    plt.show()

    print(f"Total input spikes: {input_spike_mon.num_spikes}")
    print(f"Total output spikes: {output_spike_mon.num_spikes}")

print("--- Testing Increased Synaptic Strength ---")
run_network_simulation(min_weight_val=2.0, max_weight_val=3.0)

print("\n--- Testing Inhibitory Synaptic Strength ---")
run_network_simulation(min_weight_val=-1.5, max_weight_val=-0.5)

print("\n--- Testing Sparse Connectivity (p=0.1) ---")
run_network_simulation(connection_rule='synapses.connect(p=0.1)')

print("\n--- Testing Higher Input Rate (40 Hz) ---")
run_network_simulation(input_rate=40*b2.Hz)


## Exercise 3 Solution

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import time

class GridWorldEnv:
    def __init__(self, size=4):
        self.size = size
        self.agent_pos = (0, 0)
        self.goal_pos = (size - 1, size - 1)
        self.hole_pos = (1, 1)
        self.actions = [0, 1, 2, 3]
        self.action_delta = {
            0: (-1, 0),
            1: (1, 0),
            2: (0, -1),
            3: (0, 1)
        }

    def reset(self):
        self.agent_pos = (0, 0)
        return self.get_state()

    def get_state(self):
        return self.agent_pos

    def step(self, action):
        if action not in self.actions:
            raise ValueError("Invalid action")

        delta = self.action_delta[action]
        current_r, current_c = self.agent_pos
        next_r, next_c = current_r + delta[0], current_c + delta[1]

        if not (0 <= next_r < self.size and 0 <= next_c < self.size):
            next_r, next_c = current_r, current_c

        self.agent_pos = (next_r, next_c)
        next_state = self.get_state()

        if self.agent_pos == self.goal_pos:
            reward = 10.0
            done = True
        elif self.agent_pos == self.hole_pos:
            reward = -10.0
            done = True
        else:
            reward = -0.1
            done = False

        return next_state, reward, done

    def render(self):
        grid = np.full((self.size, self.size), '_', dtype=str)
        grid[self.goal_pos] = 'G'
        grid[self.hole_pos] = 'H'
        grid[self.agent_pos] = 'A'
        print("\n".join(" ".join(row) for row in grid))
        print("-" * (self.size * 2 - 1))

class QLearningAgent:
    def __init__(self, env, alpha=0.1, gamma=0.99, epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01):
        self.env = env
        self.q_table = {}
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.actions = getattr(env, 'actions', [0, 1, 2, 3])

    def get_q_value(self, state, action):
        return self.q_table.get(state, {}).get(action, 0.0)

    def choose_action(self, state):
        if np.random.rand() < self.epsilon:
            return np.random.choice(self.actions)
        else:
            q_values = [self.get_q_value(state, a) for a in self.actions]
            max_q = np.max(q_values)
            if all(q == q_values[0] for q in q_values):
                 best_actions = self.actions
            else:
                 best_actions = [a for a, q in zip(self.actions, q_values) if q == max_q]
            return np.random.choice(best_actions)

    def update_q_table(self, state, action, reward, next_state, done):
        old_q = self.get_q_value(state, action)
        if done:
            td_target = reward
        else:
            next_max_q = np.max([self.get_q_value(next_state, a) for a in self.actions])
            td_target = reward + self.gamma * next_max_q
        td_error = td_target - old_q
        new_q = old_q + self.alpha * td_error
        if state not in self.q_table:
            self.q_table[state] = {act: 0.0 for act in self.actions}
        self.q_table[state][action] = new_q

    def update_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

def train_agent(env, agent_params, num_episodes=5000, max_steps_per_episode=100, verbose=True):
    """Trains a Q-learning agent with given parameters."""
    agent = QLearningAgent(env, **agent_params)
    rewards_per_episode = []
    start_time = time.time()

    print(f"\n--- Training with params: {agent_params} ---")

    for episode in range(num_episodes):
        state = env.reset()
        total_reward = 0
        done = False
        for step in range(max_steps_per_episode):
            action = agent.choose_action(state)
            next_state, reward, done = env.step(action)
            agent.update_q_table(state, action, reward, next_state, done)
            state = next_state
            total_reward += reward
            if done:
                break
        agent.update_epsilon()
        rewards_per_episode.append(total_reward)

        if verbose and (episode + 1) % (num_episodes // 10) == 0:
            avg_reward = np.mean(rewards_per_episode[-(num_episodes // 10):])
            print(f"Episode {episode + 1}/{num_episodes} | Avg Reward (last {num_episodes // 10}): {avg_reward:.2f} | Epsilon: {agent.epsilon:.3f}")

    end_time = time.time()
    print(f"Training finished in {end_time - start_time:.2f} seconds.")

    plt.figure(figsize=(10, 5))
    window_size = 100
    if len(rewards_per_episode) >= window_size:
        smoothed_rewards = np.convolve(rewards_per_episode, np.ones(window_size)/window_size, mode='valid')
        plt.plot(smoothed_rewards, label=f"Params: {agent_params}")
        plt.xlabel(f'Episode (Moving Average over {window_size} episodes)')
    else:
        plt.plot(rewards_per_episode, label=f"Params: {agent_params}")
        plt.xlabel('Episode')
    plt.title('Episode Rewards over Time')
    plt.ylabel('Total Reward')
    plt.grid(True)
    plt.legend(fontsize='small')
    plt.show()

    print("Learned Policy:")
    action_arrows = {0: '^', 1: 'v', 2: '<', 3: '>'}
    policy_grid = np.full((env.size, env.size), ' ', dtype=str)
    policy_grid[env.goal_pos] = 'G'
    policy_grid[env.hole_pos] = 'H'
    for r in range(env.size):
        for c in range(env.size):
            state = (r, c)
            if state != env.goal_pos and state != env.hole_pos:
                if state in agent.q_table:
                    q_values = [agent.get_q_value(state, a) for a in agent.actions]
                    best_action = agent.actions[np.argmax(q_values)]
                    policy_grid[r, c] = action_arrows[best_action]
                else:
                     policy_grid[r, c] = '.'
    print("\n".join(" ".join(row) for row in policy_grid))
    print("-" * (env.size * 2 - 1))

grid_env = GridWorldEnv(size=4)

base_params = {'alpha': 0.1, 'gamma': 0.99, 'epsilon': 1.0, 'epsilon_decay': 0.995, 'epsilon_min': 0.01}
train_agent(grid_env, base_params)

high_alpha_params = base_params.copy()
high_alpha_params['alpha'] = 0.9
train_agent(grid_env, high_alpha_params)

low_gamma_params = base_params.copy()
low_gamma_params['gamma'] = 0.1
train_agent(grid_env, low_gamma_params)

fast_decay_params = base_params.copy()
fast_decay_params['epsilon_decay'] = 0.9
train_agent(grid_env, fast_decay_params)


## Exercise 4 Solution

This solution demonstrates **Part 1** of Exercise 4: Changing `num_outputs` in the SNN setup.
It requires re-running the SNN setup, environment definition, agent creation, and training loop with the modified parameter.

In [None]:
import brian2 as b2
import numpy as np
import matplotlib.pyplot as plt
import time

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

class QLearningAgent:
    def __init__(self, env, alpha=0.1, gamma=0.99, epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01):
        self.env = env
        self.q_table = {}
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.actions = getattr(env, 'actions', [])
        if not self.actions:
             print("Warning: Actions not provided by env, defaulting potentially incorrectly.")
             self.actions = [0, 1]

    def get_q_value(self, state, action):
        return self.q_table.get(state, {}).get(action, 0.0)

    def choose_action(self, state):
        if np.random.rand() < self.epsilon:
            return np.random.choice(self.actions)
        else:
            q_values = [self.get_q_value(state, a) for a in self.actions]
            max_q = np.max(q_values)
            if all(q == q_values[0] for q in q_values):
                 best_actions = self.actions
            else:
                 best_actions = [a for a, q in zip(self.actions, q_values) if q == max_q]
            if not best_actions: 
                return np.random.choice(self.actions)
            return np.random.choice(best_actions)

    def update_q_table(self, state, action, reward, next_state, done):
        old_q = self.get_q_value(state, action)
        if done:
            td_target = reward
        else:
             if not self.actions:
                  next_max_q = 0.0
             else:
                  next_max_q = np.max([self.get_q_value(next_state, a) for a in self.actions])
             td_target = reward + self.gamma * next_max_q
            
        td_error = td_target - old_q
        new_q = old_q + self.alpha * td_error
        if state not in self.q_table:
             self.q_table[state] = {act: 0.0 for act in self.actions}
        self.q_table[state][action] = new_q

    def update_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

def setup_snn(num_inputs_snn, num_outputs_snn, conn_prob, weight_min, weight_max):
    b2.start_scope()
    tau_m = 10 * b2.ms
    V_rest = -65 * b2.mV
    V_reset = -65 * b2.mV
    V_th = -50 * b2.mV
    lif_eqs = 'dv/dt = -(v - V_rest) / tau_m : volt (unless refractory)'

    input_group = b2.PoissonGroup(num_inputs_snn, rates=0*b2.Hz, name='input_layer')
    output_group = b2.NeuronGroup(num_outputs_snn, lif_eqs, threshold='v > V_th',
                                  reset='v = V_reset', refractory=5*b2.ms, method='exact',
                                  name='output_layer')
    output_group.v = V_rest
    synapses = b2.Synapses(input_group, output_group, 'w : volt', on_pre='v_post += w',
                             name='synapses')
    synapses.connect(p=conn_prob)
    if len(synapses) > 0:
        synapses.w = np.random.uniform(weight_min, weight_max, size=len(synapses)) * b2.mV

    output_spike_mon = b2.SpikeMonitor(output_group, name='output_spikes')
    snn_net = b2.Network(b2.collect())
    snn_net.store('initial_snn')
    print(f"SNN Setup: Input={num_inputs_snn}, Output={num_outputs_snn}, Synapses={len(synapses)}")
    return snn_net

def define_patterns(num_inputs_pat, base_rate, high_rate):
    pattern_A_rates = np.zeros(num_inputs_pat) * b2.Hz
    pattern_A_rates[:num_inputs_pat // 2] = high_rate * b2.Hz
    pattern_A_rates[num_inputs_pat // 2:] = base_rate * b2.Hz

    pattern_B_rates = np.zeros(num_inputs_pat) * b2.Hz
    pattern_B_rates[:num_inputs_pat // 2] = base_rate * b2.Hz
    pattern_B_rates[num_inputs_pat // 2:] = high_rate * b2.Hz

    patterns = {'A': pattern_A_rates, 'B': pattern_B_rates}
    pattern_labels = {'A': 0, 'B': 1}
    print("Spike patterns defined.")
    return patterns, pattern_labels

class SNNPatternEnv:
    def __init__(self, snn_network, patterns, pattern_labels, duration):
        self.snn_network = snn_network
        self.patterns = patterns
        self.pattern_labels = pattern_labels
        self.duration = duration
        self.current_pattern_name = None
        self.actions = list(set(pattern_labels.values()))

    def reset(self):
        self.current_pattern_name = np.random.choice(list(self.patterns.keys()))
        input_rates = self.patterns[self.current_pattern_name]
        self.snn_network.restore('initial_snn')
        if 'input_layer' in self.snn_network.objects_by_name:
             self.snn_network['input_layer'].rates = input_rates
        else:
             print("Error: 'input_layer' not found in network objects during reset.")
             
        self.snn_network.run(self.duration, report='off')
        state = self._get_snn_state()
        return state

    def _get_snn_state(self):
        spike_monitor = self.snn_network['output_spikes']
        num_output_neurons = 0
        if 'output_layer' in self.snn_network.objects_by_name:
            num_output_neurons = self.snn_network['output_layer'].N
        else:
            print("Error: 'output_layer' not found in network objects for state calculation.")
            return tuple()

        rates = np.zeros(num_output_neurons)
        duration_sec = self.duration / b2.second
        if duration_sec > 0 and spike_monitor.num_spikes > 0:
            neuron_indices, counts = np.unique(spike_monitor.i, return_counts=True)
            valid_indices = neuron_indices < num_output_neurons
            rates[neuron_indices[valid_indices]] = counts[valid_indices] / duration_sec

        bins = [-np.inf, 10, 30, np.inf]
        discretized_rates = tuple(np.digitize(rates, bins))
        return discretized_rates

    def step(self, action):
        correct_action = self.pattern_labels[self.current_pattern_name]
        if action == correct_action:
            reward = 1.0
        else:
            reward = -1.0
        done = True
        next_state = self._get_snn_state()
        return next_state, reward, done

def train_snn_agent(snn_env, agent_params, num_episodes=3000, verbose=True):
    agent = QLearningAgent(snn_env, **agent_params)
    rewards_per_episode_snn = []
    history = []
    print(f"\n--- Starting SNN+RL training ({num_episodes} episodes) ---")
    print(f"Agent Params: {agent_params}")
    start_time_snn = time.time()

    for episode in range(num_episodes):
        state = snn_env.reset()
        if not isinstance(state, tuple):
             print(f"Error: Invalid state received from env.reset() at episode {episode+1}: {state}")
             break
             
        action = agent.choose_action(state)
        next_state, reward, done = snn_env.step(action)

        correct_action = snn_env.pattern_labels[snn_env.current_pattern_name]
        history.append({'pattern': snn_env.current_pattern_name, 'state': state, 'chosen': action, 'correct': correct_action, 'reward': reward})

        agent.update_q_table(state, action, reward, next_state, done)
        agent.update_epsilon()
        rewards_per_episode_snn.append(reward)

        if verbose and (episode + 1) % (num_episodes // 10) == 0:
            recent_history = history[-(num_episodes // 10):]
            if len(recent_history) > 0:
               recent_accuracy = sum(1 for h in recent_history if h['chosen'] == h['correct']) / len(recent_history)
            else:
               recent_accuracy = 0.0
            print(f"Episode {episode + 1}/{num_episodes} | Recent Acc: {recent_accuracy:.3f} | Epsilon: {agent.epsilon:.3f} | Q-States: {len(agent.q_table)}")

    end_time_snn = time.time()
    if num_episodes > 0 and len(history)>0:
       total_accuracy = sum(1 for h in history if h['chosen'] == h['correct']) / len(history)
    else:
        total_accuracy = 0.0
    print(f"SNN+RL Training finished in {end_time_snn - start_time_snn:.2f} seconds.")
    print(f"Overall Accuracy: {total_accuracy:.3f}")
    print(f"Final Q-Table size: {len(agent.q_table)} states encountered.")

    # Plotting Accuracy
    plt.figure(figsize=(10, 5))
    window_size = 100
    accuracy_history = [1 if h['chosen'] == h['correct'] else 0 for h in history]
    if len(accuracy_history) >= window_size:
        smoothed_accuracy = np.convolve(accuracy_history, np.ones(window_size)/window_size, mode='valid')
        plot_indices = np.arange(window_size - 1, len(accuracy_history))
        plt.plot(plot_indices, smoothed_accuracy)
        plt.xlabel(f'Episode (Smoothed over {window_size})')
    elif len(accuracy_history) > 0:
        plt.plot(accuracy_history)
        plt.xlabel('Episode')

    plt.title(f'Accuracy over Time (SNN+RL) - Params: {agent_params}')
    plt.ylabel('Accuracy')
    plt.ylim(-0.05, 1.05)
    plt.grid(True)
    plt.show()

NUM_INPUTS = 20
PATTERN_DURATION = 100 * b2.ms
BASE_RATE = 10
HIGH_RATE = 50
CONN_PROB = 0.5
WEIGHT_MIN = 0.5
WEIGHT_MAX = 2.0
NUM_EPISODES_SNN = 3000

patterns_dict, patterns_labels_dict = define_patterns(NUM_INPUTS, BASE_RATE, HIGH_RATE)

rl_params = {'alpha': 0.1, 'gamma': 0.9, 'epsilon': 1.0, 'epsilon_decay': 0.99, 'epsilon_min': 0.05}

print("\n===== Testing with num_outputs = 4 =====")
snn_net_4 = setup_snn(NUM_INPUTS, 4, CONN_PROB, WEIGHT_MIN, WEIGHT_MAX)
env_snn_4 = SNNPatternEnv(snn_net_4, patterns_dict, patterns_labels_dict, PATTERN_DURATION)
train_snn_agent(env_snn_4, rl_params, num_episodes=NUM_EPISODES_SNN)

print("\n===== Testing with num_outputs = 2 =====")
snn_net_2 = setup_snn(NUM_INPUTS, 2, CONN_PROB, WEIGHT_MIN, WEIGHT_MAX)
env_snn_2 = SNNPatternEnv(snn_net_2, patterns_dict, patterns_labels_dict, PATTERN_DURATION)
train_snn_agent(env_snn_2, rl_params, num_episodes=NUM_EPISODES_SNN)

print("\n===== Testing with num_outputs = 8 =====")
snn_net_8 = setup_snn(NUM_INPUTS, 8, CONN_PROB, WEIGHT_MIN, WEIGHT_MAX)
env_snn_8 = SNNPatternEnv(snn_net_8, patterns_dict, patterns_labels_dict, PATTERN_DURATION)
train_snn_agent(env_snn_8, rl_params, num_episodes=NUM_EPISODES_SNN)
