In [106]:
import numpy as np # used for array operations
from scipy.stats import poisson, binom # used for stats (mean, s.d.)
import brian2 as b2 # used for neural simulation
from brian2 import prefs
prefs.codegen.target = "numpy"
import matplotlib.pyplot as plt # used for plotting

# Simulates multiple independent Poisson processes for a group of neurons
def independent_poisson_processes(num_neurons, rate, time, num_samples): 
    # Convert time to seconds to Brian2 units
    simulation_time = time * b2.second
    # Number of time bins (in milisecs) used to discretize
    n_bins = int(time * 1000)
    
    #3D array to store spike trains of parameters
    processes = np.zeros((num_samples, num_neurons, n_bins))

    # tried to vectorize and use for-loops as least as possible, some remain
    for sample in range(num_samples):
        # Create Poisson neurons and monitor their spikes
        poisson_group = b2.PoissonGroup(num_neurons, rate * b2.Hz) # represents the neurons created, firing at the specified rate with int and rate
        spike_monitor = b2.SpikeMonitor(poisson_group) # records the spikes genreeated by the poissongroup with source
        b2.run(simulation_time)  # time
        # obtaining spike times for each neuron
        spike_trains = spike_monitor.spike_trains() 
        
        # Convert continuous spike times to discrete time bins
        for neuron_idx in spike_trains:
            spike_times = spike_trains[neuron_idx]
            # Convert spike times from seconds to milisecond indices
            spike_indices = [int(float(t/b2.ms)*10) for t in spike_times if t < simulation_time]
            # Mark spike occurrences in the processes array
            for idx in spike_indices:
                if idx < n_bins:
                    processes[sample, neuron_idx - 1, idx] = 1
        
        # Clean up Brian2 objects
        net = None

    # Analysis of spike patterns:
    summing_variable = np.sum(processes, axis=1) # sum across neurons to get population activity
    counts = np.sum(processes, axis=2) # count total spikes for each neuron in each sample
    vector_set_of_counts = np.sum(counts, axis=1) # sum spike counts across neurons for each sample
    characterization_of_counts = np.mean(vector_set_of_counts) # calculate average spike count across all samples
    
    # Find neurons with minimum spike count
    least_count = np.min(counts)
    least_count_indices = np.where(counts == least_count)
    least_count_integers = np.arange(1, least_count + 1)
    
    return processes, summing_variable, counts, vector_set_of_counts, characterization_of_counts, least_count, least_count_indices, least_count_integers

def correlated_poisson_processes(num_neurons, rate, time, num_common, num_samples):
    # Create a network with independent Poisson neurons
    # Each neuron is reprsented by a seperate PoissonGroup with rate
    net = b2.Network()
    poisson_groups = [b2.PoissonGroup(1, rate*b2.Hz) for _ in range((num_neurons + num_common) * num_samples)]
    net.add(poisson_groups)
    
    # Setting up spike monitors to record the activity of each neuron with a simulation run
    spike_monitors = [b2.SpikeMonitor(group) for group in poisson_groups]
    net.add(spike_monitors)
    net.run(time * b2.ms) # should all of the b2.ms be (t/b2.ms)*10) ????

    # Convert spike trains for Brian2 format to numpy arrays, intialize storing process, convert spike times to discrete time indicies
    spike_trains = [monitor.spike_trains()[0] for monitor in spike_monitors]
    independent_procs = np.zeros((num_samples, num_neurons + num_common, int(time)), dtype=bool)
    spike_indices_list = [[int(t / b2.ms) for t in spike_train if t < time * b2.ms] for spike_train in spike_trains]

    # Ensuring all spike trains have the same length by padding with zeros
    max_length = max(len(indices) for indices in spike_indices_list)
    padded_spike_indices_list = [indices + [0] * (max_length - len(indices)) for indices in spike_indices_list]

    # Convert padded lists to numpy array for efficient processing
    spike_indices = np.array(padded_spike_indices_list)

    # Calculate sample and neuron idicies for efficient array indexing
    sample_indices, neuron_indices = np.divmod(np.arange((num_neurons + num_common) * num_samples), num_neurons + num_common)
    for i, indices in enumerate(spike_indices):
        for index in indices:
            if index < int(time):
                independent_procs[sample_indices[i], neuron_indices[i], index] = True

    # Generate correlated processes:
    correlated_procs = np.zeros((num_samples, num_neurons, int(time)), dtype=bool) #  by combining independent processes
    correlated_procs[:, :num_common, :] = np.cumsum(independent_procs[:, :num_common, :], axis=2) # num_common neurons are correlated through cumsum
    correlated_procs[:, num_common:, :] = independent_procs[:, num_common:num_neurons, :] # remaining neurons maintain their independent firing patterns

    return correlated_procs # should this return more?

