In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple, Optional


class SpikeTrainGenerator:
    """
    Class for generating spike trains based on a homogeneous Poisson process.
    """
    def __init__(self, duration: float = 1.0, dt: float = 0.001):
        """
        Initialize the SpikeTrainGenerator.
        
        Parameters:
        -----------
        duration : float
            Duration of the spike train in seconds.
        dt : float
            Time bin size in seconds.
        """
        self.duration = duration
        self.dt = dt
        self.num_bins = int(duration / dt)
        
    def generate_spike_train(self, rate: float) -> np.ndarray:
        """
        Generate a single spike train with the given firing rate.
        
        Parameters:
        -----------
        rate : float
            Firing rate in Hz (spikes per second).
            
        Returns:
        --------
        np.ndarray
            Binary array representing the spike train where 1 indicates a spike.
        """
        p_spike = rate * self.dt
        spike_train = np.random.rand(self.num_bins) < p_spike
        return spike_train.astype(int)
    
    def generate_multiple_spike_trains(self, rate: float, num_trials: int) -> np.ndarray:
        """
        Generate multiple spike trains with the given firing rate.
        
        Parameters:
        -----------
        rate : float
            Firing rate in Hz (spikes per second).
        num_trials : int
            Number of spike trains to generate.
            
        Returns:
        --------
        np.ndarray
            2D array of shape (num_trials, num_bins) containing the spike trains.
        """
        p_spike = b * self.dt
        spike_trains = np.random.rand(int(num_trials), self.num_bins) < p_spike
        return spike_trains.astype(int)


class SpikeTrainAnalyzer:
    """
    Class for analyzing spike trains.
    """
    def __init__(self, dt: float = 0.001):
        """
        Initialize the SpikeTrainAnalyzer.
        
        Parameters:
        -----------
        dt : float
            Time bin size in seconds.
        """
        self.dt = dt
        
    def count_spikes(self, spike_train: np.ndarray) -> int:
        """
        Count the number of spikes in a spike train.
        
        Parameters:
        -----------
        spike_train : np.ndarray
            Binary array representing the spike train.
            
        Returns:
        --------
        int
            Number of spikes in the spike train.
        """
        return np.sum(spike_train)
    
    def count_spikes_multiple(self, spike_trains: np.ndarray) -> np.ndarray:
        """
        Count the number of spikes in multiple spike trains.
        
        Parameters:
        -----------
        spike_trains : np.ndarray
            2D array containing multiple spike trains.
            
        Returns:
        --------
        np.ndarray
            Array containing the spike counts for each spike train.
        """
        return np.sum(spike_trains, axis=1)
    
    def calculate_isi(self, spike_train: np.ndarray) -> np.ndarray:
        """
        Calculate the inter-spike intervals (ISIs) for a spike train.
        
        Parameters:
        -----------
        spike_train : np.ndarray
            Binary array representing the spike train.
            
        Returns:
        --------
        np.ndarray
            Array containing the ISIs in seconds.
        """
        spike_times = np.where(spike_train == 1)[0] * self.dt
        if len(spike_times) > 1:
            return np.diff(spike_times)
        return np.array([])
    
    def calculate_isi_multiple(self, spike_trains: np.ndarray) -> np.ndarray:
        """
        Calculate the ISIs for multiple spike trains.
        
        Parameters:
        -----------
        spike_trains : np.ndarray
            2D array containing multiple spike trains.
            
        Returns:
        --------
        np.ndarray
            Array containing all ISIs across all spike trains.
        """
        all_isis = []
        for i in range(spike_trains.shape[0]):
            isis = self.calculate_isi(spike_trains[i])
            if len(isis) > 0:
                all_isis.extend(isis)
        return np.array(all_isis)
    
    def calculate_cv(self, isis: np.ndarray) -> float:
        """
        Calculate the Coefficient of Variation (CV) of ISIs.
        
        Parameters:
        -----------
        isis : np.ndarray
            Array containing ISIs.
            
        Returns:
        --------
        float
            Coefficient of Variation.
        """
        if len(isis) > 1:
            return np.std(isis) / np.mean(isis)
        return 0.0


