In [None]:
import pandas as pd
import torch as ch
import numpy as np
from os import path

In [None]:
data_splits_path = '/Users/jk1/Downloads/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'
splits = ch.load(path.join(data_splits_path))

In [None]:
x_train = splits[0][0]
y_train = splits[0][2]

In [None]:
x_train.shape, y_train.shape

In [None]:
y_train

In [None]:
from prediction.utils.utils import aggregate_features_over_time

x_train = x_train[:, :, :, -1].astype('float32')



In [None]:
# aggregate features over time so that one timepoint is one sample
fold_X_train, fold_y_train = aggregate_features_over_time(x_train, np.array([None]), moving_average=False)

In [None]:
fold_X_train.shape, fold_y_train.shape

In [None]:
from prediction.short_term_outcome_prediction.timeseries_decomposition import decompose_and_label_timeseries

map, flat_labels = decompose_and_label_timeseries(x_train, y_train)

In [None]:
map

In [None]:
timeseries = splits[0][0]
y_df = splits[0][2]

In [None]:
def aggregate_and_label_timeseries(timeseries, y_df, target_time_to_outcome=6, mask_after_first_positive=True):
    
    all_subj_labels = []
    all_subj_data = []
    n_timepoints = timeseries.shape[1]
    for idx, cid in enumerate(timeseries[:, 0, 0, 0]):
        x_data = timeseries[None, idx, :, :, -1].astype('float32')
        if cid not in y_df.case_admission_id.values:
            labels = np.zeros(n_timepoints)
        else:
            event_ts = int(y_df[y_df.case_admission_id == cid].relative_sample_date_hourly_cat.values[0])
            # let labels be 0 until 6 ts before the event then 1 until the end then 0
            n_pos_start = max(0, event_ts - target_time_to_outcome)
            n_pos_end = event_ts
            labels = np.concatenate((np.zeros(n_pos_start), np.ones(n_pos_end - n_pos_start), np.zeros(n_timepoints - n_pos_end)))
    
            if mask_after_first_positive:
                labels = labels[:n_pos_start + 1]
                x_data = x_data[:, :n_pos_start + 1, :]
        x_data, _ = aggregate_features_over_time(x_data, np.array([None]), moving_average=False)
        all_subj_labels.append(labels)
        all_subj_data.append(x_data)
    
    return all_subj_data, all_subj_labels

In [None]:
from sklearn.preprocessing import StandardScaler


def prepare_aggregate_dataset(scenario, rescale=True, target_time_to_outcome=6, mask_after_first_positive=True):
    """
    Prepares the dataset as an aggregate dataset (one sample per timepoint) and returns the train and validation sets.

    Args:
        scenario (tuple): tuple of (X_train, X_val, y_train, y_val)
        rescale (bool): whether to rescale the data or not
        target_time_to_outcome (int): number of timesteps to predict in the future
    """
    X_train, X_val, y_train, y_val = scenario

    train_data, train_labels = aggregate_and_label_timeseries(X_train, y_train, target_time_to_outcome, mask_after_first_positive)
    val_data, val_labels = aggregate_and_label_timeseries(X_val, y_val, target_time_to_outcome, mask_after_first_positive)
    
    train_data = np.concatenate(train_data)
    train_labels = np.concatenate(train_labels)
    
    val_data = np.concatenate(val_data)
    val_labels = np.concatenate(val_labels)
    
    scaler = StandardScaler()
    if rescale:
        train_data = scaler.fit_transform(train_data)
        val_data = scaler.transform(val_data)
    
    
    return train_data, val_data, train_labels, val_labels
 

In [None]:
all_datasets = [prepare_aggregate_dataset(x, rescale=True, target_time_to_outcome=6, mask_after_first_positive=True) for x in splits]

In [None]:
splits[0][0].shape, splits[0][2].shape

In [None]:
all_datasets[0][0].shape, all_datasets[0][2].shape