In [1]:
import numpy as np
from scipy import signal
import matplotlib.pyplot as plt
from scipy.signal import hilbert
from scipy.integrate import simps
import pickle
# Import function for Morlet Wavelets
from neurodsp.timefrequency.wavelets import compute_wavelet_transform

def butter_bandpass(lowcut, highcut, fs, order=5):
    nyquist = 0.5 * fs  # Nyquist frequency
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = signal.butter(order, [low, high], btype='band')
    return b, a

# Function to apply the bandpass filter
def bandpass_filter(data,fs, lowcut,  highcut,order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order)
    y = signal.filtfilt(b, a, data)
    return y


def harmonics_removal(signal, fs, harmonics, dftbandwidth=1, dftneighbourwidth=2):
    """
    Removes beta harmonics in a signal via spectrum interpolation.

    Parameters:
    - signal: 1D numpy array, the input signal
    - fs: float, sampling rate in Hz
    - harmonics: list of floats, harmonics frequencies in Hz (e.g., [50, 100, 150])
    - dftbandwidth: float, half bandwidth of harmonics frequency bands in Hz (default 1)
    - dftneighbourwidth: float, width of neighbouring frequencies in Hz (default 2)

    Returns:
    - cleaned_signal: 1D numpy array, the signal withouth the indicated beta harmonics    """
    # FFT of the signal
    N = len(signal)
    freqs = np.fft.fftfreq(N, 1/fs)
    signal_fft = np.fft.fft(signal)

    # Helper function to get indices of frequency bins
    def get_freq_indices(freq, bandwidth, fs, N):
        
        return np.where((freqs >= (freq - bandwidth)) & (freqs <= (freq + bandwidth)))[0]

    # Process each harmonic
    for f in harmonics:
        noise_indices = get_freq_indices(f, dftbandwidth, fs, N)
       
        noise_indices = np.concatenate((noise_indices, get_freq_indices(-f, dftbandwidth, fs, N)))
       
        for noise_index in noise_indices:
            # Find neighbouring indices
            lower_bound = f - dftneighbourwidth - dftbandwidth
            upper_bound = f + dftneighbourwidth + dftbandwidth
            neighbours = np.where((freqs >= lower_bound) & (freqs <= upper_bound) & 
                                  ((freqs < (f - dftbandwidth)) | (freqs > (f + dftbandwidth))))[0]
            
            # Compute the mean amplitude of neighbouring bins
            if len(neighbours) > 1:
                neighbour_freqs = freqs[neighbours]
          
            
                neighbour_amplitudes = np.abs(signal_fft[neighbours])
                
                interpolated_amplitude = np.mean(neighbour_amplitudes)
                original_phase = np.angle(signal_fft[noise_index])
                # Replace the amplitude of the harmonics frequency bin by the interpolated value
                signal_fft[noise_index] = interpolated_amplitude * np.exp(1j * original_phase)
    # Inverse FFT to get the cleaned signal

    cleaned_signal = np.fft.ifft(signal_fft).real
    
    return cleaned_signal

#Center of gravity method for computing central frequency
def  cog(f,pxx,f1,f2):
    prod=f*pxx
    cog=abs(simps(prod[(f1<f) & (f<f2)], f[(f1<f) & (f<f2)])/simps(pxx[(f1<f) & (f<f2)], f[(f1<f) & (f<f2)]))
    return cog


## GPe-TI

In [None]:
# Define the range of frequencies to explore for the wavelet transform
freqs = np.linspace(60, 150, 200)

# Sampling frequency of the data
fs = 1000

# Population size used for calculating the equivalent Poisson process
n_pop = 780

# Load precomputed data from a pickle file
result = pickle.load( open( "isolated_gp.p", "rb" ) )

# Initialize an empty list to store processed signals
s1 = []

# Loop through the first 5 sets of activity data
for i in range(5):
    # Extract actual and isolated activity data from the loaded results
    data_D2,isol=result['1.0'][i]
     

    # Compute the power spectral density (PSD) using Welch's method
    f, pxx = signal.welch(data_D2, fs, nperseg=1000, noverlap=int(1000/2), nfft=max(30000, 1000), scaling='density', window='hamming')

    # Calculate the center of gravity (beta frequency range: 10-30 Hz)
    f_beta = cog(f, pxx, 10, 30)

    # Calculate harmonic frequencies of the beta rhythm
    harmonics = np.arange(2, 5) * f_beta

    # Remove beta harmonics from the activity signal
    data_D2 = harmonics_removal(data_D2, fs, harmonics, 5, 3)

    # Append the processed activity signal to the list
    s1.append(data_D2)

