In [1]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from PyEMD import EMD
from scipy.signal import hilbert, medfilt

ModuleNotFoundError: No module named 'PyEMD'

In [None]:
# ==========================================
# 1. USER CONFIGURATION
# ==========================================
FILENAME = 'your_trf_data.pickle'  # Replace with your actual file name
FS = 1000                          # Sampling rate (Hz) - CRITICAL: Update this!
FREQ_LIMIT_CARRIER = 100           # Max frequency to plot for Carrier (X-axis)
FREQ_LIMIT_AM = 20                 # Max frequency to plot for AM (Y-axis)

# ==========================================
# 2. HELPER FUNCTIONS
# ==========================================
def get_inst_freq_amp(signal, fs):
    """
    Calculates instantaneous amplitude (envelope) and frequency 
    using the Hilbert Transform.
    """
    analytic_signal = hilbert(signal)
    amplitude = np.abs(analytic_signal)
    instantaneous_phase = np.unwrap(np.angle(analytic_signal))
    
    # Calculate Instantaneous Frequency (derivative of phase)
    # factor of 1/(2*pi) converts rad/s to Hz
    inst_freq = np.diff(instantaneous_phase) / (2.0 * np.pi) * fs
    
    # Append last sample to match array length
    inst_freq = np.append(inst_freq, inst_freq[-1])
    
    return amplitude, inst_freq

def load_trf_data(filename):
    """
    Loads TRF data from pickle. Handles simplified data structures.
    """
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    
    # LOGIC to handle different pickle structures:
    if isinstance(data, np.ndarray):
        trf = data
    elif isinstance(data, dict):
        # Checks for common keys, user might need to adjust this
        keys = [k for k in data.keys() if 'trf' in k.lower() or 'data' in k.lower()]
        if keys:
            trf = data[keys[0]]
            print(f"Loaded key '{keys[0]}' from dictionary.")
        else:
            raise ValueError(f"Could not find data in dictionary keys: {data.keys()}")
    else:
        raise ValueError("Pickle file must contain a numpy array or a dictionary.")

    # Ensure 1D array (if 2D, takes the first channel/dimension)
    if trf.ndim > 1:
        print(f"Data shape is {trf.shape}. Taking the first channel/vector.")
        trf = trf.flatten() if trf.shape[0] == 1 or trf.shape[1] == 1 else trf[0,:]
        
    return trf

# ==========================================
# 3. HHSA MAIN ALGORITHM
# ==========================================
def run_hhsa(signal, fs):
    """
    Performs Two-Layer EMD (Holo-Hilbert Spectral Analysis).
    """
    emd = EMD()
    emd.emd(signal)
    imfs_layer1 = emd.get_imfs_and_residue()[0]
    
    holo_points = [] # To store (Carrier_Freq, AM_Freq, Power) triplets

    print(f"Layer 1 Decomposition: Found {imfs_layer1.shape[0]} Carrier IMFs.")
    
    # --- LAYER 1 LOOP (Carrier Frequencies) ---
    for i, imf_c in enumerate(imfs_layer1):
        # 1. Get Carrier Envelope and Frequency
        env_c, freq_c = get_inst_freq_amp(imf_c, fs)
        
        # Filter negative/noise frequencies
        valid_mask_c = (freq_c > 0) & (freq_c < fs/2)
        
        # --- LAYER 2 LOOP (Amplitude Modulation) ---
        # Decompose the ENVELOPE of the carrier
        if np.sum(np.abs(env_c)) < 1e-10: continue # Skip empty IMFs
        
        try:
            imfs_layer2 = emd(env_c)
        except:
            continue # Skip if EMD fails on noise
            
        for j, imf_am in enumerate(imfs_layer2):
            # 2. Get AM Frequency
            _, freq_am = get_inst_freq_amp(imf_am, fs)
            
            # 3. Calculate Power (Instantaneous Energy)
            # Power is the square of the AM component amplitude
            power_am = imf_am**2
            
            # 4. Collect Triplet Points [fc, fam, power]
            # We iterate through time points to map instantaneous relationships
            mask_am = (freq_am > 0) & (freq_am < fs/2)
            combined_mask = valid_mask_c & mask_am
            
            # Add valid points to list
            # We downsample slightly if data is huge to save memory, or take all points
            idx = np.where(combined_mask)[0]
            
            if len(idx) > 0:
                # Stack: Col 0 = Carrier Freq, Col 1 = AM Freq, Col 2 = Power
                points = np.vstack((freq_c[idx], freq_am[idx], power_am[idx])).T
                holo_points.append(points)

    if not holo_points:
        return None

    # Concatenate all points from all modes
    return np.vstack(holo_points)

# ==========================================
# 4. EXECUTION AND PLOTTING
# ==========================================

# A. Load Data
try:
    trf_signal = load_trf_data(FILENAME)
    print(f"Data Loaded. Length: {len(trf_signal)} samples.")
except Exception as e:
    print(f"Error loading file: {e}")
    # CREATE DUMMY DATA FOR DEMO if file load fails
    print("Generating dummy TRF data for demonstration...")
    t = np.linspace(0, 2, 2*FS)
    # Carrier 40Hz (Gamma) modulated by 4Hz (Theta)
    trf_signal = np.sin(2*np.pi*40*t) * (1 + 0.8*np.sin(2*np.pi*4*t))

# B. Run Analysis
print("Running HHSA (this may take a moment)...")

holo_data = run_hhsa(trf_signal, FS)

# C. Plotting the Holo-Spectrum
if holo_data is not None:
    fc = holo_data[:, 0]
    fam = holo_data[:, 1]
    power = holo_data[:, 2]

    # Create 2D Histogram (The "Holo-Spectrum" Image)
    # We bin the scattered points into a grid to create a heatmap
    nbins = 100
    H, xedges, yedges = np.histogram2d(fc, fam, bins=nbins, weights=power,
                                       range=[[0, FREQ_LIMIT_CARRIER], [0, FREQ_LIMIT_AM]])

    # Transpose H for correct orientation (freq_am on Y-axis)
    H = H.T

    plt.figure(figsize=(10, 8))
    plt.imshow(H, interpolation='gaussian', origin='lower', cmap='jet', aspect='auto',
               extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]])
    
    plt.colorbar(label='Modulation Power (Energy)')
    plt.xlabel('Carrier Frequency (Hz)')
    plt.ylabel('Amplitude Modulation Frequency (Hz)')
    plt.title('Holo-Hilbert Spectrum (TRF Analysis)')
    
    # Add a diagonal line (optional, often helpful in HHSA)
    plt.plot([0, FREQ_LIMIT_AM], [0, FREQ_LIMIT_AM], 'w--', alpha=0.5, label='1:1 Line')
    plt.legend()
    plt.show()

else:
    print("Analysis failed to extract valid frequency components.")