In [1]:
from scipy.io import loadmat
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import spkit
from scipy.fft import fft
from scipy.signal import welch
import logging

In [2]:
# Some random variables
BLOCKS = 20
RUNS = 10
EVENTS = 8
DATA_FOLDER = 'data'


In [3]:
def remove_artifacts_with_atar(signals):
    """
    Apply ATAR (Artifact Removal) algorithm to the signals.
    Args:
        signals: 2D numpy array of shape (channels, timepoints) for 1 trial/block
    Returns:
        Cleaned signals: 2D numpy array of the same shape after artifact removal
    """
    cleaned_signals = np.zeros_like(signals)
    # If it's a 2D array, just apply ATAR across it
    cleaned_signals = spkit.eeg.ATAR(signals, method='adaptive')
    
    return cleaned_signals


In [4]:
def process_session_signals(session_path):
    # Load signals and targets
    train_data_signals = loadmat(os.path.join(session_path, 'trainData.mat'))['trainData']
    train_data_targets = np.loadtxt(os.path.join(session_path, 'trainTargets.txt'))

    # Apply ATAR artifact removal and filter signals
    filtered_signals = train_data_signals[:, :, train_data_targets[:] == 1]

    # Calculate mean of events every 10 runs
    filtered_signals = np.mean(
        filtered_signals.reshape(filtered_signals.shape[0], filtered_signals.shape[1], -1, RUNS),
        axis=3
    )

    # Mean across all epochs
    processed_signals = np.mean(filtered_signals, axis=1)

    return processed_signals


In [5]:
participants = [f'SBJ{i:02d}' for i in range(1, 16)]  

# Initialize an empty list to store data for all participants
all_participants_data = []

for participant in participants:
    print(f'Processing participant: {participant}')
    participant_folder = os.path.join(DATA_FOLDER, participant)

    # Iterate over all sessions for the current participant
    session_folders = [os.path.join(participant_folder, f'S0{i}/Train') for i in range(1, 8)]

    participant_data = np.zeros((EVENTS, BLOCKS))

    for session_path in session_folders:
        session_data = process_session_signals(session_path)
        participant_data += session_data

    # Take the mean across all sessions for the participant
    participant_data /= len(session_folders)

    # Apply ATAR (Artifact Removal) algorithm to the signals.
    participant_data = remove_artifacts_with_atar(participant_data)

    # Append the participant data
    all_participants_data.append(participant_data)

final_data = np.array(all_participants_data)
print("Shape of final aggregated data: ", final_data.shape)


Processing participant: SBJ01




Processing participant: SBJ02
Processing participant: SBJ03
Processing participant: SBJ04
Processing participant: SBJ05
Processing participant: SBJ06
Processing participant: SBJ07
Processing participant: SBJ08
Processing participant: SBJ09
Processing participant: SBJ10
Processing participant: SBJ11
Processing participant: SBJ12
Processing participant: SBJ13
Processing participant: SBJ14
Processing participant: SBJ15
Shape of final aggregated data:  (15, 8, 20)


In [6]:
# Define the frequency bands
freq_bands = [(0, 4), (4, 8), (8, 13), (13, 30), (30, 50)]  # Delta, Theta, Alpha, Beta, Gamma

def extract_frequency_domain_features(signals):
    """
    Extract frequency-domain features using FFT.
    
    Args:
        signals: 3D numpy array of shape (channels, timepoints, trials)
    
    Returns:
        freq_features: 2D numpy array of shape (channels, len(freq_bands))
    """
    freq_features = np.zeros((signals.shape[0], len(freq_bands)))
    
    for ch in range(signals.shape[0]):
        for trial in range(signals.shape[2]):
            signal = signals[ch, :, trial]
            
            # Apply FFT to the signal
            fft_result = np.abs(np.fft.fft(signal))
            
            # Average power in each frequency band
            for i, (f_start, f_end) in enumerate(freq_bands):
                band_power = np.mean(fft_result[int(f_start):int(f_end)])
                freq_features[ch, i] += band_power
                
    freq_features /= signals.shape[2]
    
    return freq_features


In [7]:
all_participants_features = []

