In [2]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
#from PyEMD import EMD
from scipy.signal import hilbert, medfilt

from pathlib import Path
import re
from pprint import pprint

import numpy as np
from scipy import signal, ndimage
import matplotlib.pyplot as plt
from matplotlib import pyplot
import pandas as pd

import emd
import eelbrain
import mne

  from mne.io.pick import pick_types, pick_channels
  from mne.io.pick import pick_types, pick_channels


In [4]:
### (OLD FUNCTION)
# ==========================================
# 2. HELPER FUNCTIONS
# ==========================================
def get_inst_freq_amp(signal, fs):
    """
    Calculates instantaneous amplitude (envelope) and frequency 
    using the Hilbert Transform.
    """
    analytic_signal = hilbert(signal)
    amplitude = np.abs(analytic_signal)
    instantaneous_phase = np.unwrap(np.angle(analytic_signal))
    
    # Calculate Instantaneous Frequency (derivative of phase)
    # factor of 1/(2*pi) converts rad/s to Hz
    inst_freq = np.diff(instantaneous_phase) / (2.0 * np.pi) * fs
    
    # Append last sample to match array length
    inst_freq = np.append(inst_freq, inst_freq[-1])
    
    return amplitude, inst_freq

def load_trf_data(filename):
    """
    Loads TRF data from pickle. Handles simplified data structures.
    """
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    
    # LOGIC to handle different pickle structures:
    if isinstance(data, np.ndarray):
        trf = data
    elif isinstance(data, dict):
        # Checks for common keys, user might need to adjust this
        keys = [k for k in data.keys() if 'trf' in k.lower() or 'data' in k.lower()]
        if keys:
            trf = data[keys[0]]
            print(f"Loaded key '{keys[0]}' from dictionary.")
        else:
            raise ValueError(f"Could not find data in dictionary keys: {data.keys()}")
    else:
        raise ValueError("Pickle file must contain a numpy array or a dictionary.")

    # Ensure 1D array (if 2D, takes the first channel/dimension)
    if trf.ndim > 1:
        print(f"Data shape is {trf.shape}. Taking the first channel/vector.")
        trf = trf.flatten() if trf.shape[0] == 1 or trf.shape[1] == 1 else trf[0,:]
        
    return trf

# ==========================================
# 3. HHSA MAIN ALGORITHM
# ==========================================
def run_hhsa(signal, fs):
    """
    Performs Two-Layer EMD (Holo-Hilbert Spectral Analysis).
    """
    emd = EMD()
    emd.emd(signal)
    imfs_layer1 = emd.get_imfs_and_residue()[0]
    
    holo_points = [] # To store (Carrier_Freq, AM_Freq, Power) triplets

    print(f"Layer 1 Decomposition: Found {imfs_layer1.shape[0]} Carrier IMFs.")
    
    # --- LAYER 1 LOOP (Carrier Frequencies) ---
    for i, imf_c in enumerate(imfs_layer1):
        # 1. Get Carrier Envelope and Frequency
        env_c, freq_c = get_inst_freq_amp(imf_c, fs)
        
        # Filter negative/noise frequencies
        valid_mask_c = (freq_c > 0) & (freq_c < fs/2)
        
        # --- LAYER 2 LOOP (Amplitude Modulation) ---
        # Decompose the ENVELOPE of the carrier
        if np.sum(np.abs(env_c)) < 1e-10: continue # Skip empty IMFs
        
        try:
            imfs_layer2 = emd(env_c)
        except:
            continue # Skip if EMD fails on noise
            
        for j, imf_am in enumerate(imfs_layer2):
            # 2. Get AM Frequency
            _, freq_am = get_inst_freq_amp(imf_am, fs)
            
            # 3. Calculate Power (Instantaneous Energy)
            # Power is the square of the AM component amplitude
            power_am = imf_am**2
            
            # 4. Collect Triplet Points [fc, fam, power]
            # We iterate through time points to map instantaneous relationships
            mask_am = (freq_am > 0) & (freq_am < fs/2)
            combined_mask = valid_mask_c & mask_am
            
            # Add valid points to list
            # We downsample slightly if data is huge to save memory, or take all points
            idx = np.where(combined_mask)[0]
            
            if len(idx) > 0:
                # Stack: Col 0 = Carrier Freq, Col 1 = AM Freq, Col 2 = Power
                points = np.vstack((freq_c[idx], freq_am[idx], power_am[idx])).T
                holo_points.append(points)

    if not holo_points:
        return None

    # Concatenate all points from all modes
    return np.vstack(holo_points)

