In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import polars as pl

from tqdm import tqdm

In [2]:
# Extended signals data
timestamps = ([f"2022-01-01 {hour:02d}:00:00" for hour in range(24)] + 
             ["2022-01-02 00:00:00"] +
             [f"2022-01-02 {hour:02d}:00:00" for hour in range(1, 13)])

df_signals = pl.DataFrame({
    "series_id": ["A"] * 37 + ["B"] * 37 + ["C"] * 37,
    "timestamp": timestamps * 3,
    "step": list(range(1, 38)) * 3
})

# Extended events data
df_events = pl.DataFrame({
    "series_id": ["A", "A", "A", "A", "B", "B", "C"],
    "night": [1, 1, 1, 2, 1, 1, 1],
    "event": ["onset", "wakeup", "onset", "wakeup", "onset", "wakeup", "onset"],
    "timestamp": ["2022-01-01 02:00:00", "2022-01-01 14:00:00","2022-01-01 22:00:00", "2022-01-02 08:00:00", "2022-01-01 03:00:00",
                 "2022-01-01 15:00:00", "2022-01-01 04:00:00"],
    "step": [3, 15, 23, 33, 4, 16, 5]
})


In [3]:
#df_signals.write_csv("signals.csv")
df_events.write_csv("events.csv")

In [4]:
# # Detect and remove mismatched onsets/wakeups
# mismatches = df_events.group_by(['series_id', 'night']).agg(
#     (pl.col('event') == 'onset').sum().alias('onset'),
#     (pl.col('event') == 'wakeup').sum().alias('wakeup')
# ).filter(pl.col('onset') != pl.col('wakeup')).select(pl.all().exclude('onset', 'wakeup'))
# df_events = df_events.join(mismatches, on=['series_id', 'night'], how='anti')


In [5]:
df_events[['timestamp','step']]

timestamp,step
str,i64
"""2022-01-01 02:…",3
"""2022-01-01 14:…",15
"""2022-01-01 22:…",23
"""2022-01-02 08:…",33
"""2022-01-01 03:…",4
"""2022-01-01 15:…",16
"""2022-01-01 04:…",5


In [6]:
df_events

series_id,night,event,timestamp,step
str,i64,str,str,i64
"""A""",1,"""onset""","""2022-01-01 02:…",3
"""A""",1,"""wakeup""","""2022-01-01 14:…",15
"""A""",1,"""onset""","""2022-01-01 22:…",23
"""A""",2,"""wakeup""","""2022-01-02 08:…",33
"""B""",1,"""onset""","""2022-01-01 03:…",4
"""B""",1,"""wakeup""","""2022-01-01 15:…",16
"""C""",1,"""onset""","""2022-01-01 04:…",5


In [7]:
events_wide = df_events.pivot(index=['series_id'], columns='event', values='timestamp', aggregate_function='first')
result = df_signals.join(events_wide, on='series_id').with_columns(
    state = pl.col('timestamp').is_between('onset', 'wakeup')
)
print(events_wide.head())

shape: (3, 3)
┌───────────┬─────────────────────┬─────────────────────┐
│ series_id ┆ onset               ┆ wakeup              │
│ ---       ┆ ---                 ┆ ---                 │
│ str       ┆ str                 ┆ str                 │
╞═══════════╪═════════════════════╪═════════════════════╡
│ A         ┆ 2022-01-01 02:00:00 ┆ 2022-01-01 14:00:00 │
│ B         ┆ 2022-01-01 03:00:00 ┆ 2022-01-01 15:00:00 │
│ C         ┆ 2022-01-01 04:00:00 ┆ null                │
└───────────┴─────────────────────┴─────────────────────┘


In [8]:
df_signals=df_signals.with_columns(
    timestamp=pl.col('timestamp').str.strptime(pl.Datetime,"%Y-%m-%d %H:%M:%S")
    )
df_events=df_events.with_columns(
    timestamp=pl.col('timestamp').str.strptime(pl.Datetime,"%Y-%m-%d %H:%M:%S")
    )

In [13]:
a = (
    df_signals
    .join_asof(
        df_events.drop('timestamp'),
        on='step',
        by='series_id',
        strategy='backward',
    )
    .with_columns(
        state= pl.when((pl.col('event')=='onset')).then(1).otherwise(0),
    )
    .select(
        pl.all().exclude('event','night')
    )
    .join(
        df_events,
        on=['series_id','timestamp','step'],
        how='left'
    )
    .select(
        pl.all().exclude('night')
    )
).write_csv('test.csv')