for participant_id in range(15):  # Assuming 15 participants
    # Simulated participant data
    participant_data = np.random.rand(8, 256, 20)  # (8 channels, 256 timepoints, 20 trials)
    
    # Extract frequency-domain features
    try:
        freq_features = extract_frequency_domain_features(participant_data)
        all_participants_features.append(freq_features)
        print(f"  Frequency-domain features shape for participant {participant_id + 1}: {freq_features.shape}")
    except Exception as e:
        print(f"  Error processing participant {participant_id + 1}: {e}")
        
all_features_array = np.stack(all_participants_features)


  Frequency-domain features shape for participant 1: (8, 5)
  Frequency-domain features shape for participant 2: (8, 5)
  Frequency-domain features shape for participant 3: (8, 5)
  Frequency-domain features shape for participant 4: (8, 5)
  Frequency-domain features shape for participant 5: (8, 5)
  Frequency-domain features shape for participant 6: (8, 5)
  Frequency-domain features shape for participant 7: (8, 5)
  Frequency-domain features shape for participant 8: (8, 5)
  Frequency-domain features shape for participant 9: (8, 5)
  Frequency-domain features shape for participant 10: (8, 5)
  Frequency-domain features shape for participant 11: (8, 5)
  Frequency-domain features shape for participant 12: (8, 5)
  Frequency-domain features shape for participant 13: (8, 5)
  Frequency-domain features shape for participant 14: (8, 5)
  Frequency-domain features shape for participant 15: (8, 5)


In [12]:
def hjorth_parameters(signal):
    """
    Calculate Hjorth activity, mobility, and complexity for a signal.
    """
    variance = np.var(signal)
    diff_signal = np.diff(signal)
    activity = variance
    mobility = np.std(diff_signal) / np.std(signal)
    complexity = mobility / (np.std(np.diff(diff_signal)) / np.std(diff_signal))
    return activity, mobility, complexity

from scipy.stats import skew, kurtosis

def calculate_skewness_kurtosis(signal):
    """
    Calculate skewness and kurtosis for a signal.
    """
    return skew(signal), kurtosis(signal)

def calculate_peak_to_peak(signal):
    """
    Calculate the peak-to-peak value for a signal.
    """
    return np.ptp(signal)


In [13]:
# Define the column headers
columns = [
    "participant_id", "channel", "Autistic",
    "mean", "variance", "rms", "hjorth_activity", "hjorth_mobility", 
    "hjorth_complexity", "skewness", "kurtosis", "peak_to_peak", 
    "delta_power", "theta_power", "alpha_power", "beta_power", 
    "spectral_entropy"
]

# Prepare data for saving
csv_data = []
participant_ids = [f"SBJ{i:02d}" for i in range(1, 16)]

for participant_index, participant_data in enumerate(all_participants_features):
    participant_id = participant_ids[participant_index]
    autistic_label = 1  # Adjust based on your dataset
    
    for channel_index in range(participant_data.shape[0]):  # Iterate over channels
        # Combine time-dependent and frequency-domain features
        time_features = []

        signal = participant_data[channel_index]
        mean = np.mean(signal)
        variance = np.var(signal)
        rms = np.sqrt(np.mean(signal**2))
        hjorth_activity, hjorth_mobility, hjorth_complexity = hjorth_parameters(signal)
        skewness, kurtosis_value = calculate_skewness_kurtosis(signal)
        peak_to_peak = calculate_peak_to_peak(signal)

        time_features = [mean, variance, rms, hjorth_activity, hjorth_mobility, hjorth_complexity, skewness, kurtosis_value, peak_to_peak]

        
        # Frequency-domain features
        freq_features = participant_data[channel_index]  # Directly from all_participants_features
        
        # Combine all features
        channel_features = time_features + list(freq_features)

        # Check for feature length consistency
        if len(channel_features) != len(columns) - 3:  # 3 non-feature columns (ID, channel, label)
            print(f"Feature length mismatch for participant {participant_id}, channel {channel_index}: {len(channel_features)}")
        
        # Append the features to the csv_data
        csv_data.append([
            participant_id, 
            channel_index, 
            autistic_label, 
            *channel_features  # Combine all features
        ])

# Ensure consistency in the data structure
for row in csv_data:
    assert len(row) == len(columns), f"Row length mismatch: {len(row)} vs {len(columns)}"

# Create a DataFrame and save to CSV
df = pd.DataFrame(csv_data, columns=columns)
output_path = "participant_features.csv"
df.to_csv(output_path, index=False)

print(f"Features saved to {output_path}")


Features saved to participant_features.csv
