In [None]:
import pandas as pd
import numpy as np
import mne
from sklearn.preprocessing import scale
from scipy.signal import hilbert
import pywt
import matplotlib.pyplot as plt

In [None]:
file_path = 'Original_Data.xlsx'
sfreq = 512  # Hz
ecbl_df = pd.read_excel(file_path, sheet_name='ECBL')
eobl_df = pd.read_excel(file_path, sheet_name='EOBL')
channel_names = ecbl_df.columns.tolist()

# Problem Statements: - 

#### 1. Alpha Spindle Count Comparison:Count the total number of alpha spindles in the sheets labelled ECBL and EOBL, then compare the counts to determine the difference.

In [None]:
def df_to_raw(df, sfreq, ch_names):
    data_scaled = scale(df.values.T)  # shape (n_channels, n_times)
    info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
    raw = mne.io.RawArray(data_scaled, info)
    return raw

raw_ecbl = df_to_raw(ecbl_df, sfreq, channel_names)
raw_eobl = df_to_raw(eobl_df, sfreq, channel_names)

In [None]:
def apply_ica(raw):
    ica = mne.preprocessing.ICA(n_components=15, random_state=97, max_iter=500)
    ica.fit(raw)
    raw_ica = ica.apply(raw.copy())
    return raw_ica 
raw_ecbl_ica = apply_ica(raw_ecbl)
raw_eobl_ica = apply_ica(raw_eobl)

In [None]:
def bandpass(raw, l_freq=8, h_freq=13):
    return raw.copy().filter(l_freq=l_freq, h_freq=h_freq)

raw_ecbl_alpha = bandpass(raw_ecbl_ica)
raw_eobl_alpha = bandpass(raw_eobl_ica)

In [None]:
def detect_spindles(raw, channel='Pz', threshold_factor=2.5, min_duration=0.5):
    signal = raw.get_data(picks=channel)[0]
    
    envelope = np.abs(hilbert(signal))
    threshold = threshold_factor * np.std(envelope)

    above = envelope > threshold
    count, dur = 0, 0
    for a in above:
        if a:
            dur += 1
        elif dur > 0:
            if dur / sfreq >= min_duration:
                count += 1
            dur = 0
    return count

In [None]:
channel = 'Pz'  
ecbl_count = detect_spindles(raw_ecbl_alpha, channel)
eobl_count = detect_spindles(raw_eobl_alpha, channel)

In [None]:
print(f"Alpha Spindle Count in ECBL: {ecbl_count}")
print(f"Alpha Spindle Count in EOBL: {eobl_count}")
print(f"Difference: {abs(ecbl_count - eobl_count)}")

In [None]:
def plot_wavelet(signal, title='Wavelet Scaleogram', scales=np.arange(1, 128), wavelet='morl'):
    coefs, freqs = pywt.cwt(signal, scales, wavelet, 1/sfreq)
    plt.figure(figsize=(12, 6))
    plt.imshow(np.abs(coefs), extent=[0, len(signal)/sfreq, scales[-1], scales[0]], 
               cmap='jet', aspect='auto')
    plt.title(title)
    plt.ylabel("Scale")
    plt.xlabel("Time (s)")
    plt.colorbar(label="Amplitude")
    plt.show()

In [None]:
plot_wavelet(raw_ecbl_alpha.get_data(picks=channel)[0], title='ECBL - Alpha Band - Pz')
plot_wavelet(raw_eobl_alpha.get_data(picks=channel)[0], title='EOBL - Alpha Band - Pz')

#### 2. Electrode-wise Alpha Spindle Count Comparison: For each electrode, count the number of alpha spindles in both ECBL and EOBL sheets. Compare the counts side by side for each electrode to highlight the differences.

In [None]:
def electrode_wise_spindles(raw_ecbl, raw_eobl, threshold_factor=2.5, min_duration=0.5):
    counts = []
    for ch in raw_ecbl.ch_names:
        ecbl_cnt = detect_spindles(raw_ecbl, ch, threshold_factor, min_duration)
        eobl_cnt = detect_spindles(raw_eobl, ch, threshold_factor, min_duration)
        diff = ecbl_cnt - eobl_cnt
        counts.append([ch, ecbl_cnt, eobl_cnt, diff])
    return pd.DataFrame(counts, columns=["Electrode", "ECBL Count", "EOBL Count", "Difference"])

spindle_comparison_df = electrode_wise_spindles(raw_ecbl_alpha, raw_eobl_alpha)

print(spindle_comparison_df)

spindle_comparison_df.to_csv("Alpha_Spindle_Electrode_Comparison.csv", index=False)

In [None]:
spindle_comparison_df.set_index("Electrode")[["ECBL Count", "EOBL Count"]].plot(
    kind='bar', figsize=(14, 6), title="Alpha Spindle Counts per Electrode (ECBL vs EOBL)"
)
plt.ylabel("Spindle Count")
plt.grid(True)
plt.tight_layout()
plt.show()


#### 3. Wavelet Transform (Scaleogram) and Spindle Visualization: Calculate the Scaleogram (wavelet transform coefficients) for a single electrode and for the complete dataset. Visualize the Scaleogram alongside the spindle activity on the same time scale.

In [None]:
def compute_wavelet(signal, sfreq, wavelet='morl', max_scale=128):
    scales = np.arange(1, max_scale)
    coefs, freqs = pywt.cwt(signal, scales, wavelet, 1/sfreq)
    return coefs, freqs

