In [None]:
import spkit as sp
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from copy import deepcopy
import mne
from scipy import signal
import os


def Load_Raw(path, sfreq=300, plot_raw=False):
    data = np.array(pd.read_csv(path, skiprows=2))
    data = np.delete(data, [0, 1, -1, -3, -4], axis=1)
    
    channel_names = ['Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 
                     'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'A1', 'A2', 'Fz', 'Cz', 'Pz', 'ECG']
    
    X = deepcopy(data)
    X = deepcopy(X[10000:-10000])
    
    ch_types = ['eeg'] * 21 + ['ecg']
    info = mne.create_info(channel_names, sfreq, ch_types=ch_types)
    raw = mne.io.RawArray(X.T, info)
    raw.set_montage('standard_1020')
    
    if plot_raw:
        raw.plot(scalings='auto')
    
    return raw


def Detect_Artifacts(raw):
    # Detect amplitude-based artifacts
    raw_data = raw.get_data(picks='eeg')
    sfreq = raw.info['sfreq']
    
    # 1. Detect high amplitude artifacts (spikes)
    peak_threshold = 100  # microvolts
    high_amp_mask = np.abs(raw_data) > peak_threshold
    high_amp_percentage = np.mean(high_amp_mask) * 100
    
    # 2. Detect flat signals (disconnected electrodes)
    flat_threshold = 0.5  # microvolts
    flat_mask = np.abs(np.diff(raw_data, axis=1)) < flat_threshold
    flat_percentage = np.mean(flat_mask) * 100
    
    # 3. Detect muscle artifacts (high frequency content)
    emg_power = []
    for ch_idx in range(raw_data.shape[0]):
        f, psd = signal.welch(raw_data[ch_idx], fs=sfreq, nperseg=int(sfreq*2))
        emg_band = np.logical_and(f >= 30, f <= 100)
        emg_power.append(np.mean(psd[emg_band]))
    
    high_emg_mask = np.array(emg_power) > np.median(emg_power) * 3
    emg_affected_channels = np.sum(high_emg_mask)
    
    # 4. Detect line noise (50/60 Hz)
    line_noise_power = []
    for ch_idx in range(raw_data.shape[0]):
        f, psd = signal.welch(raw_data[ch_idx], fs=sfreq, nperseg=int(sfreq*4))
        line_noise_mask = np.logical_or(
            np.logical_and(f >= 49, f <= 51),  # 50 Hz
            np.logical_and(f >= 59, f <= 61)   # 60 Hz
        )
        line_noise_power.append(np.mean(psd[line_noise_mask]))
    
    high_line_noise = np.array(line_noise_power) > np.median(line_noise_power) * 2
    line_noise_affected_channels = np.sum(high_line_noise)
    
    # Calculate overall artifact score (0-100, lower is better)
    artifact_score = (
        high_amp_percentage * 0.4 +
        flat_percentage * 0.3 +
        (emg_affected_channels / raw_data.shape[0]) * 100 * 0.15 +
        (line_noise_affected_channels / raw_data.shape[0]) * 100 * 0.15
    )
    
    artifact_metrics = {
        'high_amplitude_percentage': high_amp_percentage,
        'flat_signal_percentage': flat_percentage,
        'emg_affected_channels': emg_affected_channels,
        'line_noise_affected_channels': line_noise_affected_channels,
        'artifact_score': artifact_score  # Overall score, lower is better
    }
    
    return artifact_metrics


def Calculate_Signal_Quality(raw):
    raw_data = raw.get_data(picks='eeg')
    sfreq = raw.info['sfreq']
    n_channels = raw_data.shape[0]
    
    # 1. SNR estimation
    snr_values = []
    for ch_idx in range(n_channels):
        # Get PSD
        f, psd = signal.welch(raw_data[ch_idx], fs=sfreq, nperseg=int(sfreq*2))
        
        # Signal bands (of interest in most EEG studies)
        signal_mask = np.logical_or(
            np.logical_and(f >= 8, f <= 13),   # Alpha
            np.logical_and(f >= 13, f <= 30)   # Beta
        )
        
        # Noise bands (high frequency usually not of interest)
        noise_mask = f >= 100
        
        # Calculate SNR
        signal_power = np.mean(psd[signal_mask])
        noise_power = np.mean(psd[noise_mask])
        
        if noise_power > 0:
            snr = 10 * np.log10(signal_power / noise_power)
        else:
            snr = float('inf')
            
        snr_values.append(snr)
    
    # 2. Spectral quality - alpha peak presence and prominence
    alpha_peak_scores = []
    for ch_idx in range(n_channels):
        f, psd = signal.welch(raw_data[ch_idx], fs=sfreq, nperseg=int(sfreq*4))
        
        # Find alpha range
        alpha_range = np.logical_and(f >= 8, f <= 13)
        alpha_freqs = f[alpha_range]
        alpha_psd = psd[alpha_range]
        
        if len(alpha_freqs) > 0:
            # Find peak in alpha range
            peak_idx = np.argmax(alpha_psd)
            peak_freq = alpha_freqs[peak_idx]
            peak_power = alpha_psd[peak_idx]
            
            # Neighboring frequency bands
            theta_power = np.mean(psd[np.logical_and(f >= 4, f <= 7)])
            beta_power = np.mean(psd[np.logical_and(f >= 14, f <= 30)])
            
            # Calculate prominence of alpha peak
            prominence = peak_power / ((theta_power + beta_power) / 2)
            alpha_peak_scores.append(prominence)
        else:
            alpha_peak_scores.append(0)
    
    # 3. Calculate amplitude stability
    stability_scores = []
    for ch_idx in range(n_channels):
        # Split signal into 1-second windows and calculate their std
        n_seconds = int(raw_data.shape[1] / sfreq)
        window_stds = []
        
        for i in range(n_seconds):
            start_idx = int(i * sfreq)
            end_idx = int((i + 1) * sfreq)
            if end_idx <= raw_data.shape[1]:
                window = raw_data[ch_idx, start_idx:end_idx]
                window_stds.append(np.std(window))
        
        # Stability is inverse of the variability of standard deviations
        if len(window_stds) > 0:
            stability = 1 / (np.std(window_stds) / np.mean(window_stds))
            stability_scores.append(stability)
        else:
            stability_scores.append(0)
    
    # Calculate overall signal quality score (0-100, higher is better)
    avg_snr = np.mean(snr_values)
    avg_alpha_peak = np.mean(alpha_peak_scores)
    avg_stability = np.mean(stability_scores)
    
    # Normalize each component to contribute to a 0-100 scale
    norm_snr = min(100, max(0, avg_snr * 5))  # SNR of ~20dB gives 100 points
    norm_alpha = min(100, max(0, avg_alpha_peak * 25))  # alpha prominence of ~4 gives 100 points
    norm_stability = min(100, max(0, avg_stability * 10))  # stability of ~10 gives 100 points
    
    quality_score = (norm_snr * 0.4 + norm_alpha * 0.4 + norm_stability * 0.2)
    
    quality_metrics = {
        'snr': avg_snr,
        'alpha_peak_prominence': avg_alpha_peak,
        'signal_stability': avg_stability,
        'quality_score': quality_score  # Overall score, higher is better
    }
    
    return quality_metrics


def Calculate_Impedance_Estimate(raw):
    raw_copy = raw.copy()
    raw_copy.load_data()
    
    raw_noise = raw_copy.copy()
    raw_noise.notch_filter([48, 52], picks='eeg')
    raw_noise.notch_filter([58, 62], picks='eeg')
    
    raw_data = raw_copy.get_data(picks='eeg')
    noise_data = raw_noise.get_data(picks='eeg')
    line_noise = raw_data - noise_data
    
    impedance_estimate = np.std(line_noise, axis=1)
    mean_impedance = np.mean(impedance_estimate)
    
    return {
        'channel_impedance': impedance_estimate,
        'mean_impedance': mean_impedance
    }


def Evaluate_Electrode_Quality(path, electrode_type, saline_conc):
    print(f"\nEvaluating {electrode_type} with {saline_conc} saline concentration...")
    
    raw = Load_Raw(path)
    
    # Detect artifacts
    artifact_metrics = Detect_Artifacts(raw)
    
    # Calculate signal quality
    quality_metrics = Calculate_Signal_Quality(raw)
    
    # Estimate impedance
    impedance_metrics = Calculate_Impedance_Estimate(raw)
    
    # Combine all metrics
    evaluation = {
        'electrode_type': electrode_type,
        'saline_concentration': saline_conc,
        'artifact_metrics': artifact_metrics,
        'quality_metrics': quality_metrics,
        'impedance_metrics': impedance_metrics,
        'combined_score': quality_metrics['quality_score'] * (1 - artifact_metrics['artifact_score']/100)
    }
    
    # Print summary
    print(f"  Artifact score: {artifact_metrics['artifact_score']:.2f} (lower is better)")
    print(f"  Quality score: {quality_metrics['quality_score']:.2f} (higher is better)")
    print(f"  Estimated impedance: {impedance_metrics['mean_impedance']:.2f} (lower is better)")
    print(f"  Combined score: {evaluation['combined_score']:.2f} (higher is better)")
    
    return evaluation


def Compare_Electrodes(results_list):
    comparison_data = []
    
    for results in results_list:
        row = {
            'electrode_type': results['electrode_type'],
            'saline_concentration': results['saline_concentration'],
            'artifact_score': results['artifact_metrics']['artifact_score'],
            'quality_score': results['quality_metrics']['quality_score'],
            'mean_impedance': results['impedance_metrics']['mean_impedance'],
            'combined_score': results['combined_score']
        }
        comparison_data.append(row)
    
    df = pd.DataFrame(comparison_data)
    df['config'] = df['electrode_type'] + '_' + df['saline_concentration']
    
    # Create bar plots for each metric
    fig, axes = plt.subplots(4, 1, figsize=(12, 16))
    
    # Artifact score (lower is better)
    axes[0].bar(df['config'], -df['artifact_score'])
    axes[0].set_title('Artifact Score (Lower is Better)')
    axes[0].set_ylabel('Negative Score')
    axes[0].set_xlabel('Electrode Configuration')
    axes[0].tick_params(axis='x', rotation=45)
    
    # Quality score (higher is better)
    axes[1].bar(df['config'], df['quality_score'])
    axes[1].set_title('Signal Quality Score (Higher is Better)')
    axes[1].set_ylabel('Score')
    axes[1].set_xlabel('Electrode Configuration')
    axes[1].tick_params(axis='x', rotation=45)
    
    # Impedance (lower is better)
    axes[2].bar(df['config'], -df['mean_impedance'])
    axes[2].set_title('Estimated Impedance (Lower is Better)')
    axes[2].set_ylabel('Negative Impedance')
    axes[2].set_xlabel('Electrode Configuration')
    axes[2].tick_params(axis='x', rotation=45)
    
    # Combined score (higher is better)
    axes[3].bar(df['config'], df['combined_score'])
    axes[3].set_title('Combined Score (Higher is Better)')
    axes[3].set_ylabel('Score')
    axes[3].set_xlabel('Electrode Configuration')
    axes[3].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.savefig("electrode_comparison.png")
    plt.show()
    
    # Find best configuration
    best_idx = df['combined_score'].idxmax()
    best_config = df.loc[best_idx]
    
    print("\n====== BEST ELECTRODE CONFIGURATION ======")
    print(f"Type: {best_config['electrode_type']}")
    print(f"Saline Concentration: {best_config['saline_concentration']}")
    print(f"Combined Score: {best_config['combined_score']:.2f}")
    print("==========================================")
    
    return df


def main():
    electrode_types = ['felt', 'sponge']
    saline_concentrations = ['low', 'medium', 'high']
    
    all_results = []
    
    for electrode_type in electrode_types:
        for saline_conc in saline_concentrations:
            config = f"{electrode_type}_{saline_conc}"
            
            try:
                # Path to your data file - adjust as needed
                path = f"./data/{config}.csv"
                
                # Evaluate electrode quality
                evaluation = Evaluate_Electrode_Quality(path, electrode_type, saline_conc)
                all_results.append(evaluation)
                
                # Save individual results
                output_dir = "results"
                os.makedirs(output_dir, exist_ok=True)
                
                result_file = os.path.join(output_dir, f"{config}_evaluation.txt")
                with open(result_file, 'w') as f:
                    f.write(f"Evaluation of {electrode_type} with {saline_conc} saline concentration\n")
                    f.write("=" * 60 + "\n\n")
                    
                    f.write("ARTIFACT METRICS:\n")
                    for key, value in evaluation['artifact_metrics'].items():
                        f.write(f"  {key}: {value:.3f}\n")
                    
                    f.write("\nQUALITY METRICS:\n")
                    for key, value in evaluation['quality_metrics'].items():
                        f.write(f"  {key}: {value:.3f}\n")
                    
                    f.write("\nIMPEDANCE ESTIMATE:\n")
                    f.write(f"  Mean impedance: {evaluation['impedance_metrics']['mean_impedance']:.3f}\n")
                    
                    f.write("\nCOMBINED SCORE: {:.3f}\n".format(evaluation['combined_score']))
                
            except Exception as e:
                print(f"Error processing {config}: {str(e)}")
    
    # Compare all electrodes
    if len(all_results) > 0:
        comparison_df = Compare_Electrodes(all_results)
        comparison_df.to_csv("electrode_comparison_results.csv", index=False)


if __name__ == "__main__":
    main()