# Example to use model

In [4]:
%load_ext autoreload
%autoreload 2

from sleepwellbaby.model import load_model, process_prediction
from sleepwellbaby.preprocess import convert_to_features, dict_to_df, ref24h_correction

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
import joblib
# with open("../output/models/trained_model.bz2", mode="rb") as f:
with open("../output/models/classifier.bz2", mode="rb") as f:
    model = joblib.load(f)
with open("../output/models/trained_support_obj.pkl", mode="rb") as f:
    model_support = joblib.load(f)

In [6]:
model

CalibratedClassifierCV(base_estimator=RandomForestClassifier(class_weight='balanced',
                                                             max_depth=5,
                                                             max_features=0.1,
                                                             min_samples_leaf=0.03,
                                                             min_samples_split=0.333,
                                                             n_estimators=250,
                                                             random_state=42),
                       cv=3, method='isotonic', n_jobs=12)

## Payload example

In [7]:
# Note, names of vital parameters cannot be changed (e.g. RR to RESP or OS to SpO2) due to reliance on trained_support_obj

payload = {
  "birth_date": "2025-05-08",
  "gestation_period": 210,
  "param_HR": {
    "ref2h_mean": 100,
    "ref2h_median": 100,
    "ref2h_std": 10,
    "ref24h_mean": 100,
    "ref24h_median": 100,
    "ref24h_std": 10,
    "values": [
      88,
      108,
      92,
      91,
      89,
      76,
      101,
      107,
      98,
      101,
      110,
      121,
      89,
      82,
      105,
      102,
      94,
      96,
      110,
      112,
      98,
      100,
      95,
      86,
      88,
      94,
      99,
      99,
      120,
      101,
      104,
      89,
      86,
      96,
      105,
      116,
      114,
      93,
      91,
      98,
      96,
      110,
      95,
      89,
      103,
      88,
      91,
      99,
      94,
      84,
      102,
      108,
      108,
      106,
      103,
      98,
      109,
      104,
      94,
      91,
      106,
      94,
      119,
      106,
      102,
      91,
      91,
      100,
      85,
      103,
      93,
      94,
      106,
      114,
      97,
      98,
      95,
      102,
      99,
      97,
      100,
      95,
      99,
      96,
      106,
      102,
      85,
      109,
      84,
      93,
      119,
      113,
      112,
      98,
      89,
      100,
      94,
      112,
      93,
      86,
      105,
      83,
      90,
      109,
      93,
      95,
      99,
      95,
      93,
      108,
      104,
      103,
      106,
      101,
      92,
      120,
      95,
      107,
      86,
      111,
      120,
      95,
      98,
      79,
      113,
      102,
      100,
      90,
      101,
      101,
      100,
      108,
      115,
      88,
      94,
      91,
      102,
      93,
      102,
      95,
      91,
      99,
      103,
      95,
      99,
      96,
      97,
      99,
      113,
      102,
      119,
      95,
      99,
      110,
      91,
      95,
      105,
      85,
      96,
      107,
      101,
      103,
      106,
      101,
      93,
      123,
      107,
      100,
      122,
      106,
      110,
      109,
      80,
      88,
      102,
      101,
      107,
      88,
      108,
      94,
      92,
      105,
      101,
      99,
      98,
      102,
      91,
      118,
      92,
      117,
      97,
      91
    ]
  },
  "param_RR": {
    "ref2h_mean": 90,
    "ref2h_median": 90,
    "ref2h_std": 10,
    "ref24h_mean": 90,
    "ref24h_median": 90,
    "ref24h_std": 10,
    "values": [
      106,
      93,
      88,
      89,
      106,
      92,
      99,
      94,
      99,
      83,
      115,
      99,
      93,
      92,
      104,
      97,
      117,
      113,
      87,
      74,
      103,
      86,
      92,
      104,
      97,
      99,
      100,
      97,
      91,
      116,
      98,
      94,
      113,
      98,
      113,
      95,
      117,
      99,
      93,
      96,
      87,
      96,
      101,
      102,
      126,
      122,
      96,
      103,
      91,
      114,
      93,
      95,
      95,
      94,
      104,
      104,
      95,
      95,
      100,
      104,
      100,
      93,
      97,
      98,
      93,
      106,
      103,
      113,
      113,
      101,
      92,
      93,
      100,
      108,
      110,
      102,
      89,
      120,
      102,
      107,
      99,
      94,
      107,
      98,
      100,
      106,
      116,
      93,
      115,
      89,
      123,
      108,
      109,
      94,
      103,
      92,
      95,
      108,
      85,
      100,
      89,
      89,
      85,
      109,
      95,
      101,
      99,
      92,
      102,
      108,
      102,
      110,
      101,
      99,
      124,
      111,
      95,
      92,
      120,
      95,
      105,
      94,
      105,
      94,
      94,
      99,
      101,
      87,
      94,
      90,
      96,
      85,
      85,
      108,
      102,
      100,
      104,
      96,
      111,
      108,
      89,
      105,
      91,
      83,
      95,
      93,
      121,
      100,
      103,
      106,
      96,
      115,
      97,
      107,
      88,
      101,
      102,
      100,
      95,
      103,
      95,
      92,
      102,
      101,
      101,
      87,
      119,
      96,
      107,
      112,
      102,
      88,
      116,
      105,
      121,
      88,
      107,
      110,
      105,
      101,
      115,
      112,
      99,
      115,
      92,
      91,
      82,
      92,
      98,
      112,
      98,
      109
    ]
  },
  "param_OS": {
    "ref2h_mean": 100,
    "ref2h_median": 100,
    "ref2h_std": 10,
    "ref24h_mean": 100,
    "ref24h_median": 100,
    "ref24h_std": 10,
    "values": [
      103,
      108,
      104,
      101,
      102,
      104,
      105,
      95,
      83,
      90,
      98,
      83,
      100,
      84,
      104,
      95,
      82,
      105,
      102,
      95,
      113,
      91,
      111,
      94,
      97,
      99,
      111,
      96,
      103,
      104,
      118,
      121,
      83,
      94,
      108,
      99,
      98,
      101,
      101,
      78,
      91,
      107,
      99,
      94,
      81,
      104,
      106,
      103,
      107,
      86,
      86,
      101,
      102,
      101,
      92,
      103,
      108,
      96,
      84,
      116,
      110,
      94,
      91,
      89,
      90,
      100,
      102,
      89,
      92,
      96,
      97,
      95,
      93,
      84,
      97,
      92,
      99,
      81,
      104,
      112,
      100,
      104,
      118,
      94,
      98,
      109,
      108,
      107,
      95,
      81,
      102,
      92,
      102,
      96,
      84,
      89,
      100,
      97,
      105,
      97,
      99,
      101,
      109,
      78,
      105,
      95,
      109,
      86,
      104,
      97,
      82,
      83,
      99,
      105,
      94,
      97,
      85,
      91,
      80,
      78,
      117,
      94,
      105,
      93,
      82,
      103,
      96,
      101,
      94,
      87,
      105,
      109,
      105,
      119,
      106,
      116,
      89,
      110,
      109,
      89,
      90,
      103,
      109,
      94,
      103,
      94,
      95,
      114,
      105,
      103,
      78,
      107,
      91,
      103,
      85,
      92,
      86,
      93,
      103,
      100,
      121,
      109,
      110,
      87,
      109,
      101,
      108,
      119,
      74,
      112,
      94,
      105,
      108,
      104,
      104,
      92,
      104,
      88,
      100,
      71,
      89,
      98,
      103,
      115,
      101,
      95,
      108,
      103,
      103,
      96,
      91,
      111
    ]
  }
}

## Preprocess

In [8]:
type(payload)

dict

In [9]:
def preprocess(data):
    """Preprocess data to df to predict on

    Args:
        data (dict): containing parameter values, ref2h metrics and ref24h metrics

    Returns:
        pd.DataFrame: containing features
    """
    return (
        dict_to_df(data)
        .pipe(ref24h_correction, payload)
        .pipe(convert_to_features)
        .reindex(columns=model_support["Xcol"])
    )

In [11]:
df = preprocess(payload)
pred_proba = model.predict_proba(df)
pred, proba_dict = process_prediction(pred_proba, model.classes_)


Feature Extraction: 100%|██████████| 12/12 [00:00<00:00, 1224.73it/s]




In [13]:
proba_dict

{'active_sleep': 0.5615941683425135,
 'quiet_sleep': 0.20849384661464576,
 'wake': 0.22991198504284074}

#