def small_network(num_neurons, rate, time):
    neurons = b2.NeuronGroup(num_neurons, 'dv/dt = -v/(10*ms) : 1', threshold='v>1', reset='v=0')
    inputs = b2.PoissonInput(target=neurons, target_var='v', N=num_neurons, rate=rate*b2.Hz, weight=1)
    spike_monitor = b2.SpikeMonitor(neurons)
    net = b2.Network(neurons, inputs, spike_monitor)
    net.run(time * b2.ms)  # Run the network instead of b2.run
    return spike_monitor.spike_trains()

# start - counting functions

# The counting functions are designed to calculate the number of spikes in a neuron or a group of neurons over time.
# These functions take in an array of spike counts and return the cumulative sum of the spike counts up to a specified time.
# The count_at_time function calculates the cumulative sum of the spike counts up to a specified time t,
# while the count1 function calculates the cumulative sum of the spike counts for a single neuron over time.
# The countall function calculates the cumulative sum of the spike counts for all neurons over time,
# and the counting_process_nd function calculates the cumulative sum of the spike counts for multiple neurons over time.

def count_at_time(counts, times, t):
    return np.cumsum(counts[:np.sum(times < t)])

def count1(process, time):
    # **Changed to use np.cumsum**
    return np.cumsum(process)[:time]

def countall(processes, time):
    counts = []
    for process in processes:
        count = count1(process, time)
        counts.append(count)
    return counts

def countall_vectorized(processes, time):
    # **Added this function to vectorize the counting process**
    counts = [np.cumsum(process) for process in processes[0]]
    return counts

def counting_process_nd(independent_processes, num_samples, time, num_neurons_to_plot):
    counting_process_nd = [[[0 for _ in range(time)] for _ in range(num_neurons_to_plot)] for _ in range(num_samples)]
    for i in range(num_samples):
        counts = [0] * num_neurons_to_plot
        for j in range(time):
            for k in range(num_neurons_to_plot):
                if independent_processes[i][k][j] == 1:
                    counts[k] += 1
            for k in range(num_neurons_to_plot):
                counting_process_nd[i][k][j] = counts[k]
    return counting_process_nd
# end - counting functions

# start - plot functions

# The plotting functions are designed to visualize the spike counts and other data.
# The plot_neurons_spiking function plots the spike counts for multiple neurons over time,
# while the plot_neuron_spiking_standard_dev function plots the spike counts for a single neuron over time,
# along with the mean and standard deviation of the spike counts.
# The plot_two_neurons_against_time function plots the spike counts for two neurons over time,
# along with the cumulative sum of the spike counts for a third neuron.

def plot_neurons_spiking(processes, time): # showing some examples of neurons spiking
    fig = plt.figure(figsize=(10,6))
    for i in range(len(processes[0])):
        plt.plot(processes[0][i], label=f'Neuron {i}')
    plt.xlabel('Time')
    plt.ylabel('Spike')
    plt.title('Neurons Over Time')
    plt.legend()
    plt.show()