class SpikeTrainVisualizer:
    """
    Class for visualizing spike trains and their statistics.
    """
    def __init__(self, dt: float = 0.001):
        """
        Initialize the SpikeTrainVisualizer.
        
        Parameters:
        -----------
        dt : float
            Time bin size in seconds.
        """
        self.dt = dt
        
    def plot_spike_train(self, spike_train: np.ndarray, rate: float, duration: float) -> None:
        """
        Plot a single spike train.
        
        Parameters:
        -----------
        spike_train : np.ndarray
            Binary array representing the spike train.
        rate : float
            Firing rate used to generate the spike train.
        duration : float
            Duration of the spike train in seconds.
        """
        plt.figure(figsize=(12, 3))
        plt.eventplot([np.where(spike_train == 1)[0] * self.dt], lineoffsets=[0], 
                      linelengths=[0.5], colors=['black'])
        plt.xlabel('Time (s)')
        plt.ylabel('Spike')
        plt.title(f'Spike Train (Rate: {rate} Hz)')
        plt.xlim(0, duration)
        plt.yticks([])
        plt.grid(True, alpha=0.3)
        plt.show()
        
    def plot_raster(self, spike_trains: np.ndarray, rate: float, duration: float, 
                   num_trials_to_plot: Optional[int] = None) -> None:
        """
        Create a raster plot of multiple spike trains.
        
        Parameters:
        -----------
        spike_trains : np.ndarray
            2D array containing multiple spike trains.
        rate : float
            Firing rate used to generate the spike trains.
        duration : float
            Duration of each spike train in seconds.
        num_trials_to_plot : int, optional
            Number of trials to include in the plot. If None, all trials are plotted.
        """
        if num_trials_to_plot is None:
            num_trials_to_plot = spike_trains.shape[0]
        else:
            num_trials_to_plot = min(num_trials_to_plot, spike_trains.shape[0])
            
        plt.figure(figsize=(12, 6))
        for i in range(num_trials_to_plot):
            spike_times = np.where(spike_trains[i] == 1)[0] * self.dt
            plt.plot(spike_times, np.ones_like(spike_times) * i, '|', color='black', markersize=5)
            
        plt.xlabel('Time (s)')
        plt.ylabel('Trial')
        plt.title(f'Raster Plot (First {num_trials_to_plot} Trials, Rate: {rate} Hz)')
        plt.xlim(0, duration)
        plt.ylim(-0.5, num_trials_to_plot - 0.5)
        plt.grid(True, alpha=0.3)
        plt.show()
        
    def plot_spike_count_histogram(self, spike_counts: np.ndarray, rate: float) -> None:
        """
        Plot a histogram of spike counts.
        
        Parameters:
        -----------
        spike_counts : np.ndarray
            Array containing spike counts for multiple spike trains.
        rate : float
            Firing rate used to generate the spike trains.
        """
        plt.figure(figsize=(10, 6))
        plt.hist(spike_counts, bins=range(min(spike_counts), max(spike_counts) + 2), 
                 alpha=0.7, color='skyblue', edgecolor='black')
        plt.xlabel('Number of Spikes')
        plt.ylabel('Frequency')
        plt.title(f'Histogram of Spike Counts (Rate: {rate} Hz)')
        plt.grid(True, alpha=0.3)
        
        mean_count = np.mean(spike_counts)
        plt.axvline(mean_count, color='red', linestyle='dashed', linewidth=2, 
                    label=f'Mean: {mean_count:.2f}')
        plt.legend()
        plt.show()
        
    def plot_isi_histogram(self, isis: np.ndarray, rate: float) -> None:
        """
        Plot a histogram of inter-spike intervals.
        
        Parameters:
        -----------
        isis : np.ndarray
            Array containing ISIs.
        rate : float
            Firing rate used to generate the spike trains.
        """
        plt.figure(figsize=(10, 6))
        
        # Adjust bin size based on the range of ISIs
        max_isi = min(0.1, np.max(isis) * 1.1) if len(isis) > 0 else 0.1
        hist_bins = np.arange(0, max_isi, 0.001)  # 0 to max_isi in 1 ms steps
        
        plt.hist(isis, bins=hist_bins, alpha=0.7, color='skyblue', edgecolor='black')
        plt.xlabel('Inter-spike Interval (s)')
        plt.ylabel('Frequency')
        plt.title(f'Inter-spike Interval Distribution (Rate: {rate} Hz)')
        plt.grid(True, alpha=0.3)
        
        # Calculate mean and CV of ISIs
        if len(isis) > 0:
            mean_isi = np.mean(isis)
            cv = np.std(isis) / mean_isi
            
            # Plot theoretical exponential distribution for comparison
            x = np.linspace(0, max_isi, 1000)
            bin_width = hist_bins[1] - hist_bins[0]
            y = rate * np.exp(-rate * x) * bin_width * len(isis)
            plt.plot(x, y, 'r-', linewidth=2, label='Theoretical (Exponential)')
            
            plt.axvline(mean_isi, color='green', linestyle='dashed', linewidth=2, 
                        label=f'Mean ISI: {mean_isi:.4f}s')
            plt.legend()
            
            # Add CV info as text
            plt.text(0.7 * max_isi, 0.9 * plt.ylim()[1], f'CV: {cv:.4f}',
                     bbox=dict(facecolor='white', alpha=0.7))
        
        plt.show()





