In [13]:
from zmax_datasets.datasets.usleep import USleepDataset
from pathlib import Path
import numpy as np

DATASETS_DIR = Path("/project/4180000.46/sleep_datasets/processed")
DATASET_NAME = "mesa"

dataset = USleepDataset(data_dir=DATASETS_DIR / DATASET_NAME)
print(dataset.n_recordings)
print(dataset.recording_ids)

2056
['mesa-sleep-0001', 'mesa-sleep-0002', 'mesa-sleep-0006', 'mesa-sleep-0010', 'mesa-sleep-0012', 'mesa-sleep-0014', 'mesa-sleep-0016', 'mesa-sleep-0021', 'mesa-sleep-0027', 'mesa-sleep-0028', 'mesa-sleep-0033', 'mesa-sleep-0035', 'mesa-sleep-0036', 'mesa-sleep-0038', 'mesa-sleep-0046', 'mesa-sleep-0048', 'mesa-sleep-0050', 'mesa-sleep-0052', 'mesa-sleep-0054', 'mesa-sleep-0056', 'mesa-sleep-0064', 'mesa-sleep-0070', 'mesa-sleep-0074', 'mesa-sleep-0077', 'mesa-sleep-0079', 'mesa-sleep-0081', 'mesa-sleep-0084', 'mesa-sleep-0085', 'mesa-sleep-0087', 'mesa-sleep-0099', 'mesa-sleep-0101', 'mesa-sleep-0105', 'mesa-sleep-0149', 'mesa-sleep-0107', 'mesa-sleep-0109', 'mesa-sleep-0110', 'mesa-sleep-0111', 'mesa-sleep-0113', 'mesa-sleep-0118', 'mesa-sleep-0120', 'mesa-sleep-0121', 'mesa-sleep-0125', 'mesa-sleep-0132', 'mesa-sleep-0133', 'mesa-sleep-0138', 'mesa-sleep-0140', 'mesa-sleep-0144', 'mesa-sleep-0152', 'mesa-sleep-0155', 'mesa-sleep-0159', 'mesa-sleep-0167', 'mesa-sleep-0169', 'mesa-

In [14]:
sample_recording = dataset.get_recording("mesa-sleep-0001")
print(sample_recording.data_types)

{'PPG_artifact': DataType(channel='PPG_artifact', sampling_rate=0.1), 'PPG_filtered': DataType(channel='PPG_filtered', sampling_rate=128.0), 'PPG_ibi': DataType(channel='PPG_ibi', sampling_rate=128.0), 'PPG_peaks': DataType(channel='PPG_peaks', sampling_rate=128.0), 'PPG_quality': DataType(channel='PPG_quality', sampling_rate=128.0), 'PPG_rate': DataType(channel='PPG_rate', sampling_rate=128.0)}


In [15]:
import numpy as np

WITH_QUALITY = True

# Load all PPG signals
ppg_filtered = sample_recording.read_data_type("PPG_filtered")
ppg_peaks = sample_recording.read_data_type("PPG_peaks")
ppg_rate = sample_recording.read_data_type("PPG_rate")
ppg_ibi = sample_recording.read_data_type("PPG_ibi")

if WITH_QUALITY:
    ppg_quality = sample_recording.read_data_type("PPG_quality")
    ppg_artifact = sample_recording.read_data_type("PPG_artifact")

# Create time axis in seconds for the signals
duration_seconds = len(ppg_rate.array.squeeze()) / ppg_rate.sample_rate
time = np.linspace(0, duration_seconds, len(ppg_rate.array.squeeze()))
print(f"Signal duration: {duration_seconds:.1f} seconds ({duration_seconds/60:.1f} minutes)")

[32m2025-10-14 11:14:17.129[0m | [1mINFO    [0m | [36mzmax_datasets.datasets.base[0m:[36mread_data_type[0m:[36m36[0m - [1mReading data type: PPG_filtered[0m
[32m2025-10-14 11:14:17.451[0m | [1mINFO    [0m | [36mzmax_datasets.datasets.base[0m:[36mread_data_type[0m:[36m36[0m - [1mReading data type: PPG_peaks[0m
[32m2025-10-14 11:14:17.700[0m | [1mINFO    [0m | [36mzmax_datasets.datasets.base[0m:[36mread_data_type[0m:[36m36[0m - [1mReading data type: PPG_rate[0m
[32m2025-10-14 11:14:17.998[0m | [1mINFO    [0m | [36mzmax_datasets.datasets.base[0m:[36mread_data_type[0m:[36m36[0m - [1mReading data type: PPG_ibi[0m
[32m2025-10-14 11:14:18.291[0m | [1mINFO    [0m | [36mzmax_datasets.datasets.base[0m:[36mread_data_type[0m:[36m36[0m - [1mReading data type: PPG_quality[0m
[32m2025-10-14 11:14:18.437[0m | [1mINFO    [0m | [36mzmax_datasets.datasets.base[0m:[36mread_data_type[0m:[36m36[0m - [1mReading data type: PPG_artifact[0m

Signal duration: 43190.0 seconds (719.8 minutes)


In [16]:
import ipywidgets as widgets
from IPython.display import display
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Create widgets for controlling the window
start_slider = widgets.FloatSlider(
    value=0,
    min=0,
    max=duration_seconds - 60,  # Leave room for the window
    step=10,
    description='Start Time (s):',
    style={'description_width': 'initial'},
    layout={'width': '500px'}
)

window_size = widgets.Dropdown(
    options=[('10 seconds', 10), ('30 seconds', 30), ('1 minute', 60), ('5 minutes', 300)],
    value=60,
    description='Window Size:',
    style={'description_width': 'initial'}
)

# Configure subplot layout based on WITH_QUALITY
if WITH_QUALITY:
    n_rows = 4
    row_heights = [0.4, 0.2, 0.2, 0.2]
    subplot_titles = ("PPG Signal and Peaks", "Signal Quality", "Heart Rate", "Inter-Beat Intervals")
    hr_row = 3
    ibi_row = 4
else:
    n_rows = 3
    row_heights = [0.5, 0.25, 0.25]
    subplot_titles = ("PPG Signal and Peaks", "Heart Rate", "Inter-Beat Intervals")
    hr_row = 2
    ibi_row = 3

# Create initial figure
fig = go.FigureWidget(make_subplots(
    rows=n_rows, cols=1,
    shared_xaxes=True,
    vertical_spacing=0.05,
    row_heights=row_heights,
    subplot_titles=subplot_titles
))

# Initialize with empty traces
fig.add_trace(go.Scatter(name="PPG Filtered", line=dict(color='blue')), row=1, col=1)
fig.add_trace(go.Scatter(name="Peaks", mode='markers', marker=dict(color='red', size=8, symbol='circle')), row=1, col=1)

if WITH_QUALITY:
    fig.add_trace(go.Scatter(name="Signal Quality", line=dict(color='purple'), fill='tozeroy'), row=2, col=1)

fig.add_trace(go.Scatter(name="Heart Rate", line=dict(color='orange')), row=hr_row, col=1)
fig.add_trace(go.Scatter(name="Mean HR", line=dict(color='red', dash='dash')), row=hr_row, col=1)
# Add IBI trace and threshold lines
fig.add_trace(go.Scatter(name="IBI", line=dict(color='green')), row=ibi_row, col=1)
fig.add_trace(go.Scatter(name="Min Threshold", line=dict(color='red', dash='dash')), row=ibi_row, col=1)
fig.add_trace(go.Scatter(name="Max Threshold", line=dict(color='red', dash='dash')), row=ibi_row, col=1)

# Update layout
fig.update_layout(
    height=900,
    title=f"PPG Signal Analysis - Recording {str(sample_recording)}",
    showlegend=True,
    template="plotly_white"
)

# Update axes labels
fig.update_yaxes(title_text="Amplitude", row=1, col=1)
if WITH_QUALITY:
    fig.update_yaxes(title_text="Quality Score", range=[0, 1], row=2, col=1)
fig.update_yaxes(title_text="Heart Rate (BPM)", range=[0, 200], row=hr_row, col=1)
fig.update_yaxes(title_text="IBI (ms)", range=[0, 2500], row=ibi_row, col=1)  # Set range slightly above max threshold
fig.update_xaxes(title_text="Time (seconds)", row=n_rows, col=1)

def update_plot(start_time, window_duration):
    # Calculate indices for the window
    start_idx = int(start_time * ppg_rate.sample_rate)
    end_idx = int((start_time + window_duration) * ppg_rate.sample_rate)
    
    # Create time array for the window
    time_window = np.linspace(start_time, start_time + window_duration, end_idx - start_idx)
    
    # Get signal segments
    filtered_signal = ppg_filtered.array.squeeze()[start_idx:end_idx]
    peaks_signal = ppg_peaks.array.squeeze()[start_idx:end_idx]
    if WITH_QUALITY:
        quality_signal = ppg_quality.array.squeeze()[start_idx:end_idx]
    rate_signal = ppg_rate.array.squeeze()[start_idx:end_idx]
    ibi_signal = ppg_ibi.array.squeeze()[start_idx:end_idx]
    
    # Determine trace indices based on WITH_QUALITY
    if WITH_QUALITY:
        quality_trace_idx = 2
        hr_trace_idx = 3
        mean_hr_trace_idx = 4
        ibi_trace_idx = 5
        min_threshold_trace_idx = 6
        max_threshold_trace_idx = 7
    else:
        hr_trace_idx = 2
        mean_hr_trace_idx = 3
        ibi_trace_idx = 4
        min_threshold_trace_idx = 5
        max_threshold_trace_idx = 6
    
    # Update traces with new data
    with fig.batch_update():
        # Update PPG signal
        fig.data[0].x = time_window
        fig.data[0].y = filtered_signal
        
        # Update peaks
        peaks_idx = np.where(peaks_signal == 1)[0]
        if len(peaks_idx) > 0:
            fig.data[1].x = time_window[peaks_idx]
            fig.data[1].y = filtered_signal[peaks_idx]
        else:
            fig.data[1].x = []
            fig.data[1].y = []
        
        # Clear existing shapes (artifact markers)
        fig.layout.shapes = []
        
        # Add artifact regions as shaded rectangles (only if WITH_QUALITY is True)
        if WITH_QUALITY:
            # For artifacts, we need to find which artifact labels overlap with the time window
            # Convert start_time to absolute timestamp
            start_timestamp = ppg_filtered.timestamps[0] + start_time * 1e9
            end_timestamp = start_timestamp + window_duration * 1e9
            
            # Find artifact indices that overlap with this time window
            artifact_mask = (ppg_artifact.timestamps >= start_timestamp) & (ppg_artifact.timestamps < end_timestamp)
            artifact_values = ppg_artifact.array.squeeze()[artifact_mask]
            
            # Get the actual timestamps for the artifacts in this window (in seconds relative to PPG start)
            artifact_time = (ppg_artifact.timestamps[artifact_mask] - ppg_filtered.timestamps[0]) / 1e9
            
            # Mark artifacts
            shapes = []
            artifact_indices = np.where(artifact_values == 1)[0]
            if len(artifact_indices) > 0:
                # Group consecutive artifact indices
                artifact_regions = []
                start_idx = artifact_indices[0]
                for i in range(1, len(artifact_indices)):
                    if artifact_indices[i] != artifact_indices[i-1] + 1:
                        artifact_regions.append((start_idx, artifact_indices[i-1]))
                        start_idx = artifact_indices[i]
                artifact_regions.append((start_idx, artifact_indices[-1]))
                
                for start_idx, end_idx in artifact_regions:
                    shapes.append(
                        dict(
                            type="rect",
                            xref="x",
                            yref="paper",
                            x0=artifact_time[start_idx],
                            x1=artifact_time[end_idx] + (1 / ppg_artifact.sample_rate),  # Extend to end of 10-sec window
                            y0=0.6,  # Top subplot (PPG signal)
                            y1=1.0,
                            fillcolor="rgba(255, 0, 0, 0.2)",
                            line=dict(width=0),
                            layer="below"
                        )
                    )
            
            fig.layout.shapes = shapes
        
        # Update signal quality (only if WITH_QUALITY is True)
        if WITH_QUALITY:
            fig.data[quality_trace_idx].x = time_window
            fig.data[quality_trace_idx].y = quality_signal
        
        # Update heart rate
        fig.data[hr_trace_idx].x = time_window
        fig.data[hr_trace_idx].y = rate_signal
        
        # Update mean heart rate
        mean_hr = np.nanmean(rate_signal)
        fig.data[mean_hr_trace_idx].x = [time_window[0], time_window[-1]]
        fig.data[mean_hr_trace_idx].y = [mean_hr, mean_hr]
        fig.data[mean_hr_trace_idx].name = f"Mean HR ({mean_hr:.1f} BPM)"
        
        # Update IBI plot
        # Only plot non-zero IBI values
        valid_ibi_mask = ibi_signal > 0
        valid_times = time_window[valid_ibi_mask]
        valid_ibis = ibi_signal[valid_ibi_mask]
        
        fig.data[ibi_trace_idx].x = valid_times
        fig.data[ibi_trace_idx].y = valid_ibis
        
        # Update threshold lines
        fig.data[min_threshold_trace_idx].x = [time_window[0], time_window[-1]]  # Min threshold
        fig.data[min_threshold_trace_idx].y = [300, 300]  # 300ms threshold
        
        fig.data[max_threshold_trace_idx].x = [time_window[0], time_window[-1]]  # Max threshold
        fig.data[max_threshold_trace_idx].y = [2000, 2000]  # 2000ms threshold

# Create the interactive plot
def on_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        update_plot(start_slider.value, window_size.value)

# Link the widgets to the update function
start_slider.observe(on_change)
window_size.observe(on_change)

# Display widgets and initial plot
display(widgets.HBox([start_slider, window_size]))
display(fig)

# Initialize the plot
update_plot(start_slider.value, window_size.value)


HBox(children=(FloatSlider(value=0.0, description='Start Time (s):', layout=Layout(width='500px'), max=43130.0…

FigureWidget({
    'data': [{'line': {'color': 'blue'},
              'name': 'PPG Filtered',
              'type': 'scatter',
              'uid': 'd16e30e5-0b41-470c-b8b8-bc80e2877ec3',
              'xaxis': 'x',
              'yaxis': 'y'},
             {'marker': {'color': 'red', 'size': 8, 'symbol': 'circle'},
              'mode': 'markers',
              'name': 'Peaks',
              'type': 'scatter',
              'uid': '2172b313-0825-4f08-b6d4-72908be1cc3b',
              'xaxis': 'x',
              'yaxis': 'y'},
             {'fill': 'tozeroy',
              'line': {'color': 'purple'},
              'name': 'Signal Quality',
              'type': 'scatter',
              'uid': '2b3ae48d-e128-4b41-b51f-a997a851bccd',
              'xaxis': 'x2',
              'yaxis': 'y2'},
             {'line': {'color': 'orange'},
              'name': 'Heart Rate',
              'type': 'scatter',
              'uid': '46b34ad9-9992-4d9b-9c18-9bf87a08ddfd',
              'xaxis': 'x