def plot_scaleogram_with_spindles(signal, spindle_mask, sfreq, title='Scaleogram + Spindles'):
    coefs, freqs = compute_wavelet(signal, sfreq)
    time = np.arange(len(signal)) / sfreq

    plt.figure(figsize=(14, 6))
    plt.imshow(np.abs(coefs), extent=[time[0], time[-1], freqs[-1], freqs[0]],
               cmap='jet', aspect='auto')
    plt.colorbar(label='Amplitude')
    plt.title(title)
    plt.xlabel('Time (s)')
    plt.ylabel('Frequency (Hz)')

    for i in range(len(spindle_mask)):
        if spindle_mask[i] and (i == 0 or not spindle_mask[i - 1]):
            start = i
        if spindle_mask[i] and (i == len(spindle_mask) - 1 or not spindle_mask[i + 1]):
            end = i
            plt.axvspan(start / sfreq, end / sfreq, color='white', alpha=0.3)

    plt.tight_layout()
    plt.show()


In [None]:
def get_spindle_mask(signal, sfreq, threshold_factor=2.5):
    envelope = np.abs(hilbert(signal))
    threshold = threshold_factor * np.std(envelope)
    return envelope > threshold


In [None]:
raw = raw_ecbl_alpha  # or raw_eobl_alpha
channel = 'Pz'

signal = raw.get_data(picks=channel)[0]
spindle_mask = get_spindle_mask(signal, sfreq)

plot_scaleogram_with_spindles(signal, spindle_mask, sfreq,
                              title=f"ECBL: {channel} - Scaleogram with Spindles")


In [None]:
def plot_avg_wavelet_power(raw, sfreq):
    powers = []
    for ch in raw.ch_names:
        signal = raw.get_data(picks=ch)[0]
        coefs, freqs = compute_wavelet(signal, sfreq)
        power = np.mean(np.abs(coefs), axis=0)  # Avg over scales
        powers.append(power)

    avg_power = np.mean(np.array(powers), axis=0)
    time = np.arange(len(avg_power)) / sfreq

    plt.figure(figsize=(12, 4))
    plt.plot(time, avg_power)
    plt.title("Average Wavelet Power Across Electrodes")
    plt.xlabel("Time (s)")
    plt.ylabel("Power")
    plt.grid(True)
    plt.tight_layout()
    plt.show()


In [None]:
plot_avg_wavelet_power(raw_ecbl_alpha, sfreq)  # For ECBL
plot_avg_wavelet_power(raw_eobl_alpha, sfreq)  # For EOBL

#### 4. Epoch-based Visualization and Spindle Counting: Divide the complete dataset into 10-second epochs. For each epoch, visualize the spindles and the corresponding Scaleogram, while also counting the total number of spindles within each epoch.

In [None]:
def epoch_signal(signal, epoch_length_sec, sfreq):
    samples_per_epoch = int(epoch_length_sec * sfreq)
    num_epochs = len(signal) // samples_per_epoch
    return np.array_split(signal[:num_epochs * samples_per_epoch], num_epochs)

In [None]:
def count_spindles_in_epoch(epoch, sfreq, threshold_factor=2.5, min_duration=0.5):
    envelope = np.abs(hilbert(epoch))
    threshold = threshold_factor * np.std(envelope)
    above_thresh = envelope > threshold

    spindle_count = 0
    duration_samples = int(min_duration * sfreq)
    i = 0
    while i < len(above_thresh):
        if above_thresh[i]:
            start = i
            while i < len(above_thresh) and above_thresh[i]:
                i += 1
            if (i - start) >= duration_samples:
                spindle_count += 1
        else:
            i += 1
    return spindle_count, above_thresh

In [None]:
def plot_epoch_scaleogram(epoch, spindle_mask, sfreq, epoch_idx):
    coefs, freqs = compute_wavelet(epoch, sfreq)
    time = np.arange(len(epoch)) / sfreq

    plt.figure(figsize=(14, 5))
    plt.imshow(np.abs(coefs), extent=[time[0], time[-1], freqs[-1], freqs[0]],
               cmap='jet', aspect='auto')
    plt.colorbar(label='Amplitude')
    plt.title(f"Epoch {epoch_idx+1}: Scaleogram + Spindle Overlay")
    plt.xlabel("Time (s)")
    plt.ylabel("Frequency (Hz)")

    # Overlay spindle regions
    for i in range(len(spindle_mask)):
        if spindle_mask[i] and (i == 0 or not spindle_mask[i - 1]):
            start = i
        if spindle_mask[i] and (i == len(spindle_mask) - 1 or not spindle_mask[i + 1]):
            end = i
            plt.axvspan(start / sfreq, end / sfreq, color='white', alpha=0.3)

    plt.tight_layout()
    plt.show()


In [None]:
def epoch_analysis(raw, channel='Pz', epoch_duration=10):
    sfreq = int(raw.info['sfreq'])
    signal = raw.get_data(picks=channel)[0]
    epochs = epoch_signal(signal, epoch_duration, sfreq)
    
    spindle_counts = []

    for idx, epoch in enumerate(epochs):
        count, spindle_mask = count_spindles_in_epoch(epoch, sfreq)
        spindle_counts.append(count)
        print(f"Epoch {idx+1}: {count} spindles")
        plot_epoch_scaleogram(epoch, spindle_mask, sfreq, idx)

    # Summary bar plot
    plt.figure(figsize=(10, 4))
    plt.bar(np.arange(1, len(spindle_counts)+1), spindle_counts)
    plt.xlabel("Epoch Number")
    plt.ylabel("Spindle Count")
    plt.title(f"Spindle Count per 10-second Epoch - Channel: {channel}")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    return spindle_counts


In [None]:
# Make sure your raw data is filtered and ICA-cleaned
epoch_spindle_counts = epoch_analysis(raw_ecbl_alpha, channel='Pz', epoch_duration=10)