# def exercise_1_3(duration: float = 1.0, dt: float = 0.001, rate: float = 80.0, 
#                 num_trials: int = 300) -> None:
#     """
#     Exercise 1.3: Compute the inter-spike interval distribution and CV.
    
#     Parameters:
#     -----------
#     duration : float
#         Duration of each spike train in seconds.
#     dt : float
#         Time bin size in seconds.
#     rate : float
#         Firing rate in Hz.
#     num_trials : int
#         Number of spike trains to generate.
#     """
#     print("\n--- Exercise 1.3: Inter-spike interval distribution ---")
    
#     # Create a spike train generator
#     generator = SpikeTrainGenerator(duration=duration, dt=dt)
    
#     # Generate multiple spike trains
#     spike_trains = generator.generate_multiple_spike_trains(rate=rate, num_

In [10]:
# Exercise 1.1 Create a spike train with the given parameters.
    
duration = 1.0
dt = 0.001
rate = 250.0
    
# Create a spike train generator
generator = SpikeTrainGenerator(duration=duration, dt=dt)

# Generate a spike train
spike_train = generator.generate_spike_train(rate=rate)

# Count spikes and calculate actual rate
analyzer = SpikeTrainAnalyzer(dt=dt)
spike_count = analyzer.count_spikes(spike_train)
actual_rate = spike_count / duration

print(f"Parameters: Duration = {duration}s, Time bin = {dt}s, Rate = {rate} Hz")
print(f"Number of spikes: {spike_count}")
print(f"Actual firing rate: {actual_rate:.2f} Hz")

# Visualize the spike train
visualizer = SpikeTrainVisualizer(dt=dt)
visualizer.plot_spike_train(spike_train, rate, duration)





In [21]:
rate = 80

print(rate)



In [24]:
# Exercise 1.2: Create a raster plot and compute spike counts
duration = 1.0
dt = 0.001
rate = 80.0 
num_trials: int = 300
num_trials_to_plot = 40


# Create a spike train generator
generator = SpikeTrainGenerator(duration=duration, dt=dt)

# Generate multiple spike trains
spike_trains = generator.generate_multiple_spike_trains(rate=rate, num_trials=num_trials)

# Analyze spike trains
analyzer = SpikeTrainAnalyzer(dt=dt)
spike_counts = analyzer.count_spikes_multiple(spike_trains)
mean_count = np.mean(spike_counts)
var_count = np.var(spike_counts)

print(f"Parameters: Duration = {duration}s, Time bin = {dt}s, Rate = {rate} Hz")
print(f"Number of trials: {num_trials}")
print(f"Mean spike count: {mean_count:.2f}")
print(f"Variance of spike counts: {var_count:.2f}")
print(f"Expected mean for Poisson process: {rate * duration} spikes")
print(f"Expected variance for Poisson process: {rate * duration} spikes")

# Visualize the raster plot and spike count histogram
visualizer = SpikeTrainVisualizer(dt=dt)
visualizer.plot_raster(spike_trains, rate, duration, num_trials_to_plot)
visualizer.plot_spike_count_histogram(spike_counts, rate)






