In [None]:
import os
import pickle
import numpy as np
from scipy import signal as scisignal #renamed to avoid conflict with data['signal']
import matplotlib.pyplot as plt

# --- Configuration ---
# !!! IMPORTANT: Replace this with the actual path to your WESAD dataset !!!
WESAD_ROOT_DIR = "../data/raw/WESAD" # e.g., "/Users/yourname/datasets/WESAD"
# Let's pick one subject for demonstration
SUBJECT_ID = "S2" # Example: S2, S3, ... S17 (S1 is usually not used)

# Define sampling rates (Hz) - refer to WESAD readme for exact values
SAMPLING_RATES = {
    'chest': {
        'ACC': 700,
        'ECG': 700,
        'EDA': 700,
        'EMG': 700,
        'RESP': 700,
        'TEMP': 700
    },
    'wrist': {
        'ACC': 32,
        'BVP': 64,
        'EDA': 4,
        'TEMP': 4
    }
}

# Target sampling rate for synchronization (example)
# Choose a rate that makes sense for your target signals.
# Often, BVP's rate (64Hz) or wrist ACC's rate (32Hz) is chosen,
# or everything is downsampled to a lower common rate like 16Hz or 32Hz.
TARGET_SAMPLING_RATE_HZ = 32 # Example: Downsample everything to 32 Hz

# Windowing parameters
WINDOW_DURATION_SEC = 10  # Example: 10-second windows
WINDOW_OVERLAP_RATIO = 0.5 # Example: 50% overlap

# --- Helper Function ---

In [6]:
def load_subject_data(subject_id, wesad_root_dir):
    """Loads the pickle file for a given subject."""
    file_path = os.path.join(wesad_root_dir, subject_id, f"{subject_id}.pkl")
    if not os.path.exists(file_path):
        print(f"Error: Data file not found for subject {subject_id} at {file_path}")
        return None
    with open(file_path, 'rb') as f:
        data = pickle.load(f, encoding='latin1') # Use 'latin1' encoding as per WESAD common practice
    return data


In [8]:
def plot_signal(signal_data, Fs, title="Signal", max_points=None):
    """Plots a signal with time on the x-axis."""
    if max_points is None:
        max_points = len(signal_data)
    
    time_axis = np.arange(len(signal_data[:max_points])) / Fs
    plt.figure(figsize=(15, 4))
    if signal_data.ndim > 1: # For multi-channel signals like ACC
        for i in range(signal_data.shape[1]):
            plt.plot(time_axis, signal_data[:max_points, i], label=f'Channel {i+1}')
        plt.legend()
    else:
        plt.plot(time_axis, signal_data[:max_points])
    plt.title(title)
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.grid(True)
    plt.show()

In [12]:
def resample_signal(signal_data, original_fs, target_fs):
    """Resamples a signal to a target sampling rate."""
    if original_fs == target_fs:
        return signal_data
    
    num_original_samples = len(signal_data)
    duration = num_original_samples / original_fs
    num_target_samples = int(duration * target_fs)
    
    if num_target_samples == 0:
        print(f"Warning: Resampling from {original_fs}Hz to {target_fs}Hz resulted in 0 target samples for signal of length {num_original_samples}. Returning original.")
        return signal_data

    resampled_data = scisignal.resample(signal_data, num_target_samples)
    return resampled_data

In [14]:
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    """Applies a Butterworth bandpass filter."""
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    # Ensure low and high are valid
    if low <= 0: low = 1e-5 # avoid zero
    if high >= 1: high = 1 - 1e-5 # avoid nyquist

    if low >= high:
        print(f"Error: Lowcut ({lowcut}Hz) must be less than Highcut ({highcut}Hz). Filter not applied.")
        return data
        
    b, a = scisignal.butter(order, [low, high], btype='band')
    y = scisignal.lfilter(b, a, data)
    return y

In [16]:
def segment_signal(signal_data, window_size_samples, overlap_samples):
    """Segments a signal into windows with overlap."""
    segments = []
    num_samples = len(signal_data)
    step_size = window_size_samples - overlap_samples
    
    for start in range(0, num_samples - window_size_samples + 1, step_size):
        end = start + window_size_samples
        segments.append(signal_data[start:end])
    return np.array(segments)

# --- Main Processing Steps ---

In [None]:
# 1. Load Data for One Subject
print(f"Loading data for subject: {SUBJECT_ID}")
subject_data = load_subject_data(SUBJECT_ID, WESAD_ROOT_DIR)

