# Example
This notebook illustrates how the sleepwellbaby package can be used to get sleep state predictions using mock data.

In [25]:
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd

from sleepwellbaby.model import load_model
from sleepwellbaby.data import (
    compute_reference_values,
    generate_mock_signalbase_data,
)
from sleepwellbaby.utils import get_swb_predictions, shap
from sleepwellbaby.preprocess import pipeline


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Prepare data
Generate some mock data (same format as SignalBase with 1Hz). If you want to use 0.4Hz set `freq=2s500ms`.

In [6]:
df = generate_mock_signalbase_data(freq='S').sort_values(by='datetime').set_index('datetime')

Compute reference values using rolling windows

In [7]:
df = compute_reference_values(df, freq=1)

## Get predictions 
We make a prediction every one minute in the below example.

*Note*: we start 8 minutes later than our dataframe otherwise we get a ValueError in `get_swb_predictions`.

In [8]:
# Get indices of times at which you want to get an SWB value (we only compute it at whole minutes)
t_range_swb = pd.date_range(
    start=(df.index.min().round(freq='min') + pd.Timedelta(minutes=8)),
    end=df.index.max().round(freq='min'),
    freq='1min'
)

Return predictions (takes a couple of minutes for a single day with one prediction every minute)

In [14]:
payloads = shap(df, t_range_swb, birth_date="2000-01-01", gestation_period=210, freq='S')

Calculating SWB: 100%|██████████| 2873/2873 [00:08<00:00, 330.82it/s]


Drop rows without a prediction

In [30]:
payloads = np.array(payloads)
mask = [x is not None for x in payloads]
payloads[mask]

model, model_support_dict = load_model()
dfs = [pipeline(p, model_support_dict) for p in payloads[mask][:100]]
df = pd.concat(dfs)


In [31]:
df

Unnamed: 0,HR__0_120__median,HR__0_120__mean,HR__0_120__variance,HR__0_120__maximum,HR__0_120__minimum,"HR__0_120__linear_trend__attr_""pvalue""","HR__0_120__linear_trend__attr_""rvalue""","HR__0_120__linear_trend__attr_""intercept""","HR__0_120__linear_trend__attr_""slope""",HR__0_240__median,...,"RR__0_480__linear_trend__attr_""slope""",RR__0_60__median,RR__0_60__mean,RR__0_60__variance,RR__0_60__maximum,RR__0_60__minimum,"RR__0_60__linear_trend__attr_""pvalue""","RR__0_60__linear_trend__attr_""rvalue""","RR__0_60__linear_trend__attr_""intercept""","RR__0_60__linear_trend__attr_""slope"""
0,0.120394,0.235572,0.749809,2.205438,-1.376022,0.664039,-0.064323,0.330054,-0.004021,-0.053821,...,0.000589,0.011137,0.036788,0.893154,1.657312,-2.280842,0.258116,0.240259,-0.340435,0.032802
0,0.173177,0.137855,0.780982,1.898651,-1.543657,0.152735,-0.209618,0.452095,-0.013372,0.047710,...,-0.001536,-0.448352,-0.372965,0.888254,1.430271,-2.300612,0.619434,0.106789,-0.540170,0.014540
0,0.112393,0.025206,0.897659,1.897466,-1.692563,0.852549,-0.027549,0.069482,-0.001884,0.116683,...,-0.003230,-0.666427,-0.488510,1.023420,1.461642,-1.965262,0.012059,-0.503875,0.358335,-0.073639
0,-0.169907,0.064088,0.901172,2.077222,-1.695403,0.509057,-0.097653,0.221342,-0.006692,0.120252,...,-0.002429,0.053261,-0.100262,1.050770,1.152903,-2.652325,0.117784,-0.327883,0.458115,-0.048554
0,0.137341,0.124580,0.810317,2.077134,-1.933647,0.473429,0.105979,-0.037250,0.006886,0.111383,...,0.000248,0.527417,0.342590,0.572308,1.822344,-1.108488,0.755020,-0.067208,0.427058,-0.007345
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,0.366417,0.126996,1.151989,1.646811,-2.715009,0.772240,-0.042889,0.205084,-0.003323,0.454402,...,0.000808,-0.016547,-0.014579,1.051328,1.460169,-1.952650,0.363587,0.194045,-0.345120,0.028743
0,-0.189784,-0.097743,1.009706,1.640034,-1.937780,0.867195,0.024786,-0.139992,0.001798,0.277223,...,0.001314,0.413318,0.157265,0.884381,1.880411,-2.033231,0.391416,-0.183239,0.443546,-0.024894
0,-0.216787,-0.076386,0.776068,1.737867,-1.590518,0.183941,0.195078,-0.367906,0.012405,0.023030,...,0.001639,0.178168,0.210424,0.601548,1.836385,-1.520900,0.789603,-0.057490,0.284500,-0.006441
0,-0.228110,-0.175331,0.735894,1.780446,-2.146330,0.199590,-0.188448,0.098897,-0.011669,-0.222421,...,0.000509,-0.312645,-0.250993,0.835339,1.545283,-2.168287,0.129361,-0.318456,0.232550,-0.042047


In [None]:
df.tail(10)

#