In [30]:
#!/usr/bin/env python
"""
Test script to evaluate CUDA acceleration in MARBLE.
This script processes EEG data files from a single day (24_08_13)
and uses CUDA for dataset construction and processing.
"""

import os
import time
import torch

import numpy as np
import matplotlib.pyplot as plt
import mne
import pickle
from glob import glob
from tqdm import tqdm

import MARBLE
from MARBLE import postprocessing, plotting, preprocessing,ge

params = {
        "epochs": 50,
        "order": 1,
        "hidden_channels": [256],
        "batch_size": 256,
        "lr": 1e-3,
        "out_channels": 3,
        "inner_product_features": False,
        "emb_norm": True,
        "diffusion": True,
    }


In [9]:
# Get EEG files from 24_08_13
data_dir = "./preprocessed/bipolar"
files = sorted(glob(os.path.join(data_dir, "24_08_13*.fif")))
file = files[0]
raw = mne.io.read_raw_fif(file, preload=True, verbose=False)
current_data = None
max_samples = 100000
raw = raw.resample(200)

file_data = raw.get_data()  # Channels x Time

if current_data is None:
    current_data = file_data
else:
    # Concatenate along time dimension (axis=1)
    current_data = np.concatenate([current_data, file_data], axis=1)

batch_data = current_data[:, :max_samples]
current_data = current_data[:, max_samples:]

# Normalize and transpose to Time x Channels
batch_data = batch_data.T
batch_data = (batch_data - batch_data.mean(axis=0)) / batch_data.std(axis=0)

# Create position and vector lists
pos_list = batch_data[:-1, :]
x_list = np.diff(batch_data, axis=0)

k_value = 20
# Construct dataset
Dataset = MARBLE.construct_dataset(
    anchor=pos_list, 
    vector=x_list,
    graph_type="cknn",
    k=k_value,  
    spacing=0.05,
)

model = MARBLE.net(Dataset, params=params)
with tqdm(total=params["epochs"], desc="Training model", unit="epoch") as pbar:
    model.fit(Dataset)

transformed_data = model.transform(Dataset)
transformed_data = postprocessing.embed_in_2D(transformed_data)

  raw = mne.io.read_raw_fif(file, preload=True, verbose=False)



---- Embedding dimension: 6
---- Signal dimension: 6
---- Computing kernels ... 
---- Computing full spectrum ...
              (if this takes too long, then run construct_dataset()
              with number_of_eigenvectors specified) 