def plot_neuron_spiking_standard_dev(processes, time):
    fig = plt.figure(figsize=(10,6))
    plt.plot(processes[0][0], label='Neuron 1')
    sd = np.std(processes[0][0])
    mean = np.mean(processes[0][0])
    print(f"Mean: {mean:.2f}")
    print(f"Standard Deviation: {sd:.2f}")
    plt.fill_between(range(time), processes[0][0] - sd, processes[0][0] + sd, alpha=0.2, label='Standard Deviation')
    plt.axhline(y=mean, color='black', linestyle='--', label='Mean')
    plt.xlabel('Time')
    plt.ylabel('Spike')
    plt.title('Neuron 1 Over Time with Standard Deviation')
    plt.legend()
    plt.text(0.5, 0.9, f"Mean: {mean:.2f}, SD: {sd:.2f}", transform=plt.gca().transAxes)
    plt.show()

def plot_two_neurons_against_time(processes, time):
    fig = plt.figure(figsize=(10,6))
    ax = fig.add_subplot(111, projection='3d')
    ax.plot(processes[0][0], processes[0][1], np.cumsum(processes[0][2]))
    ax.set_xlabel('Neuron 1')
    ax.set_ylabel('Neuron 2')
    ax.set_zlabel('Cumulative Sum of Neuron 3')
    plt.title('Two Neurons Over Time with Cumulative Sum of Third Neuron')
    plt.show()

def plot_count_neuron1_vs_time(counting_process_nd, num_samples):
    plt.figure(figsize=(10,6))
    for i in range(num_samples):
        plt.plot(counting_process_nd[i][0], label=f'Sample {i}')
    plt.xlabel('Time')
    plt.ylabel('Count of Neuron 1')
    plt.title('Count of Neuron 1 Over Time')
    plt.ylim(0, None)  # Set y-axis lower limit to 0
    plt.show()

def plot_count_neuron1_vs_neuron2_vs_time(counting_process_nd, num_samples):
    fig = plt.figure(figsize=(12,8))
    ax = fig.add_subplot(111, projection='3d')
    for i in range(num_samples):
        ax.plot(counting_process_nd[i][0], counting_process_nd[i][1], range(len(counting_process_nd[i][0])))
    ax.set_xlabel('Count of Neuron 1')
    ax.set_ylabel('Count of Neuron 2')
    ax.set_zlabel('Time', rotation=90)
    ax.set_title('Count of Neuron 1 vs Count of Neuron 2 vs Time')
    ax.set_xlim(0, max([max(counting_process_nd[i][0]) for i in range(num_samples)]))
    ax.set_ylim(0, max([max(counting_process_nd[i][1]) for i in range(num_samples)]))
    ax.set_zlim(0, max([len(counting_process_nd[i][0]) for i in range(num_samples)]))
    plt.show()

def plot_count1(count1_output, title='Count of Spikes Over Time (Single Neuron)'):
    plt.figure(figsize=(10,6))
    plt.plot(count1_output)
    plt.xlabel('Time')
    plt.ylabel('Count')
    plt.title(title)
    plt.show()
    print(f"Plot of {title} generated successfully.")

def plot_vectorized_count1(count1_vectorized_output, title='Count of Spikes Over Time (Single Neuron, Vectorized)'):
    plt.figure(figsize=(10,6))
    plt.plot(count1_vectorized_output)
    plt.xlabel('Time')
    plt.ylabel('Count')
    plt.title(title)
    plt.show()
    print(f"Plot of {title} generated successfully.")

def plot_counts(counts_all, title='Count of Spikes Over Time'):
    plt.figure(figsize=(10,6))
    for i, count in enumerate(counts_all):
        plt.plot(count, label=f'Neuron {i}')
    plt.xlabel('Time')
    plt.ylabel('Count')
    plt.title(title)
    plt.legend()
    plt.show()

def plot_counts_vectorized(counts_all_vectorized, title='Count of Spikes Over Time (Vectorized)'):
    plt.figure(figsize=(10,6))
    for i, count in enumerate(counts_all_vectorized):
        plt.plot(count, label=f'Neuron {i}')
    plt.xlabel('Time')
    plt.ylabel('Count')
    plt.title(title)
    plt.legend()
    plt.show()
# end - plot functions

# start - stats functions