# Flatten the list of processed signals into a single array
s1 = [x for xs in s1 for x in xs]
s1 = np.array(s1)

# Generate a binomial distribution-based signal for the population
pois = np.random.binomial(n_pop, np.mean(s1)/n_pop, len(s1))

# Apply bandpass filters around the beta frequency to compute the beta phase
s2 = bandpass_filter(s1, 1000, f_beta - 5, f_beta + 5)
s2_pois = bandpass_filter(pois, 1000, f_beta - 5, f_beta + 5)

# Compute the analytic signal and phase using the Hilbert transform
analytic_signal2 = hilbert(s2)
phase2 = np.angle(analytic_signal2)

analytic_signal_pois = hilbert(s2_pois)
phase_pois = np.angle(analytic_signal_pois)

# Define the number of bins for beta phase 
num_bins = 200
phase_bins = np.linspace(-np.pi, np.pi, num_bins)

# Assign phase data to bins
bin_indices = np.digitize(phase2, phase_bins)
bin_indices_pois = np.digitize(phase_pois, phase_bins)

# Compute the wavelet transform magnitude for the signals
mwt = abs(compute_wavelet_transform(s1, fs=fs, n_cycles=5, freqs=freqs))
mwt_pois = abs(compute_wavelet_transform(pois, fs=fs, n_cycles=5, freqs=freqs))

# Calculate mean wavelet transform values for each phase bin
mean_mwt = np.array([
    np.mean(mwt[:, bin_indices == i], axis=1) if len(mwt[:, bin_indices == i]) > 0 else 0
    for i in range(0, num_bins)
])
mean_mwt[0, :] = mean_mwt[-1, :]

mean_mwt_pois = np.array([
    np.mean(mwt_pois[:, bin_indices_pois == i], axis=1) if len(mwt_pois[:, bin_indices_pois == i]) > 0 else 0
    for i in range(0, num_bins)
])

mean_mwt_pois[0, :] = mean_mwt_pois[-1, :]

# Use the 95th percentile of the mean Poisson wavelet transform for finding significant wavelet power
mean_mwt_pois = np.tile(np.percentile(mean_mwt_pois, 95, axis=1), (mean_mwt_pois.shape[1], 1)).T

# Visualization: Plot the time-frequency representation
plt.figure(figsize=(12, 12))
plt.rcParams.update({'font.size': 30})
plt.rc('axes', labelsize=30)
plt.rc('xtick', labelsize=25)
plt.rc('ytick', labelsize=25)

# Define a custom colormap
from matplotlib.colors import ListedColormap
mean_mwt[mean_mwt < mean_mwt_pois] = 0
cmap = plt.cm.jet
cmap_colors = cmap(np.arange(cmap.N))
cmap_colors[0] = [1, 1, 1, 1]  # Set the first color to white
custom_cmap = ListedColormap(cmap_colors)

# Plot the wavelet transform as an image
plt.imshow(mean_mwt.T, aspect='auto', origin='lower', cmap=custom_cmap, extent=[phase_bins[0], phase_bins[-1], freqs[0], freqs[-1]], vmax=np.max(mean_mwt), vmin=np.min(mean_mwt[mean_mwt != 0]))

# Highlight frequency with the maximum wavelet transform per phase bin
max_indices = np.argmax(mean_mwt, axis=1)
mean_frequency = freqs[max_indices]
plt.plot(phase_bins,mean_frequency,color='black')

# Configure plot labels and colorbar
plt.xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi], labels=[r'$-\pi$', r'$-\pi/2$', r'$0$', r'$\pi/2$', r'$\pi$'])
cbar = plt.colorbar()
cbar.set_label('Wavelet Amplitude', labelpad=15)
plt.xlabel("Beta phase [rad]", labelpad=12)
plt.ylabel("Frequency [Hz]", labelpad=12)





## D2

In [None]:
# Define the range of frequencies to explore for the wavelet transform
freqs = np.linspace(60, 150, 200)

# Sampling frequency of the data
fs = 1000

# Population size used for calculating the equivalent Poisson process
n_pop = 6000

# Load precomputed data from a pickle file
result = pickle.load( open( "isolated_d2.p", "rb" ) )

# Initialize an empty list to store processed signals
s1 = []

