In [None]:
import pandas as pd
import numpy as np

In [None]:
features_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_09052025_220520/preprocessed_features_09052025_220520.csv'
short_term_outcomes_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_09052025_220520/preprocessed_outcomes_short_term_09052025_220520.csv'

In [None]:
features = pd.read_csv(features_path)
short_term_outcomes = pd.read_csv(short_term_outcomes_path)

In [None]:
selection_cid = short_term_outcomes.case_admission_id.values[0:5]

In [None]:
x_df = features[features.case_admission_id.isin(selection_cid)]
y_df = short_term_outcomes[short_term_outcomes.case_admission_id.isin(selection_cid)]

In [None]:
from prediction.outcome_prediction.data_loading.data_formatting import features_to_numpy

x_np = features_to_numpy(x_df,
                                   ['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label', 'value'])


In [None]:
x_np.shape, y_df.shape

In [None]:
y_df

In [None]:
features.head()

## testing final functions

In [None]:
from prediction.short_term_outcome_prediction.timeseries_decomposition import decompose_and_label_timeseries


map, flat_labels = decompose_and_label_timeseries(x_np, y_df, target_time_to_outcome=6, target_interval=True, restrict_to_first_event=False)

In [None]:
flat_labels[72*4:72*5]

In [None]:
from torch import tensor
from prediction.short_term_outcome_prediction.timeseries_decomposition import StrokeUnitBucketDataset


stroke_unit_dataset = StrokeUnitBucketDataset(tensor(x_np[:, :, :, -1].astype('float32')), tensor(flat_labels), map)

In [None]:
stroke_unit_dataset[72*4+5][1]

## initial exploration

In [None]:
delta = 6

# create index mapping (list of (cid idx, ts) in which the index in the list is the index of the sample in the flattened targets array)
map = []
# labels for every sub sequence
flat_labels = []
# maximum number of timesteps (for most patients is max of relative_sample_date_hourly_cat, but for some patients it until the occurrence of the event)
overall_max_ts = x_np.shape[1]
for idx, cid in enumerate(x_np[:, 0, 0, 0]):
    if cid in y_df.case_admission_id.values:
        max_ts = y_df[y_df.case_admission_id == cid].relative_sample_date_hourly_cat.values[0]
    else:
        max_ts = overall_max_ts
    for ts in range(int(max_ts)):
        # store idx of cid and idx of ts
        map.append((idx, ts))
        if cid in y_df.case_admission_id.values and ts + delta >= max_ts:
            flat_labels.append(1)
        else:
            flat_labels.append(0)

In [None]:
len(flat_labels)

In [None]:
x_np[0, :73]

In [None]:
len(map), len(flat_labels)

In [None]:
flat_labels

In [None]:
# create a mapping from overall idx to length of sequence
idx_to_len_map = [(idx, map_i[1] + 1) for idx, map_i in enumerate(map)]

In [None]:
idx_to_len_map

In [None]:
from torch.utils.data import Sampler, Dataset
from torch import tensor
from collections import OrderedDict
from random import shuffle


class StrokeUnitBucketDataset(Dataset):

    def __init__(self, inputs: tensor, targets: tensor, idx_map: list): 
        """
        Every sample is a sequence of timesteps with an associated label/target.
        - The index of each sample is an index in the flattened targets array
        - To retrieve the inputs for a sample, we need to know the case admission id and the last timestep for this sample (provided in idx_map) 
        
        Args:
            inputs (tensor): tensor of shape (num_samples, num_features, num_timesteps)
            targets (tensor): tensor of shape (with targets for all idx) (flattened)
            idx_map (list): list of tuples (cid_idx, ts) where idx in list is the index of the sample in targets, cid_idx is the idx of case admission id, and ts is the last timestep for this idx 
                - This is necessary to retrieve the inputs for this idx (as every patient has multiple timesteps)
        """
        self.inputs = inputs
        self.targets = targets
        self.idx_map = idx_map

    def __len__(self):
        return len(self.idx_map)

    def __getitem__(self, index):
        cid_idx = self.idx_map[index][0]
        last_ts = self.idx_map[index][1]
        if self.targets is None:
            return self.inputs[cid_idx, 0: last_ts + 1]
        else:
            return self.inputs[cid_idx, 0: last_ts + 1], self.targets[index]

In [None]:
from torch import tensor

# only retain values from X (dropping case admission id, relative_sample_date_hourly_cat, sample_label)
stroke_unit_dataset = StrokeUnitBucketDataset(tensor(x_np[:, :, :, -1].astype('float32')), tensor(flat_labels), map)

In [None]:
map[1], flat_labels[2]

In [None]:
stroke_unit_dataset[0][0].shape[-1]

In [None]:
class BucketBatchSampler(Sampler):
    # Ref: https://discuss.pytorch.org/t/tensorflow-esque-bucket-by-sequence-length/41284/13
    # want inputs to be an array
    def __init__(self, idx_to_len_map, batch_size):
        self.batch_size = batch_size
        self.idx_to_len_map = idx_to_len_map # list of tuples (idx, length)
        self.batch_list = self._generate_batch_map()
        self.num_batches = len(self.batch_list)

    def _generate_batch_map(self):
        # shuffle all of the indices first so they are put into buckets differently
        shuffle(self.idx_to_len_map)
        # Organize lengths, e.g., batch_map[10] = [30, 124, 203, ...] <= indices of sequences of length 10
        batch_map = OrderedDict()
        for idx, length in self.idx_to_len_map:
            if length not in batch_map:
                batch_map[length] = [idx]
            else:
                batch_map[length].append(idx)
        # Use batch_map to split indices into batches of equal size
        # e.g., for batch_size=3, batch_list = [[23,45,47], [49,50,62], [63,65,66], ...]
        batch_list = []
        for length, indices in batch_map.items():
            for group in [indices[i:(i + self.batch_size)] for i in range(0, len(indices), self.batch_size)]:
                batch_list.append(group)
        return batch_list

    def batch_count(self):
        return self.num_batches

    def __len__(self):
        return len(self.idx_to_len_map)

    def __iter__(self):
        self.batch_list = self._generate_batch_map()
        # shuffle all the batches so they arent ordered by bucket size
        shuffle(self.batch_list)
        for i in self.batch_list:
            yield i

In [None]:
batch_size = 3
bucket_sampler = BucketBatchSampler(idx_to_len_map, batch_size)

In [None]:
bucket_sampler._generate_batch_map()

In [None]:
bucket_sampler.batch_count()

In [None]:
next(bucket_sampler.__iter__())

In [None]:
from torch.utils.data import DataLoader

dataloader = DataLoader(stroke_unit_dataset, batch_size=3, batch_sampler=bucket_sampler, shuffle=False, num_workers=0, drop_last=False)


In [None]:
# get first batch
dataloader_iter = iter(dataloader)
x, y = next(dataloader_iter)
print(x.shape)

In [None]:
for i, (x, y) in enumerate(dataloader):
    print(x.shape)
    print(y)
    if i > 10:
        break