if subject_data:
    print(f"Data loaded successfully for {SUBJECT_ID}.")
    
    # 2. Explore Data Structure
    print("\n--- Data Structure ---")
    print(f"Keys in subject_data: {subject_data.keys()}")
    print(f"Keys in 'signal' data: {subject_data['signal'].keys()}")
    print(f"Keys in 'chest' signals: {subject_data['signal']['chest'].keys()}")
    print(f"Keys in 'wrist' signals: {subject_data['signal']['wrist'].keys()}")
    print(f"Shape of chest ACC: {subject_data['signal']['chest']['ACC'].shape}") # (samples, 3 channels: x,y,z)
    print(f"Shape of chest ECG: {subject_data['signal']['chest']['ECG'].shape}") # (samples, 1 channel)
    print(f"Shape of wrist BVP: {subject_data['signal']['wrist']['BVP'].shape}") # (samples, 1 channel)
    print(f"Labels available: {np.unique(subject_data['label'])}") # These are for downstream tasks

    # 3. Extract a Few Example Signals
    print("\n--- Extracting Signals ---")
    chest_ecg_raw = subject_data['signal']['chest']['ECG'].flatten() # Flatten if it's (N,1)
    wrist_bvp_raw = subject_data['signal']['wrist']['BVP'].flatten()
    wrist_acc_raw = subject_data['signal']['wrist']['ACC'] # This is 3-channel

    # Visualize raw signals (plotting a small portion for brevity)
    plot_signal(chest_ecg_raw, SAMPLING_RATES['chest']['ECG'], title=f"{SUBJECT_ID} - Raw Chest ECG (first 10s)", max_points=SAMPLING_RATES['chest']['ECG']*10)
    plot_signal(wrist_bvp_raw, SAMPLING_RATES['wrist']['BVP'], title=f"{SUBJECT_ID} - Raw Wrist BVP (first 10s)", max_points=SAMPLING_RATES['wrist']['BVP']*10)
    plot_signal(wrist_acc_raw, SAMPLING_RATES['wrist']['ACC'], title=f"{SUBJECT_ID} - Raw Wrist ACC (first 10s)", max_points=SAMPLING_RATES['wrist']['ACC']*10)

    # 4. Demonstrate Synchronization (Resampling)
    print(f"\n--- Resampling to {TARGET_SAMPLING_RATE_HZ} Hz ---")
    
    # Resample ECG
    ecg_resampled = resample_signal(chest_ecg_raw, 
                                    SAMPLING_RATES['chest']['ECG'], 
                                    TARGET_SAMPLING_RATE_HZ)
    print(f"ECG: Original length @{SAMPLING_RATES['chest']['ECG']}Hz: {len(chest_ecg_raw)}, Resampled length @{TARGET_SAMPLING_RATE_HZ}Hz: {len(ecg_resampled)}")
    plot_signal(ecg_resampled, TARGET_SAMPLING_RATE_HZ, title=f"{SUBJECT_ID} - Resampled Chest ECG to {TARGET_SAMPLING_RATE_HZ}Hz (first 10s resampled duration)", max_points=TARGET_SAMPLING_RATE_HZ*10)

    # Resample BVP (if its original rate is different from target)
    if SAMPLING_RATES['wrist']['BVP'] != TARGET_SAMPLING_RATE_HZ:
        bvp_resampled = resample_signal(wrist_bvp_raw,
                                        SAMPLING_RATES['wrist']['BVP'],
                                        TARGET_SAMPLING_RATE_HZ)
        print(f"BVP: Original length @{SAMPLING_RATES['wrist']['BVP']}Hz: {len(wrist_bvp_raw)}, Resampled length @{TARGET_SAMPLING_RATE_HZ}Hz: {len(bvp_resampled)}")
    else:
        bvp_resampled = wrist_bvp_raw # No resampling needed
        print(f"BVP: Already at target sampling rate of {TARGET_SAMPLING_RATE_HZ}Hz. Length: {len(bvp_resampled)}")
    
    # Resample Wrist ACC (multi-channel)
    # For multi-channel, resample each channel independently
    acc_resampled_channels = []
    if SAMPLING_RATES['wrist']['ACC'] != TARGET_SAMPLING_RATE_HZ:
        for i in range(wrist_acc_raw.shape[1]):
            channel_resampled = resample_signal(wrist_acc_raw[:, i],
                                                SAMPLING_RATES['wrist']['ACC'],
                                                TARGET_SAMPLING_RATE_HZ)
            acc_resampled_channels.append(channel_resampled)
        # Combine channels back. Ensure all resampled channels have the same length.
        # This can be an issue if original signals had slightly different actual durations
        # before resampling due to sensor start/stop times.
        # For safety, find min length:
        min_len = min(len(ch) for ch in acc_resampled_channels)
        acc_resampled = np.array([ch[:min_len] for ch in acc_resampled_channels]).T

        print(f"Wrist ACC: Original length @{SAMPLING_RATES['wrist']['ACC']}Hz: {wrist_acc_raw.shape[0]}, Resampled length @{TARGET_SAMPLING_RATE_HZ}Hz: {acc_resampled.shape[0]}")
    else:
        acc_resampled = wrist_acc_raw
        print(f"Wrist ACC: Already at target sampling rate of {TARGET_SAMPLING_RATE_HZ}Hz. Shape: {acc_resampled.shape}")


    # 5. Demonstrate Filtering (on one resampled signal)
    print("\n--- Filtering Example (BVP) ---")
    # BVP is often bandpass filtered to focus on heart rate frequencies (e.g., 0.5 Hz to 4 Hz)
    # (0.5 Hz = 30 bpm, 4 Hz = 240 bpm)
    bvp_filtered = butter_bandpass_filter(bvp_resampled, 
                                          lowcut=0.5, highcut=4.0, 
                                          fs=TARGET_SAMPLING_RATE_HZ, order=3)
    plot_signal(bvp_resampled, TARGET_SAMPLING_RATE_HZ, title=f"{SUBJECT_ID} - Resampled BVP (first 30s)", max_points=TARGET_SAMPLING_RATE_HZ*30)
    plot_signal(bvp_filtered, TARGET_SAMPLING_RATE_HZ, title=f"{SUBJECT_ID} - Filtered BVP (0.5-4.0 Hz) (first 30s)", max_points=TARGET_SAMPLING_RATE_HZ*30)

    # 6. Demonstrate Segmentation (on the filtered BVP)
    print("\n--- Segmentation Example (Filtered BVP) ---")
    window_size_samples = int(WINDOW_DURATION_SEC * TARGET_SAMPLING_RATE_HZ)
    overlap_samples = int(window_size_samples * WINDOW_OVERLAP_RATIO)
    
    bvp_segments = segment_signal(bvp_filtered, window_size_samples, overlap_samples)
    if bvp_segments.size > 0:
        print(f"Segmented BVP into {bvp_segments.shape[0]} windows of size {bvp_segments.shape[1]} samples.")
        # Plot the first segment
        plot_signal(bvp_segments[0], TARGET_SAMPLING_RATE_HZ, title=f"{SUBJECT_ID} - First BVP Segment")
    else:
        print("Not enough data to create BVP segments with current parameters.")

    # 7. Demonstrate Normalization (on BVP segments)
    print("\n--- Normalization Example (BVP Segments) ---")
    if bvp_segments.size > 0:
        # For proper SEL, normalization parameters (mean, std) should be derived
        # from the *training set subjects only* and then applied to all.
        # Here, for demonstration, we'll normalize based on the current subject's segments.
        
        # Option 1: Normalize each segment independently (local normalization)
        # bvp_segments_normalized_local = [(seg - np.mean(seg)) / (np.std(seg) + 1e-6) for seg in bvp_segments]
        
        # Option 2: Normalize based on global mean/std of all segments from this subject
        mean_bvp_for_subject = np.mean(bvp_segments) # Or np.mean(bvp_filtered) before segmentation
        std_bvp_for_subject = np.std(bvp_segments)   # Or np.std(bvp_filtered) before segmentation
        
        bvp_segments_normalized_global = (bvp_segments - mean_bvp_for_subject) / (std_bvp_for_subject + 1e-6) # Add epsilon to avoid division by zero
        
        print(f"Normalized BVP segments. Mean of first normalized segment: {np.mean(bvp_segments_normalized_global[0]):.2f}, Std: {np.std(bvp_segments_normalized_global[0]):.2f}")
        plot_signal(bvp_segments_normalized_global[0], TARGET_SAMPLING_RATE_HZ, title=f"{SUBJECT_ID} - First Normalized BVP Segment")
    else:
        print("No BVP segments to normalize.")

    print("\n--- Further Steps for SEL ---")
    print("1. Process ALL desired signals for ALL training subjects using similar steps: load, resample, filter, segment, normalize.")
    print("2. Combine segments from different modalities (e.g., concatenate ECG, EDA, ACC windows for the same time period). Ensure they are perfectly aligned after resampling and segmentation.")
    print("3. Prepare these multimodal segments for your chosen SEL pretext task (e.g., creating augmented pairs for contrastive learning, defining input/output for predictive coding).")
    print("4. Split subjects into train/validation/test sets for evaluating the downstream task after SEL pre-training.")

else:
    print(f"Could not proceed as data for {SUBJECT_ID} was not loaded.")
    print(f"Please ensure WESAD_ROOT_DIR ('{WESAD_ROOT_DIR}') is correctly set and contains the WESAD subject folders (S2, S3, etc.).")