In [7]:
if __name__ == "__main__":

    ## ESLs ##
    #STIMULI = [str(i) for i in range(1, 13)]
    #DATA_ROOT = Path("/Volumes/Neurolang_1/Master Program/New_Thesis_topic/Experiments_Results")  #Path("~").expanduser() / 'Data' / 'Alice'
    DATA_ROOT = Path("/Users/neuroling/Downloads/DINGHSIN_Results/Alice_Experiments_Results")
    #PREDICTOR_audio_DIR = DATA_ROOT / 'TRFs_pridictors/audio_predictors'
    #PREDICTOR_word_DIR = DATA_ROOT / 'TRFs_pridictors/word_predictors'
    EEG_DIR = DATA_ROOT / 'EEG_ESLs' / 'Alice_ESL_ICAed_fif'
    IMF_DIR = DATA_ROOT/ "TRFs_pridictors/IF_predictors"
    F0_DIR = DATA_ROOT/ "TRFs_pridictors/F0_predictors"
    IMFsLIST = [path.name for path in IMF_DIR.iterdir() if re.match(r'Alice_IF_IMF_*', path.name)] 
    ESL_SUBJECTS = [path.name for path in EEG_DIR.iterdir() if re.match(r'n_2_S\d*', path.name)]  #S01_alice-raw.fif
    
    # Define a target directory for TRF estimates and make sure the directory is created
    TRF_DIR = DATA_ROOT / 'TRFs_ESLs'
    TRF_DIR.mkdir(exist_ok=True)
    print(ESL_SUBJECTS)
    print(len(ESL_SUBJECTS))  # 26
    DST = TRF_DIR / 'ESLs_figures'
    DST.mkdir(exist_ok=True)

['n_2_S030_ICAed_raw.fif', 'n_2_S027_ICAed_raw.fif', 'n_2_S023_ICAed_raw.fif', 'n_2_S034_ICAed_raw.fif', 'n_2_S024_ICAed_raw.fif', 'n_2_S019_ICAed_raw.fif', 'n_2_S020_ICAed_raw.fif', 'n_2_S013_ICAed_raw.fif', 'n_2_S017_ICAed_raw.fif', 'n_2_S039_ICAed_raw.fif', 'n_2_S010_ICAed_raw.fif', 'n_2_S029_ICAed_raw.fif', 'n_2_S015_ICAed_raw.fif', 'n_2_S028_ICAed_raw.fif', 'n_2_S011_ICAed_raw.fif', 'n_2_S038_ICAed_raw.fif', 'n_2_S016_ICAed_raw.fif', 'n_2_S012_ICAed_raw.fif', 'n_2_S021_ICAed_raw.fif', 'n_2_S036_ICAed_raw.fif', 'n_2_S032_ICAed_raw.fif', 'n_2_S025_ICAed_raw.fif', 'n_2_S035_ICAed_raw.fif', 'n_2_S022_ICAed_raw.fif', 'n_2_S026_ICAed_raw.fif', 'n_2_S031_ICAed_raw.fif']
26


In [9]:
# Load in the TRF .pickle file (I choose F0 & En & EnOnset)

# Compare the TRFs corresponding to F0 and Envelopes.
# Build the model_data['model']
models = ['Fzero', 'Fzero+envelope', 'Fzero+envelope+env_onset'] # Change the description
rows = []
for model in models:
    print(model)
    for subject in ESL_SUBJECTS:
        trf = eelbrain.load.unpickle(TRF_DIR / subject[4:8] / f'{subject[4:8]} {model}.pickle')
        #print(trf)
        trf.x = ['f0', 'f0env', 'f0envenvon'] #rename the model (no punctuation)
        rows.append([n_subj, model, *trf.h_scaled])
        #rows.append([subject, model, trf.proportion_explained])

model_data = eelbrain.Dataset.from_caselist(['subject', 'model', 'det'], rows)
#pprint(rows)
print(model_data)

Fzero
Fzero+envelope
Fzero+envelope+env_onset
subject                  model                   
-------------------------------------------------
n_2_S030_ICAed_raw.fif   Fzero                   
n_2_S027_ICAed_raw.fif   Fzero                   
n_2_S023_ICAed_raw.fif   Fzero                   
n_2_S034_ICAed_raw.fif   Fzero                   
n_2_S024_ICAed_raw.fif   Fzero                   
n_2_S019_ICAed_raw.fif   Fzero                   
n_2_S020_ICAed_raw.fif   Fzero                   
n_2_S013_ICAed_raw.fif   Fzero                   
n_2_S017_ICAed_raw.fif   Fzero                   
n_2_S039_ICAed_raw.fif   Fzero                   
n_2_S010_ICAed_raw.fif   Fzero                   
n_2_S029_ICAed_raw.fif   Fzero                   
n_2_S015_ICAed_raw.fif   Fzero                   
n_2_S028_ICAed_raw.fif   Fzero                   
n_2_S011_ICAed_raw.fif   Fzero                   
n_2_S038_ICAed_raw.fif   Fzero                   
n_2_S016_ICAed_raw.fif   Fzero                   
n_2_

