In [None]:
import mne
import numpy as np
import MARBLE
from MARBLE import postprocessing, plotting
import matplotlib.pyplot as plt

# Parameters
segment_length = 100000  # Adjust based on your memory capabilities
results = []


# Selected files
selected_files = [
    "24_08_13-D_14_36_48_bipolar.fif",
    "24_08_13-D_14_49_45_bipolar.fif",
    # "24_08_13-D_15_00_00_bipolar.fif",
    # "24_08_13-D_16_57_44_bipolar.fif",
    # "24_08_13-D_18_43_55_bipolar.fif",
    # "24_08_13-D_20_00_00_bipolar.fif",
    # "24_08_14-D_11_17_03_bipolar.fif",
    # "24_08_14-D_12_00_00_bipolar.fif",
    # "24_08_14-D_13_00_00_bipolar.fif",
    # "24_08_14-D_15_00_00_bipolar.fif",
]

# Process each file in segments
for file in selected_files:
    # Load data
    raw = mne.io.read_raw_fif('./preprocessed/bipolar/'+file, preload=True, verbose=False)
    data = raw.get_data().T
    data = (data - data.mean(axis=0)) / data.std(axis=0)
    
    # Process in segments
    for start_idx in range(0, len(data)-1, segment_length):
        end_idx = min(start_idx + segment_length, len(data)-1)
        segment_data = data[start_idx:end_idx]
        
        # Skip segments that are too small
        if len(segment_data) < 100:
            continue
            
        print(f"Processing segment {start_idx}:{end_idx} from file {file}")
        
        # Create position and vector lists for this segment
        pos_list = [segment_data[:-1, :]]
        x_list = [np.diff(segment_data, axis=0)]
        
        # Construct dataset with appropriate spacing
        # Use larger spacing for memory efficiency
        segment_dataset = MARBLE.construct_dataset(
            anchor=pos_list, 
            graph_type="cknn",
            vector=x_list,
            spacing=0.1,  # Increase this value to sample fewer points
            memory_efficient=True,
        )
        
        # Train model
        params = {
            "epochs": 50,  # Reduce epochs for faster processing
            "order": 1,
            "hidden_channels": [32],  # Smaller network
            "batch_size": 256,
            "lr": 1e-4,
            "out_channels": 3,
            "inner_product_features": False,
            "emb_norm": True,
            "diffusion": True,
        }
        
        # Create and train model
        model = MARBLE.net(segment_dataset, params=params)
        model.fit(segment_dataset)
        
        # Transform and process
        transformed_data = model.transform(segment_dataset)
        transformed_data = postprocessing.embed_in_2D(transformed_data)
        transformed_data = postprocessing.cluster(transformed_data, n_clusters=5)
        
        # Store results
        results.append({
            'file': file,
            'segment': (start_idx, end_idx),
            'data': transformed_data,
            'model': model
        })
        
        # Optionally save model and visualization
        plt.figure(figsize=(10, 8))
        plotting.state_space(transformed_data)
        plt.savefig(f"marble_segment_{file}_{start_idx}_{end_idx}.png")
        plt.close()


  from .autonotebook import tqdm as notebook_tqdm
  raw = mne.io.read_raw_fif('./preprocessed/bipolar/'+file, preload=True, verbose=False)


Processing segment 0:100000 from file 24_08_13-D_14_36_48_bipolar.fif


In [4]:
data = model.transform(data)
# data = postprocessing.distribution_distances(data)
data = postprocessing.embed_in_2D(data)
data = postprocessing.cluster(data, n_clusters=10)


 No umap embedding performed. Embedding seems to be               already in 2D.


  kmeans = KMeans(n_clusters=n_clusters, random_state=seed).fit(x)
