In [1]:

!pip install mne scipy numpy matplotlib torch --quiet

import torch
import numpy as np
import mne
from scipy.signal import butter, filtfilt, spectrogram, stft
import matplotlib.pyplot as plt

pth_path = '/content/eeg_signals_raw_with_mean_std.pth'
eeg_data = torch.load(pth_path)
print(" loaded.")


FS = 1000
FREQ_RANGE = (5, 95)
TRIM_MS = 0
CHANNELS = 128
TOKEN_GROUP = 4
PROJ_DIM = 1024


raw_lengths = [
    seg['eeg'].shape[1]
    for subj in eeg_data['dataset'].values()
    for seg in subj if isinstance(seg, dict) and 'eeg' in seg
]
median_len = int(np.median(raw_lengths))
median_len -= median_len % 4
print(f" Median EEG segment length: {median_len}")


def bandpass_filter(signal, low, high, fs, order=5):
    b, a = butter(order, [low / (fs/2), high / (fs/2)], btype='band')
    return filtfilt(b, a, signal, axis=-1)

def trim_and_pad(signal, target_len):
    if signal.shape[1] < target_len:
        pad_width = target_len - signal.shape[1]
        signal = np.pad(signal, ((0, 0), (0, pad_width)), mode='constant')
    return signal[:, :target_len]

def normalize(signal):
    mean = signal.mean(axis=1, keepdims=True)
    std = signal.std(axis=1, keepdims=True) + 1e-8
    return (signal - mean) / std

def pad_channels(signal, target_channels=128):
    c, t = signal.shape
    if c < target_channels:
        repeats = (target_channels + c - 1) // c
        signal = np.tile(signal, (repeats, 1))[:target_channels, :]
    return signal

def temporal_tokenize(signal, group_size=4):
    c, t = signal.shape
    t_new = t // group_size
    signal = signal[:, :t_new * group_size]
    return signal.reshape(c, t_new, group_size).mean(axis=2)

def linear_project(token_tensor, out_dim=1024):
    W = np.random.randn(token_tensor.shape[1], out_dim)
    return token_tensor @ W

def plot_psd(signal, fs=1000, subject_id='', seg_idx=0):
    info = mne.create_info(ch_names=[f"ch{i}" for i in range(signal.shape[0])], sfreq=fs, ch_types="eeg")
    raw = mne.io.RawArray(signal, info)
    psds, freqs = mne.time_frequency.psd_array_welch(signal, sfreq=fs, fmin=1, fmax=100, n_fft=256)

    plt.figure(figsize=(10, 4))
    plt.plot(freqs, psds[0], label='Channel 0')
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Power Spectral Density")
    plt.title(f"PSD (Subject {subject_id}, Segment {seg_idx})")
    plt.grid(True)
    plt.legend()
    plt.show()

def plot_spectrogram(signal, fs=1000, subject_id='', seg_idx=0):
    f, t, Zxx = stft(signal[0], fs=fs, nperseg=128)
    power_db = 10 * np.log10(np.abs(Zxx)**2 + 1e-8)

    plt.figure(figsize=(10, 4))
    plt.pcolormesh(t, f, power_db, shading='gouraud', cmap='viridis')
    plt.ylabel('Frequency [Hz]')
    plt.xlabel('Time [sec]')
    plt.colorbar(label='Power [dB]')
    plt.title(f"Spectrogram (Subject {subject_id}, Segment {seg_idx}, Channel 0)")
    plt.tight_layout()
    plt.show()


all_subject_embeddings = {}

for subject_id, segments in eeg_data['dataset'].items():
    print(f" Processing Subject {subject_id} ")
    subject_embeddings = []

    for seg_idx, segment in enumerate(segments):
        if isinstance(segment, dict) and 'eeg' in segment:
            signal = segment['eeg'].numpy()


            filtered = bandpass_filter(signal, FREQ_RANGE[0], FREQ_RANGE[1], FS)
            trimmed = trim_and_pad(filtered, target_len=median_len)
            normalized = normalize(trimmed)
            padded = pad_channels(normalized, target_channels=CHANNELS)
            tokenized = temporal_tokenize(padded, group_size=TOKEN_GROUP)


            token_tensor = tokenized.T
            embedded = linear_project(token_tensor, out_dim=PROJ_DIM)
            subject_embeddings.append(embedded)


            if seg_idx == 0:
                plot_psd(padded, fs=FS, subject_id=subject_id, seg_idx=seg_idx)
                plot_spectrogram(padded, fs=FS, subject_id=subject_id, seg_idx=seg_idx)
        else:
            print(f"Skipping segment {seg_idx} (no 'eeg' key)")

    if subject_embeddings:
        all_subject_embeddings[subject_id] = subject_embeddings
        print(f" Finished Subject {subject_id} with {len(subject_embeddings)} segt")


torch.save(all_subject_embeddings, 'dreamdiffusion_eeg_embeddings.pth')
print("Embeddings saved ")



[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m88.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m33.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m29.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m25.5 MB/s[0m eta [36m0:00:00[0m
[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━[0m [32m442.3/664.8 MB[0m [31m136.3 MB/s[0m eta [36m0:00:02[0m
[?25h[31mERROR: Operation cancelled by user[0m[31m
[0m

ModuleNotFoundError: No module named 'mne'