In [10]:
# --- 1. HHSA HELPER FUNCTIONS (Same as before) ---
def get_inst_freq_amp(signal, fs):
    analytic_signal = hilbert(signal)
    amplitude = np.abs(analytic_signal)
    instantaneous_phase = np.unwrap(np.angle(analytic_signal))
    inst_freq = np.diff(instantaneous_phase) / (2.0 * np.pi) * fs
    inst_freq = np.append(inst_freq, inst_freq[-1])
    return amplitude, inst_freq

def run_hhsa(signal, fs):
    # Standardize signal shape
    if signal.ndim > 1: signal = signal.flatten()
    
    emd = EMD()
    # EMD can fail on very short/flat signals, so we use try/except
    try:
        imfs_layer1 = emd(signal)
    except:
        return None
        
    holo_points = []
    
    # Layer 1: Carrier
    for imf_c in imfs_layer1:
        env_c, freq_c = get_inst_freq_amp(imf_c, fs)
        if np.sum(np.abs(env_c)) < 1e-10: continue

        # Layer 2: Amplitude Modulation
        try:
            imfs_layer2 = emd(env_c)
        except:
            continue
            
        for imf_am in imfs_layer2:
            _, freq_am = get_inst_freq_amp(imf_am, fs)
            power_am = imf_am**2
            
            # Filter valid frequencies
            mask = (freq_c > 0) & (freq_c < fs/2) & (freq_am > 0) & (freq_am < fs/2)
            idx = np.where(mask)[0]
            
            if len(idx) > 0:
                points = np.vstack((freq_c[idx], freq_am[idx], power_am[idx])).T
                holo_points.append(points)
                
    if not holo_points: return None
    return np.vstack(holo_points)

In [11]:
# --- 2. YOUR LOADING & PROCESSING LOOP ---

# CONFIGURATION
FS = 100 # Check your TRF sampling rate! (Eelbrain usually downsamples, e.g., 100Hz or 128Hz)
TARGET_MODEL = 'Fzero+envelope' # Choose ONE model to analyze for now
TARGET_PREDICTOR = 'envelope'   # Which part of the model do you want? 'f0' or 'envelope'?

# Accumulator for Group Average
group_spectrum_sum = None
n_subjects = 0
nbins = 100
LIMIT_CARRIER = 20 # Hz (TRFs usually don't have high gamma, mostly delta/theta/alpha)
LIMIT_AM = 10      # Hz

print(f"Starting HHSA on {TARGET_MODEL} -> {TARGET_PREDICTOR}...")

for subject in ESL_SUBJECTS:
    subject_id = subject[4:8]
    file_path = TRF_DIR / subject_id / f'{subject_id} {TARGET_MODEL}.pickle'
    
    try:
        # A. Load the Eelbrain object
        # This loads the BoostingResult object
        trf_obj = eelbrain.load.unpickle(file_path)
        
        # B. Extract the Time Series (The critical fix!)
        # trf_obj.h is the NDVar containing the response
        # We assume trf_obj.h structure is [Predictor, Sensor, Time] or [Predictor, Time]
        
        # 1. Select the specific predictor (e.g., 'envelope')
        # We find the index of the predictor name
        if hasattr(trf_obj.h, 'x'): 
            # Check if dimensions usually include Sensor
            # If you want to Average across all sensors (common for TRF analysis):
            trf_avg_sensors = trf_obj.h.mean('sensor') 
            
            # Select the predictor
            trf_predictor = trf_avg_sensors.sub(name=TARGET_PREDICTOR)
            
            # Convert to pure Numpy array for analysis
            trf_signal = trf_predictor.x 
            
            # Ensure it's 1D (Time only)
            if trf_signal.ndim > 1:
                trf_signal = trf_signal.flatten()
                
        else:
            print(f"Skipping {subject}: Unexpected TRF structure.")
            continue

        # C. Run HHSA on this subject's TRF
        print(f"  Processing {subject}...")
        holo_data = run_hhsa(trf_signal, FS)
        
        if holo_data is None: continue

        # D. Bin into Histogram (Image)
        fc = holo_data[:, 0]
        fam = holo_data[:, 1]
        power = holo_data[:, 2]
        
        H_subj, xedges, yedges = np.histogram2d(fc, fam, bins=nbins, weights=power,
                                           range=[[0, LIMIT_CARRIER], [0, LIMIT_AM]])
        
        # Accumulate
        if group_spectrum_sum is None:
            group_spectrum_sum = H_subj
        else:
            group_spectrum_sum += H_subj
        n_subjects += 1
        
    except Exception as e:
        print(f"  Error loading {subject}: {e}")
        continue

