In [27]:
from zmax_datasets.datasets.usleep import USleepDataset
from pathlib import Path
import numpy as np

DATASETS_DIR = Path("/project/4180000.46/zmax_datasets/eeg_processed/")
DATASET_NAME = "wearanize_plus"

dataset = USleepDataset(data_dir=DATASETS_DIR / DATASET_NAME)
print(dataset.n_recordings)
print(dataset.recording_ids)


97
['Sub001', 'Sub002', 'Sub004', 'Sub005', 'Sub007', 'Sub009', 'Sub010', 'Sub011', 'Sub012', 'Sub013', 'Sub014', 'Sub016', 'Sub017', 'Sub019', 'Sub020', 'Sub022', 'Sub023', 'Sub024', 'Sub025', 'Sub027', 'Sub028', 'Sub029', 'Sub031', 'Sub032', 'Sub033', 'Sub034', 'Sub035', 'Sub036', 'Sub038', 'Sub039', 'Sub040', 'Sub041', 'Sub043', 'Sub044', 'Sub045', 'Sub046', 'Sub047', 'Sub048', 'Sub050', 'Sub051', 'Sub053', 'Sub054', 'Sub056', 'Sub057', 'Sub059', 'Sub060', 'Sub061', 'Sub062', 'Sub063', 'Sub064', 'Sub065', 'Sub066', 'Sub067', 'Sub068', 'Sub071', 'Sub072', 'Sub073', 'Sub074', 'Sub075', 'Sub076', 'Sub079', 'Sub080', 'Sub081', 'Sub082', 'Sub084', 'Sub085', 'Sub086', 'Sub087', 'Sub088', 'Sub089', 'Sub090', 'Sub091', 'Sub092', 'Sub093', 'Sub094', 'Sub095', 'Sub096', 'Sub097', 'Sub099', 'Sub100', 'Sub101', 'Sub102', 'Sub103', 'Sub104', 'Sub105', 'Sub106', 'Sub107', 'Sub108', 'Sub109', 'Sub110', 'Sub111', 'Sub112', 'Sub118', 'Sub121', 'Sub124', 'Sub129', 'Sub130']


In [28]:
# Get a sample recording
sample_recording = dataset.get_recording("Sub102")
print("Available data types:")
print(sample_recording.data_types)


Available data types:
{'EEG_LEFT': DataType(channel='EEG_LEFT', sampling_rate=128.0), 'EEG_LEFT_artifact': DataType(channel='EEG_LEFT_artifact', sampling_rate=0.1), 'EEG_RIGHT': DataType(channel='EEG_RIGHT', sampling_rate=128.0), 'EEG_RIGHT_artifact': DataType(channel='EEG_RIGHT_artifact', sampling_rate=0.1)}


In [29]:
# Load all EEG signals and artifact channels
eeg_left = sample_recording.read_data_type("EEG_LEFT")
eeg_right = sample_recording.read_data_type("EEG_RIGHT")
eeg_left_artifact = sample_recording.read_data_type("EEG_LEFT_artifact")
eeg_right_artifact = sample_recording.read_data_type("EEG_RIGHT_artifact")

# Create time axis in seconds for the EEG signals
duration_seconds = len(eeg_left.array.squeeze()) / eeg_left.sample_rate
time = np.linspace(0, duration_seconds, len(eeg_left.array.squeeze()))

print(f"\nSignal duration: {duration_seconds:.1f} seconds ({duration_seconds/60:.1f} minutes)")
print(f"EEG LEFT sampling rate: {eeg_left.sample_rate} Hz")
print(f"EEG RIGHT sampling rate: {eeg_right.sample_rate} Hz")
print(f"EEG LEFT artifact sampling rate: {eeg_left_artifact.sample_rate} Hz")
print(f"EEG RIGHT artifact sampling rate: {eeg_right_artifact.sample_rate} Hz")
print(f"\nEEG LEFT shape: {eeg_left.array.shape}")
print(f"EEG RIGHT shape: {eeg_right.array.shape}")
print(f"EEG LEFT artifact shape: {eeg_left_artifact.array.shape}")
print(f"EEG RIGHT artifact shape: {eeg_right_artifact.array.shape}")


