In [1]:
import os
os.chdir("../") # path to vital folder
overwrite = False


# ------------------------------------------------------------------------------------------------
# helper functions
# ------------------------------------------------------------------------------------------------
import pandas as pd
import numpy as np
def build_level_map(df_train: pd.DataFrame, col: str, sort_levels=True):
    levels = df_train[col].dropna().unique()
    if sort_levels:
        levels = sorted(levels)
    level2idx = {lvl: i for i, lvl in enumerate(levels)}
    return level2idx

def encode_column(df: pd.DataFrame, col: str, level2idx: dict) -> np.ndarray:
    return (
        df[col]
        .map(level2idx)          # map strings → ints
        .fillna(-1)              # unseen level or NaN
        .astype(int)
        .to_numpy()
    )

def _get_arrays(df, config_dict, level_maps):
    ts = df[[str(i) for i in range(1, config_dict['seq_length'] + 1)]].values
    attrs_idx = []
    for attr_col in config_dict['txt2ts_y_cols']:
        attr_arr = encode_column(df, attr_col, level_maps[attr_col])
        attrs_idx.append(attr_arr)
    attrs_idx = np.stack(attrs_idx, axis=1)
    return ts, attrs_idx 

def prepare_tedit_data(df_train, df_test, df_left, config_dict, save_path):
    level_maps = {c: build_level_map(df_train, c) for c in config_dict["txt2ts_y_cols"]}

    train_mean_std = {'mean': np.nanmean(df_train[[str(i+1) for i in range(config_dict['seq_length'])]] .values), 
                    'std': np.nanstd(df_train[[str(i+1) for i in range(config_dict['seq_length'])]] .values, ddof=0) }
    
    train_ts, train_attrs_idx = _get_arrays(df_train, config_dict, level_maps)
    valid_ts, valid_attrs_idx = _get_arrays(df_test, config_dict, level_maps)
    test_ts, test_attrs_idx =  _get_arrays(df_left, config_dict, level_maps)

    # ­­­ normalise in-place (creates new arrays; original untouched) ­­­ #
    m, s = train_mean_std["mean"], train_mean_std["std"]
    s = s if s > 0 else 1e-8            # guard against division-by-zero
    train_ts = (train_ts - m) / s
    valid_ts = (valid_ts - m) / s
    test_ts  = (test_ts  - m) / s

    # ------------------------------------------------------------------ #
    # save everything under save_path
    # ------------------------------------------------------------------ #
    os.makedirs(save_path, exist_ok=True)
    np.save(os.path.join(save_path, "train_ts.npy"),         train_ts)
    np.save(os.path.join(save_path, "train_attrs_idx.npy"),  train_attrs_idx)
    np.save(os.path.join(save_path, "valid_ts.npy"),         valid_ts)
    np.save(os.path.join(save_path, "valid_attrs_idx.npy"),  valid_attrs_idx)
    np.save(os.path.join(save_path, "test_ts.npy"),          test_ts)
    np.save(os.path.join(save_path, "test_attrs_idx.npy"),   test_attrs_idx)
    # mean/std as a single dict → allow_pickle=True
    np.save(
        os.path.join(save_path, "train_mean_std.npy"),
        train_mean_std,
        allow_pickle=True,
    )
    return (
        train_ts,
        train_attrs_idx,
        valid_ts,
        valid_attrs_idx,
        test_ts,
        test_attrs_idx,
        train_mean_std,
    )

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [None]:
# loader = TEditDataset(df_train, df_test, df_left, config_dict, split="test").get_loader(batch_size=3)
# for batch in loader:
#     print(batch["x"].shape, batch["attrs"].shape, batch["tp"].shape)
#     break

Loaded test split – 369 samples.
torch.Size([3, 168, 1]) torch.Size([3, 3]) torch.Size([3, 168])


In [3]:
with open('run/settings.py', 'r') as file:
    exec(file.read())
with open('run/configs/air_quality.py', 'r') as file:
    exec(file.read())
with open('run/prepare_datasets/air_quality.py', 'r') as file:
    exec(file.read())

# save_path = "../../data/tedit_datasets/air_quality"
# (train_ts, train_attrs_idx,
#  valid_ts, valid_attrs_idx,
#  test_ts, test_attrs_idx,
#  train_mean_std, meta) = prepare_tedit_data(df_train, df_test, df_left, config_dict, save_path)


Random seed set to 333
using device:  cpu
After downsampling:
city_str
This is air quality in Beijing.    1852
This is air quality in London.      726
Name: count, dtype: int64
After downsampling:
city_str
This is air quality in Beijing.    529
This is air quality in London.     208
Name: count, dtype: int64
After downsampling:
city_str
This is air quality in Beijing.    265
This is air quality in London.     104
Name: count, dtype: int64


final distribution of text prediction
text
This is air quality in Beijing. The season is winter.    575
This is air quality in Beijing. The season is spring.    446
This is air quality in Beijing. The season is fall.      433
This is air quality in Beijing. The season is summer.    398
This is air quality in London. The season is winter.     244
This is air quality in London. The season is spring.     213
This is air quality in London. The season is fall.       143
This is air quality in London. The season is summer.     126
Name: count, dtype: int6

In [4]:
with open('run/settings.py', 'r') as file:
    exec(file.read())