# --- 3. PLOT GRAND AVERAGE ---
if n_subjects > 0:
    group_avg = group_spectrum_sum / n_subjects
    
    plt.figure(figsize=(10, 8))
    plt.imshow(group_avg.T, origin='lower', cmap='jet', aspect='auto',
               extent=[0, LIMIT_CARRIER, 0, LIMIT_AM], interpolation='gaussian')
    plt.colorbar(label='Modulation Power')
    plt.xlabel('Carrier Frequency (Hz)')
    plt.ylabel('AM Frequency (Hz)')
    plt.title(f'Group HHSA of TRF ({TARGET_PREDICTOR}) - N={n_subjects}')
    plt.plot([0, LIMIT_AM], [0, LIMIT_AM], 'w--', alpha=0.5) # Diagonal
    plt.show()
else:
    print("No subjects processed.")

Starting HHSA on Fzero+envelope -> envelope...
Skipping n_2_S030_ICAed_raw.fif: Unexpected TRF structure.
Skipping n_2_S027_ICAed_raw.fif: Unexpected TRF structure.
Skipping n_2_S023_ICAed_raw.fif: Unexpected TRF structure.
Skipping n_2_S034_ICAed_raw.fif: Unexpected TRF structure.
Skipping n_2_S024_ICAed_raw.fif: Unexpected TRF structure.
Skipping n_2_S019_ICAed_raw.fif: Unexpected TRF structure.
Skipping n_2_S020_ICAed_raw.fif: Unexpected TRF structure.
Skipping n_2_S013_ICAed_raw.fif: Unexpected TRF structure.
Skipping n_2_S017_ICAed_raw.fif: Unexpected TRF structure.
Skipping n_2_S039_ICAed_raw.fif: Unexpected TRF structure.
Skipping n_2_S010_ICAed_raw.fif: Unexpected TRF structure.
Skipping n_2_S029_ICAed_raw.fif: Unexpected TRF structure.
Skipping n_2_S015_ICAed_raw.fif: Unexpected TRF structure.
Skipping n_2_S028_ICAed_raw.fif: Unexpected TRF structure.
Skipping n_2_S011_ICAed_raw.fif: Unexpected TRF structure.
Skipping n_2_S038_ICAed_raw.fif: Unexpected TRF structure.
Skipping 

In [None]:
# ==========================================
# 4. EXECUTION AND PLOTTING
# ==========================================

# A. Load Data
try:
    trf_signal = load_trf_data(FILENAME)
    print(f"Data Loaded. Length: {len(trf_signal)} samples.")
except Exception as e:
    print(f"Error loading file: {e}")
    # CREATE DUMMY DATA FOR DEMO if file load fails
    print("Generating dummy TRF data for demonstration...")
    t = np.linspace(0, 2, 2*FS)
    # Carrier 40Hz (Gamma) modulated by 4Hz (Theta)
    trf_signal = np.sin(2*np.pi*40*t) * (1 + 0.8*np.sin(2*np.pi*4*t))

# B. Run Analysis
print("Running HHSA (this may take a moment)...")

holo_data = run_hhsa(trf_signal, FS)

# C. Plotting the Holo-Spectrum
if holo_data is not None:
    fc = holo_data[:, 0]
    fam = holo_data[:, 1]
    power = holo_data[:, 2]

    # Create 2D Histogram (The "Holo-Spectrum" Image)
    # We bin the scattered points into a grid to create a heatmap
    nbins = 100
    H, xedges, yedges = np.histogram2d(fc, fam, bins=nbins, weights=power,
                                       range=[[0, FREQ_LIMIT_CARRIER], [0, FREQ_LIMIT_AM]])

    # Transpose H for correct orientation (freq_am on Y-axis)
    H = H.T

    plt.figure(figsize=(10, 8))
    plt.imshow(H, interpolation='gaussian', origin='lower', cmap='jet', aspect='auto',
               extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]])
    
    plt.colorbar(label='Modulation Power (Energy)')
    plt.xlabel('Carrier Frequency (Hz)')
    plt.ylabel('Amplitude Modulation Frequency (Hz)')
    plt.title('Holo-Hilbert Spectrum (TRF Analysis)')
    
    # Add a diagonal line (optional, often helpful in HHSA)
    plt.plot([0, FREQ_LIMIT_AM], [0, FREQ_LIMIT_AM], 'w--', alpha=0.5, label='1:1 Line')
    plt.legend()
    plt.show()

else:
    print("Analysis failed to extract valid frequency components."