[32m2025-10-10 11:49:51.260[0m | [1mINFO    [0m | [36mzmax_datasets.datasets.base[0m:[36mread_data_type[0m:[36m36[0m - [1mReading data type: EEG_LEFT[0m
[32m2025-10-10 11:49:51.813[0m | [1mINFO    [0m | [36mzmax_datasets.datasets.base[0m:[36mread_data_type[0m:[36m36[0m - [1mReading data type: EEG_RIGHT[0m
[32m2025-10-10 11:49:52.133[0m | [1mINFO    [0m | [36mzmax_datasets.datasets.base[0m:[36mread_data_type[0m:[36m36[0m - [1mReading data type: EEG_LEFT_artifact[0m
[32m2025-10-10 11:49:52.139[0m | [1mINFO    [0m | [36mzmax_datasets.datasets.base[0m:[36mread_data_type[0m:[36m36[0m - [1mReading data type: EEG_RIGHT_artifact[0m



Signal duration: 30690.0 seconds (511.5 minutes)
EEG LEFT sampling rate: 128.0 Hz
EEG RIGHT sampling rate: 128.0 Hz
EEG LEFT artifact sampling rate: 0.1 Hz
EEG RIGHT artifact sampling rate: 0.1 Hz

EEG LEFT shape: (3928320, 1)
EEG RIGHT shape: (3928320, 1)
EEG LEFT artifact shape: (3069, 1)
EEG RIGHT artifact shape: (3069, 1)


In [30]:
print(np.unique(eeg_left_artifact.array.squeeze(), return_counts=True))
print(np.unique(eeg_right_artifact.array.squeeze(), return_counts=True))

(array([0, 1]), array([ 805, 2264]))
(array([0, 1]), array([2449,  620]))


In [31]:
import ipywidgets as widgets
from IPython.display import display
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Create widgets for controlling the window
start_slider = widgets.FloatSlider(
    value=0,
    min=0,
    max=duration_seconds - 60,  # Leave room for the window
    step=10,
    description='Start Time (s):',
    style={'description_width': 'initial'},
    layout={'width': '500px'}
)

window_size = widgets.Dropdown(
    options=[('10 seconds', 10), ('30 seconds', 30), ('1 minute', 60), ('5 minutes', 300)],
    value=60,
    description='Window Size:',
    style={'description_width': 'initial'}
)

# Create initial figure with 2 subplots (one for each EEG channel)
fig = go.FigureWidget(make_subplots(
    rows=2, cols=1,
    shared_xaxes=True,
    vertical_spacing=0.1,
    row_heights=[0.5, 0.5],
    subplot_titles=("EEG LEFT", "EEG RIGHT")
))

# Initialize with empty traces
# Add EEG LEFT to first subplot
fig.add_trace(go.Scatter(name="EEG LEFT", line=dict(color='blue')), row=1, col=1)

# Add EEG RIGHT to second subplot
fig.add_trace(go.Scatter(name="EEG RIGHT", line=dict(color='red')), row=2, col=1)

# Update layout
fig.update_layout(
    height=800,
    title=f"EEG Signal Analysis - Recording {str(sample_recording)}",
    showlegend=True,
    template="plotly_white"
)

# Update axes labels
fig.update_yaxes(title_text="Amplitude (µV)", row=1, col=1)
fig.update_yaxes(title_text="Amplitude (µV)", row=2, col=1)
fig.update_xaxes(title_text="Time (seconds)", row=2, col=1)

def update_plot(start_time, window_duration):
    # Calculate indices for the EEG signals
    eeg_start_idx = int(start_time * eeg_left.sample_rate)
    eeg_end_idx = int((start_time + window_duration) * eeg_left.sample_rate)
    
    # Create time array for the EEG window using actual timestamps
    time_window = (eeg_left.timestamps[eeg_start_idx:eeg_end_idx] - eeg_left.timestamps[0]) / 1e9
    
    # Get EEG signal segments
    left_signal = eeg_left.array.squeeze()[eeg_start_idx:eeg_end_idx]
    right_signal = eeg_right.array.squeeze()[eeg_start_idx:eeg_end_idx]
    
    # For artifacts, we need to find which artifact labels overlap with the time window
    # Convert start_time to absolute timestamp
    start_timestamp = eeg_left.timestamps[0] + start_time * 1e9
    end_timestamp = start_timestamp + window_duration * 1e9
    
    # Find artifact indices that overlap with this time window
    artifact_left_mask = (eeg_left_artifact.timestamps >= start_timestamp) & (eeg_left_artifact.timestamps < end_timestamp)
    artifact_right_mask = (eeg_right_artifact.timestamps >= start_timestamp) & (eeg_right_artifact.timestamps < end_timestamp)
    
    left_artifact = eeg_left_artifact.array.squeeze()[artifact_left_mask]
    right_artifact = eeg_right_artifact.array.squeeze()[artifact_right_mask]
    
    # Get the actual timestamps for the artifacts in this window (in seconds relative to EEG start)
    artifact_left_time = (eeg_left_artifact.timestamps[artifact_left_mask] - eeg_left.timestamps[0]) / 1e9
    artifact_right_time = (eeg_right_artifact.timestamps[artifact_right_mask] - eeg_left.timestamps[0]) / 1e9
    
    # Update traces with new data
    with fig.batch_update():
        # Update EEG signals
        fig.data[0].x = time_window
        fig.data[0].y = left_signal
        
        fig.data[1].x = time_window
        fig.data[1].y = right_signal
        
        # Clear existing shapes (artifact markers)
        fig.layout.shapes = []
        
        # Add artifact regions as shaded rectangles
        shapes = []
        
        # Mark artifacts for LEFT channel
        artifact_indices_left = np.where(left_artifact == 1)[0]
        if len(artifact_indices_left) > 0:
            # Group consecutive artifact indices
            artifact_regions_left = []
            start = artifact_indices_left[0]
            for i in range(1, len(artifact_indices_left)):
                if artifact_indices_left[i] != artifact_indices_left[i-1] + 1:
                    artifact_regions_left.append((start, artifact_indices_left[i-1]))
                    start = artifact_indices_left[i]
            artifact_regions_left.append((start, artifact_indices_left[-1]))
            
            for start_idx, end_idx in artifact_regions_left:
                shapes.append(
                    dict(
                        type="rect",
                        xref="x",
                        yref="paper",
                        x0=artifact_left_time[start_idx],
                        x1=artifact_left_time[end_idx] + (1 / eeg_left_artifact.sample_rate),  # Extend to end of 10-sec window
                        y0=0.55,  # Top subplot
                        y1=1.0,
                        fillcolor="rgba(255, 0, 0, 0.2)",
                        line=dict(width=0),
                        layer="below"
                    )
                )
        
        # Mark artifacts for RIGHT channel
        artifact_indices_right = np.where(right_artifact == 1)[0]
        if len(artifact_indices_right) > 0:
            # Group consecutive artifact indices
            artifact_regions_right = []
            start = artifact_indices_right[0]
            for i in range(1, len(artifact_indices_right)):
                if artifact_indices_right[i] != artifact_indices_right[i-1] + 1:
                    artifact_regions_right.append((start, artifact_indices_right[i-1]))
                    start = artifact_indices_right[i]
            artifact_regions_right.append((start, artifact_indices_right[-1]))
            
            for start_idx, end_idx in artifact_regions_right:
                shapes.append(
                    dict(
                        type="rect",
                        xref="x2",
                        yref="paper",
                        x0=artifact_right_time[start_idx],
                        x1=artifact_right_time[end_idx] + (1 / eeg_right_artifact.sample_rate),  # Extend to end of 10-sec window
                        y0=0.0,  # Bottom subplot
                        y1=0.45,
                        fillcolor="rgba(255, 0, 0, 0.2)",
                        line=dict(width=0),
                        layer="below"
                    )
                )
        
        fig.layout.shapes = shapes

# Create the interactive plot
def on_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        update_plot(start_slider.value, window_size.value)

# Link the widgets to the update function
start_slider.observe(on_change)
window_size.observe(on_change)

# Display widgets and initial plot
display(widgets.HBox([start_slider, window_size]))
display(fig)

# Initialize the plot
update_plot(start_slider.value, window_size.value)


HBox(children=(FloatSlider(value=0.0, description='Start Time (s):', layout=Layout(width='500px'), max=30630.0…

FigureWidget({
    'data': [{'line': {'color': 'blue'},
              'name': 'EEG LEFT',
              'type': 'scatter',
              'uid': '46b8da9e-975f-4d39-8af0-125e86f84e90',
              'xaxis': 'x',
              'yaxis': 'y'},
             {'line': {'color': 'red'},
              'name': 'EEG RIGHT',
              'type': 'scatter',
              'uid': 'ab3a8c6d-0737-4b30-9131-91624ab1ddd5',
              'xaxis': 'x2',
              'yaxis': 'y2'}],
    'layout': {'annotations': [{'font': {'size': 16},
                                'showarrow': False,
                                'text': 'EEG LEFT',
                                'x': 0.5,
                                'xanchor': 'center',
                                'xref': 'paper',
                                'y': 1.0,
                                'yanchor': 'bottom',
                                'yref': 'paper'},
                               {'font': {'size': 16},
                             