# Data Preparation

This notebook will be used to prepare the data for machine learning.

1. Annotate the dataset (Sleep 0 /Awake 1)
2. Signal Preparation (scaling, missing data, outliers, smoothing)
3. Subset generation (light, medium, heavy)

## Import


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import polars as pl
import gc
import joblib

## Data Preparation

**Convert timestamp to datetime**


In [None]:
timestamp = [
    pl.col("timestamp").str.to_datetime("%Y-%m-%dT%H:%M:%S%z")
]

**Min-max normalization**

In [None]:
min_max_normalization = lambda x: (x - x.min()) / (x.max() - x.min())
normalization = [
    pl.col("anglez").map_batches(min_max_normalization).cast(pl.Float32), 
    pl.col("enmo").map_batches(min_max_normalization).cast(pl.Float32),
    pl.col("step").cast(pl.UInt32),
]


**Data import**

In [None]:
df_signals = pl.scan_parquet("data/train_series.parquet").with_columns(
    timestamp + normalization
).collect(streaming=True)

In [None]:
df_events = pl.scan_csv("data/train_events.csv").with_columns(
    timestamp + [pl.col("step").cast(pl.UInt32)]
).drop_nulls().collect()

**Data cleaning**

In [None]:
"""
# Removing null events and nights with mismatched counts from series_events
mismatches = df_events.group_by(['series_id', 'night']).agg(
    (pl.col('event') == 'onset').sum().alias('onset'),
    (pl.col('event') == 'wakeup').sum().alias('wakeup')
    ).sort(by=['series_id', 'night']).filter(pl.col('onset') != pl.col('wakeup')).select(pl.all().exclude('onset', 'wakeup'))
print(f"The mismatch Onset and Wakeup are : \n {mismatches}")
df_events = df_events.join(mismatches, on=['series_id', 'night'], how='anti')
"""

In [None]:
# Count for each series_id the number of onset and wakeup events
df_events_problem = df_events.group_by(['series_id']).agg(
    (pl.col('event') == 'onset').sum().alias('onset'),
    (pl.col('event') == 'wakeup').sum().alias('wakeup')
    ).sort(by=['series_id'])

In [None]:
# display the series_id with mismatched counts
mismatches = df_events_problem.filter(pl.col('onset') != pl.col('wakeup')).select(pl.all().exclude('onset', 'wakeup'))
print(f"The mismatch Onset and Wakeup are : \n {mismatches}")

**Merge data**

In [None]:
df = df_signals.join(df_events, on=['series_id', 'timestamp', 'step'], how='left')

In [None]:
df

**Annotation Sleep // Awake**

In [None]:
# for each series_id, the state is 0 if the subject is sleeping and 1 if the subject is awake

df = df.with_columns(
   pl.lit(None).alias('state').cast(pl.Boolean)
)

for series_id in df['series_id'].unique().to_list():
    df_sorted = df.filter(pl.col('series_id') == series_id).sort(by=['timestamp'])

    for i in range(1, len(df_sorted)):
        if i == 1:
            df_sorted[1, 'state'] = True
        elif df_sorted[i, 'event'] == 'onset':
            df_sorted[i, 'state'] = False
        elif df_sorted[i, 'event'] == 'wakeup':
            df_sorted[i, 'state'] = True
        else:
            df_sorted[i, 'state'] = df_sorted[i-1, 'state']
    
    # store in a parquet file with joblib
    joblib.dump(df_sorted, f'data/train_series_{series_id}.parquet')



**Missing Data**

Remove signals 6 hours after awake and 6 hours before sleep when an annotation is missing

In [None]:
# For each parquet file representing a time series
# We will sort them by timestamp
# if there are periods with 20 hours without sleep
# We will remove a period of 16 hours because we consider the annotations as missing

**Smoothing**

In [None]:
# Your code here ...

**Stratified Export**

In [None]:
# Your code here ...