# The statistical functions are designed to calculate statistical properties of the spike counts.
# These functions take in an array of spike counts and return statistical properties such as the mean, covariance, and slope of the spike counts. 
# The calculate_mean_and_covariance function calculates the mean and covariance of the spike counts for multiple neurons
# over time, while the get_slope function calculates the slope of the spike counts for multiple neurons over time.
# The calculate_mean_with_std function calculates the mean and standard deviation of the spike counts for a single neuron over time,
# and the calculate_covariance_matrix function calculates the covariance matrix of the spike counts for multiple neurons over time.

def calculate_mean_and_covariance(processes):
    means = []
    covariances = []
    for i in range(len(processes[0])):
        mean = np.mean([process[i] for process in processes])
        covariance = np.cov([process[i] for process in processes], rowvar=False)
        means.append(mean)
        covariances.append(covariance)
    return means, covariances

def get_slope(processes):
    slopes = []
    for i in range(len(processes[0])):
        slope = np.polyfit(range(len(processes[0][i])), processes[0][i], 1)[0]
        slopes.append(slope)
    return slopes

def calculate_mean_with_std(processes):
    mean_with_std = np.mean(processes[0][0]) + np.std(processes[0][0])
    return mean_with_std

def calculate_covariance_matrix(processes):
    cov_matrix = np.cov([processes[0][0], processes[0][1]])
    return cov_matrix
# end - stats functions

# start -  Discretize time and sampling functions

# The discretization and sampling functions are designed to discretize the time values into bins of a specified size and sample
# the spike counts at a specified resolution. The discretize_time function discretizes the time values into bins of a specified size,
# while the sample_spikes function samples the spike counts at a specified resolution.
# These functions are useful for analyzing the spike counts at different time scales.

def discretize_time(processes, time):
    discretized_processes = []
    for process in processes:
        discretized_process = np.array([process[i] for i in range(0, len(process), time)])
        discretized_processes.append(discretized_process)
    return discretized_processes

def sample_spikes(processes, resolution):
    sampled_processes = []
    for process in processes:
        sampled_process = process[::resolution]
        sampled_processes.append(sampled_process)
    return sampled_processes
# end -  Discretize time and sampling function

# start - tagging and branding functions

# The tagging and branding functions are designed to tag each event in the spike counts with a Poisson process and simulate
# a network of neurons. The tag_events_with_poisson function tags each event in the spike counts with a Poisson process,
# while the brand_networks function simulates a network of neurons with a specified number of neurons, firing rate, and time,
# and returns the spike trains of the neurons. These functions are useful for analyzing the behavior of neural networks.


def tag_events_with_poisson(processes, rate):
    if isinstance(processes, (np.float64, float, int)):
        # If single value, return single Poisson sample
        return np.random.poisson(rate)
    else:
        # If array, process as before
        tagged_processes = []
        for process in processes:
            tagged_process = np.random.poisson(rate, size=len(process))
            tagged_processes.append(tagged_process)
        return tagged_processes  

def brand_networks(num_neurons, firing_rate, time, num_rates=None):
    neurons = b2.NeuronGroup(num_neurons, 'dv/dt = -v/(10*ms) : 1', threshold='v>1', reset='v=0')

    if num_rates is None:
        # Use a single rate for all neurons
        inputs = b2.PoissonInput(target=neurons, target_var='v', N=num_neurons, rate=firing_rate*b2.Hz, weight=1)
    else:
        # Use multiple rates for different neurons
        input_rates = np.random.uniform(0, firing_rate, size=num_rates)
        inputs = []
        for i in range(num_neurons):
            input_rate = input_rates[i % num_rates] * b2.Hz
            input_ = b2.PoissonInput(target=neurons, target_var='v', N=1, rate=input_rate, weight=1)
            inputs.append(input_)

    spike_monitor = b2.SpikeMonitor(neurons)
    net = b2.Network(neurons, inputs, spike_monitor)
    net.run(time * b2.ms)  # Run the network instead of b2.run
    return spike_monitor.spike_trains() 
# end - tagging and branding functions