with open('run/configs/synthetic_w_gt.py', 'r') as file:
    exec(file.read())
with open('run/prepare_datasets/synthetic.py', 'r') as file:
    exec(file.read())

# save_path = "../../data/tedit_datasets/synthetic_w_gt"
# (train_ts, train_attrs_idx,
#  valid_ts, valid_attrs_idx,
#  test_ts, test_attrs_idx,
#  train_mean_std, meta) = prepare_tedit_data(df_train, df_test, df_left, config_dict, save_path)


using device:  cpu
{'No trend.': (210, 60, 30), 'The time series shows upward linear trend.': (210, 60, 30), 'The time series shows downward linear trend.': (210, 60, 30), 'The time series shows upward quadratic trend.': (210, 60, 30), 'The time series shows downward quadratic trend.': (210, 60, 30), 'No seasonal pattern.': (210, 60, 30), 'The time series exhibits a seasonal pattern.': (210, 60, 30), 'No sharp shifts.': (210, 60, 30), 'The mean of the time series shifts upwards.': (210, 60, 30), 'The mean of the time series shifts downwards.': (210, 60, 30), 'The time series exhibits low variability.': (210, 60, 30), 'The time series exhibits high variability.': (210, 60, 30)}


final distribution of text prediction
text
No trend. No seasonal pattern. No sharp shifts. The time series exhibits low variability.                                                                                                210
No trend. No seasonal pattern. No sharp shifts. The time series exhibits high va

In [5]:
with open('run/settings.py', 'r') as file:
    exec(file.read())
with open('run/configs/synthetic.py', 'r') as file:
    exec(file.read())
with open('run/prepare_datasets/synthetic.py', 'r') as file:
    exec(file.read())

# save_path = "../../data/tedit_datasets/synthetic"
# (train_ts, train_attrs_idx,
#  valid_ts, valid_attrs_idx,
#  test_ts, test_attrs_idx,
#  train_mean_std, meta) = prepare_tedit_data(df_train, df_test, df_left, config_dict, save_path)


using device:  cpu
text
No trend. No seasonal pattern. No sharp shifts. The time series exhibits low variability.                                                                                                10000
No trend. No seasonal pattern. No sharp shifts. The time series exhibits high variability.                                                                                               10000
The time series shows downward linear trend. The time series exhibits a seasonal pattern. The mean of the time series shifts upwards. The time series exhibits low variability.          10000
The time series shows downward linear trend. The time series exhibits a seasonal pattern. The mean of the time series shifts upwards. The time series exhibits high variability.         10000
The time series shows downward linear trend. The time series exhibits a seasonal pattern. The mean of the time series shifts downwards. The time series exhibits low variability.        10000
The time series shows

In [2]:
with open('run/settings.py', 'r') as file:
    exec(file.read())
with open('run/configs/nicu.py', 'r') as file:
    exec(file.read())
with open('run/prepare_datasets/nicu.py', 'r') as file:
    exec(file.read())

# save_path = "../../data/tedit_datasets/nicu"
# (train_ts, train_attrs_idx,
#  valid_ts, valid_attrs_idx,
#  test_ts, test_attrs_idx,
#  train_mean_std, meta) = prepare_tedit_data(df_train, df_test, df_left, config_dict, save_path)


Random seed set to 333
using device:  cpu

Sample of patients with positive labels:
VitalID
1018    8
5170    8
1835    8
2361    8
2791    8
dtype: int64


Processing descriptions: 100%|██████████| 131/131 [00:24<00:00,  5.45it/s]


text
This infant will survive. This infant has gestational age 25 weeks. Birth weight is 640 grams. Moderate variability. Moderate amount of consecutive increases. No events.    254
This infant will survive. This infant has gestational age 23 weeks. Birth weight is 640 grams. Moderate variability. Moderate amount of consecutive increases. No events.    177
This infant will survive. This infant has gestational age 24 weeks. Birth weight is 620 grams. Moderate variability. Moderate amount of consecutive increases. No events.    164
This infant will survive. This infant has gestational age 25 weeks. Birth weight is 870 grams. Moderate variability. Moderate amount of consecutive increases. No events.    160
This infant will survive. This infant has gestational age 24 weeks. Birth weight is 650 grams. Moderate variability. Moderate amount of consecutive increases. No events.    154
Name: count, dtype: int64

Sample of patients with positive labels:
TestID
508     8
707     8
1903    8
817  

Processing descriptions: 100%|██████████| 123/123 [00:19<00:00,  6.39it/s]


text
This infant will survive. This infant has gestational age 27 weeks. Birth weight is 1000 grams. Moderate variability. Moderate amount of consecutive increases. No events.    163
This infant will survive. This infant has gestational age 27 weeks. Birth weight is 890 grams. Moderate variability. Moderate amount of consecutive increases. No events.     149
This infant will survive. This infant has gestational age 24 weeks. Birth weight is 580 grams. Moderate variability. Moderate amount of consecutive increases. No events.     145
This infant will survive. This infant has gestational age 24 weeks. Birth weight is 760 grams. Moderate variability. Moderate amount of consecutive increases. No events.     142
This infant will survive. This infant has gestational age 26 weeks. Birth weight is 510 grams. Moderate variability. Moderate amount of consecutive increases. No events.     125
Name: count, dtype: int64
After downsampling:
description_succ_inc
Moderate amount of consecutive increas