# Loop through the first 5 sets of activity data
for i in range(5):
    # Extract actual and isolated activity data from the loaded results
    data_D2,isol=result['1.0'][i]
     

    # Compute the power spectral density (PSD) using Welch's method
    f, pxx = signal.welch(data_D2, fs, nperseg=1000, noverlap=int(1000/2), nfft=max(30000, 1000), scaling='density', window='hamming')

    # Calculate the center of gravity (beta frequency range: 10-30 Hz)
    f_beta = cog(f, pxx, 10, 30)

    # Calculate harmonic frequencies of the beta rhythm
    harmonics = np.arange(2, 5) * f_beta

    # Remove beta harmonics from the activity signal
    data_D2 = harmonics_removal(data_D2, fs, harmonics, 5, 3)

    # Append the processed activity signal to the list
    s1.append(data_D2)

# Flatten the list of processed signals into a single array
s1 = [x for xs in s1 for x in xs]
s1 = np.array(s1)

# Generate a binomial distribution-based signal for the population
pois = np.random.binomial(n_pop, np.mean(s1)/n_pop, len(s1))

# Apply bandpass filters around the beta frequency to compute the beta phase
s2 = bandpass_filter(s1, 1000, f_beta - 5, f_beta + 5)
s2_pois = bandpass_filter(pois, 1000, f_beta - 5, f_beta + 5)

# Compute the analytic signal and phase using the Hilbert transform
analytic_signal2 = hilbert(s2)
phase2 = np.angle(analytic_signal2)

analytic_signal_pois = hilbert(s2_pois)
phase_pois = np.angle(analytic_signal_pois)

# Define the number of bins for beta phase 
num_bins = 200
phase_bins = np.linspace(-np.pi, np.pi, num_bins)

# Assign phase data to bins
bin_indices = np.digitize(phase2, phase_bins)
bin_indices_pois = np.digitize(phase_pois, phase_bins)

# Compute the wavelet transform magnitude for the signals
mwt = abs(compute_wavelet_transform(s1, fs=fs, n_cycles=5, freqs=freqs))
mwt_pois = abs(compute_wavelet_transform(pois, fs=fs, n_cycles=5, freqs=freqs))

# Calculate mean wavelet transform values for each phase bin
mean_mwt = np.array([
    np.mean(mwt[:, bin_indices == i], axis=1) if len(mwt[:, bin_indices == i]) > 0 else 0
    for i in range(0, num_bins)
])
mean_mwt[0, :] = mean_mwt[-1, :]

mean_mwt_pois = np.array([
    np.mean(mwt_pois[:, bin_indices_pois == i], axis=1) if len(mwt_pois[:, bin_indices_pois == i]) > 0 else 0
    for i in range(0, num_bins)
])

mean_mwt_pois[0, :] = mean_mwt_pois[-1, :]

# Use the 95th percentile of the mean Poisson wavelet transform for finding significant wavelet power
mean_mwt_pois = np.tile(np.percentile(mean_mwt_pois, 95, axis=1), (mean_mwt_pois.shape[1], 1)).T

# Visualization: Plot the time-frequency representation
plt.figure(figsize=(12, 12))
plt.rcParams.update({'font.size': 30})
plt.rc('axes', labelsize=30)
plt.rc('xtick', labelsize=25)
plt.rc('ytick', labelsize=25)

# Define a custom colormap
from matplotlib.colors import ListedColormap
mean_mwt[mean_mwt < mean_mwt_pois] = 0
cmap = plt.cm.jet
cmap_colors = cmap(np.arange(cmap.N))
cmap_colors[0] = [1, 1, 1, 1]  # Set the first color to white
custom_cmap = ListedColormap(cmap_colors)

# Plot the wavelet transform as an image
plt.imshow(mean_mwt.T, aspect='auto', origin='lower', cmap=custom_cmap, extent=[phase_bins[0], phase_bins[-1], freqs[0], freqs[-1]], vmax=np.max(mean_mwt), vmin=np.min(mean_mwt[mean_mwt != 0]))

# Highlight frequency with the maximum wavelet transform per phase bin
max_indices = np.argmax(mean_mwt, axis=1)
mean_frequency = freqs[max_indices]
plt.plot(phase_bins,mean_frequency,color='black')

# Configure plot labels and colorbar
plt.xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi], labels=[r'$-\pi$', r'$-\pi/2$', r'$0$', r'$\pi/2$', r'$\pi$'])
cbar = plt.colorbar()
cbar.set_label('Wavelet Amplitude', labelpad=15)
plt.xlabel("Beta phase [rad]", labelpad=12)
plt.ylabel("Frequency [Hz]", labelpad=12)


