<a href="https://www.kaggle.com/code/wangyuweikiwi/mimi-iii-time-series-data-preprocessing?scriptVersionId=198599877" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Load data from dataset

As a demo, only use some of the values in vitals_records, not all

In [1]:
import pickle
import copy
import numpy as np


with open('/kaggle/input/mimic-iii-time-series-filters-by-measurement-ids/vitals_records_10000.p', 'rb') as file:
    vitals = pickle.load(file)
print(len(vitals))

5000


In [2]:
def list_info(vitals: list):
    print(type(vitals))

    print(f"The admissions: {len(vitals)}")
    print(f"the vital signs (or categories of measurements): {len(vitals[0])}")
    print(f"Time-series data points: {len(vitals[0][0])}")
    
list_info(vitals)

<class 'list'>
The admissions: 5000
the vital signs (or categories of measurements): 16
Time-series data points: 130


# Description of data

* **First dimension(5000):** The outermost list contains 5000 elements, it corresponding to the number of admissions or records we are processing.
* **Second dimension (16):** Each of the 5000 elements contains a list with 16 sub-elements. This represent different vital signs or measurement categories (e.g., SpO2, HR, SBP, etc.) for each admission.
* **Third dimension (130):** Each of the 16 sub-elements contains a list of 130 time-series measurements, which likely corresponds to the number of timestamped data points collected for each vital sign.

In [3]:
def get_list_dimensions(lst):
    if isinstance(lst, list):
        return [len(lst)] + get_list_dimensions(lst[0]) if lst else [0]
    else:
        return []


dimensions = get_list_dimensions(vitals)
print(f"Dimensions of the list: {dimensions}")

Dimensions of the list: [5000, 16, 130]


In [4]:
with open('/kaggle/input/mimic-iii-time-series-filters-by-measurement-ids/adm_type_los_mortality.p', 'rb') as file:
    adm_info = pickle.load(file)
    
adm_info_5000=adm_info[:5000]
print(len(adm_info_5000))

5000


In [5]:
print(adm_info_5000[0])

(165315, 'EMERGENCY', Decimal('27'), 0)


# Preprocessing data


## 1.The length of stay should >=48H

`adm_info_5000[x][2>=48]`

In [6]:
# This step is about to filter vitals result by choose adm_info's 4th value larger than 48.  Only keep the corresponding vitals values
# Here is the original codes, correct me if I messed the logic. adm_id_needed is confusing

# adm_id = [record[0] for record in adm_info]
# adm_id_needed = [record[0] for record in adm_info if record[2] >= 48]
# vitals_dict = {}
# for i in range(len(adm_id)):
#     vitals_dict[adm_id[i]] = vitals[i]
# vitals = [vitals_dict[x] for x in adm_id_needed]
# label = [rec[3] for x in adm_id_needed for rec in adm_info if x == rec[0]]

# batch_size is number of records within each vital file
batch_size = 10000
# batch_idx is index of the vital file.  
# Please note, vital files should be loaded in order
batch_idx = 1

start_point= batch_size*(batch_idx-1)

print(start_point)

label = []
for record in adm_info_5000:
    if record[2] >= 48:
        label.append(record[3])

print(len(label))

vitals_new = []
for i in range(len(vitals)):
    if adm_info_5000[start_point+i][2] >= 48:
        vitals_new.append(vitals[i])
print(len(vitals_new))

vitals = vitals_new

0
4531
4531


In [7]:
list_info(vitals)

<class 'list'>
The admissions: 4531
the vital signs (or categories of measurements): 16
Time-series data points: 84


# Visualize the sample's time series length

In [8]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def interactive_visualize_top_4_samples(vitals, num_samples=4):
    """
    Interactive visualization of the lengths of the time-series data for the top 'num_samples' Samples using Plotly.
    
    vitals: 2D list where the first dimension is Samples and the second dimension is the time-series data.
    num_samples: Number of samples (Samples) to visualize (default is 4).
    """
    # Create subplots: 2 rows and 2 columns
    fig = make_subplots(rows=2, cols=2, subplot_titles=[f"Sample {i+1}" for i in range(num_samples)])

    # Plot for each Sample in their respective subplot
    for i in range(num_samples):
        Sample_data = vitals[i]  # Get data for the ith Sample
        time_series_lengths = [len(vital_sign) for vital_sign in Sample_data]  # Get the lengths of each vital sign
        
        # Determine row and column for the subplot
        row = (i // 2) + 1
        col = (i % 2) + 1
        
        # Add bar chart to the appropriate subplot
        fig.add_trace(
            go.Bar(
                x=[f'Vital Sign {j+1}' for j in range(len(time_series_lengths))],
                y=time_series_lengths,
                marker=dict(color='skyblue'),
                name=f'Sample {i+1}'
            ),
            row=row, col=col
        )

    # Update layout for the entire figure
    fig.update_layout(
        title_text="Time-Series Lengths for Top 4 Samples",
        showlegend=False,  # Disable legend since titles indicate the Samples
        height=800,
        hovermode="x unified"
    )
    
    # Set axis labels for each subplot
    fig.update_xaxes(title_text='Vital Sign Index', tickangle=-45)
    fig.update_yaxes(title_text='Length of Time-Series Data (Number of Time Steps)')

    # Show the interactive plot
    fig.show()

# Visualize the time-series lengths for the top 4 Samples
interactive_visualize_top_4_samples(vitals, num_samples=4)


## 2.Filter the vital signs(or categories of measurements)

```
# Item - item_id

# SpO2 - 646, 220277
# HR - 211, 220045
# RR - 618, 615, 220210, 224690
# SBP - 51,442,455,6701,220179,220050
# DBP - 8368,8440,8441,8555,220180,220051
# EtCO2 - 1817, 228640
# Temp(F) - 223761,678
# Temp(C) - 223762,676
# TGCS - 198, 226755, 227013
# CRR - 3348
# Urine Output - 43647, 43053, 43171, 43173, 43333, 43347,
# 43348, 43355, 43365, 43373, 43374, 43379, 43380, 43431,
# 43519, 43522, 43537, 43576, 43583, 43589, 43638, 43654,
# 43811, 43812, 43856, 44706, 45304, 227519,
# FiO2 - 2981, 3420, 3422, 223835,
# Glucose - 807,811,1529,3745,3744,225664,220621,226537
# pH - 780, 860, 1126, 1673, 3839, 4202, 4753, 6003, 220274, 220734, 223830, 228243,


# vitals[0] - SpO2
# vitals[1] - HR
```

### Remove the lossing data.

* num_features=12 because the measurement value is blank, like vitals[0][5].


### 2880 for 48 hours of minute-level data

* max_length=2881

The collected data at 1-minute intervals for 48 hours, you'd get 48 hours * 60 minutes = 2880 timestamps. The extra 1 might be padding or for safety against slight variations in timestamp intervals.

In [9]:
# Original code: https://github.com/mlds-lab/interp-net/blob/master/src/mimic_preprocessing.py#L25

import numpy as np
from concurrent.futures import ProcessPoolExecutor

def trim_los_parallel(data_chunk, length_of_stay):
    num_features = 12  # final features (excluding EtCO2) because the measurement value is blank: vitals[0][5]
    max_length = 2881  # maximum length of time stamp
    a = np.full((len(data_chunk), num_features, max_length), -100, dtype=float)  # initialize array with -100 (missing data)
    timestamps = []

    for i in range(len(data_chunk)):
        # Process temperature conversion in a vectorized way
        if data_chunk[i][7]:
            temp_array = np.array([elem[1] for elem in data_chunk[i][7] if elem[1] is not None])
            data_chunk[i][6] += [(elem[0], temp * 1.8 + 32) for elem, temp in zip(data_chunk[i][7], temp_array)]

        # Combine data[9] with data[10] and data[11]
        data_chunk[i][9].extend(data_chunk[i][10] + data_chunk[i][11])

        # Remove unwanted elements (EtCO2 data)
        del data_chunk[i][5:7]
        del data_chunk[i][8]

        # Collect unique timestamps across all features
        all_timestamps = sorted(set([elem[0] for j in range(num_features) for elem in data_chunk[i][j]]))

        # Extract first 48-hour data
        first_ts = all_timestamps[0] if all_timestamps else None
        TS = [ts for ts in all_timestamps if (ts - first_ts).total_seconds() / 3600 <= length_of_stay]

        timestamps.append(TS)

        for j in range(num_features):
            feature_data = data_chunk[i][j]
            feature_dict = {entry[0]: entry[1] for entry in feature_data}  # Convert list to dictionary for fast lookup

            for k, ts in enumerate(TS):
                if ts in feature_dict:
                    value = feature_dict[ts]
                    if value is None or value in ('Other/Remarks', 'Comment'):
                        a[i, j, k] = -100
                    elif value in ('Normal <3 secs', 'Normal <3 Seconds', 'Brisk'):
                        a[i, j, k] = 1
                    elif value in ('Abnormal >3 secs', 'Abnormal >3 Seconds', 'Delayed'):
                        a[i, j, k] = 2
                    else:
                        a[i, j, k] = value
                else:
                    a[i, j, k] = -100  # missing data

    return a, timestamps

# Function to split the data and run in parallel
def run_trim_los_in_parallel(data, length_of_stay, num_workers=4):
    # Split data into chunks for parallel processing
    chunk_size = len(data) // num_workers
    data_chunks = [data[i:i + chunk_size] for i in range(0, len(data), chunk_size)]

    results = []
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        futures = [executor.submit(trim_los_parallel, chunk, length_of_stay) for chunk in data_chunks]
        for future in futures:
            results.append(future.result())

    # Combine results from all chunks
    all_a = np.concatenate([result[0] for result in results], axis=0)
    all_timestamps = sum([result[1] for result in results], [])
    
    return all_a, all_timestamps

# Example usage
hours_look_ahead = 48
num_workers = 15  # Use the available 15 cores for parallel processing
vitals, timestamps = run_trim_los_in_parallel(vitals, hours_look_ahead, num_workers)

In [10]:
print(len(vitals))
print(len(timestamps))

4531
4531


# Visualization top 100 samples's timestamps data

In [11]:
import plotly.graph_objects as go

def interactive_visualize_multiple_samples(timestamps, num_samples=4):
    """
    Interactive visualization of the timestamp data for the first 'num_samples' samples.
    
    timestamps: 2D list where the first dimension is samples and the second dimension is time-series.
    num_samples: Number of samples to visualize (default is 4).
    """
    # Create a figure
    fig = go.Figure()

    # Plot each sample's time-series
    for i in range(num_samples):
        sample_timestamps = timestamps[i]  # Extract the timestamps for the i-th sample
        fig.add_trace(go.Scatter(
            x=list(range(len(sample_timestamps))), 
            y=sample_timestamps,
            mode='lines+markers', 
            name=f'Sample {i+1}'
        ))
    
    # Set labels and title
    fig.update_layout(
        title=f'Interactive Time-Series Data for First {num_samples} Samples',
        xaxis_title='Time Step Index',
        yaxis_title='Time (Hours)',
        hovermode='x unified'
    )
    
    # Show the interactive plot
    fig.show()

# Call the function to visualize the first 4 samples interactively
interactive_visualize_multiple_samples(timestamps, num_samples=100)


In [12]:
list_info(vitals)

<class 'numpy.ndarray'>
The admissions: 4531
the vital signs (or categories of measurements): 12
Time-series data points: 2881


In [13]:
def interactive_visualize_top_4_samples(vitals, num_samples=4):
    """
    Interactive visualization of the lengths of the time-series data for the top 'num_samples' Samples using Plotly.
    
    vitals: 2D list where the first dimension is Samples and the second dimension is the time-series data.
    num_samples: Number of samples (Samples) to visualize (default is 4).
    """
    # Create subplots: 2 rows and 2 columns
    fig = make_subplots(rows=2, cols=2, subplot_titles=[f"Sample {i+1}" for i in range(num_samples)])

    # Plot for each Sample in their respective subplot
    for i in range(num_samples):
        Sample_data = vitals[i]  # Get data for the ith Sample
        time_series_lengths = [len(vital_sign) for vital_sign in Sample_data]  # Get the lengths of each vital sign
        
        # Determine row and column for the subplot
        row = (i // 2) + 1
        col = (i % 2) + 1
        
        # Add bar chart to the appropriate subplot
        fig.add_trace(
            go.Bar(
                x=[f'Vital Sign {j+1}' for j in range(len(time_series_lengths))],
                y=time_series_lengths,
                marker=dict(color='skyblue'),
                name=f'Sample {i+1}'
            ),
            row=row, col=col
        )

    # Update layout for the entire figure
    fig.update_layout(
        title_text="Time-Series Lengths for Top 4 Samples",
        showlegend=False,  # Disable legend since titles indicate the Samples
        height=800,
        hovermode="x unified"
    )
    
    # Set axis labels for each subplot
    fig.update_xaxes(title_text='Vital Sign Index', tickangle=-45)
    fig.update_yaxes(title_text='Length of Time-Series Data (Number of Time Steps)')

    # Show the interactive plot
    fig.show()

# Visualize the time-series lengths for the top 4 Samples
interactive_visualize_top_4_samples(vitals, num_samples=4)

In [14]:
normal_values=[i for i in vitals[0][0] if i!=-100]

In [15]:
print(len(timestamps))
print(len(timestamps[0]))

4531
85


## 3.Fixing input format(Trim time stamps higher than 200)

Return the input in the proper format

* x: observed values
* M: masking, 0 indicates missing values
* delta: time points of observation


Here we set a **consistent length** of time streps across all samples by `timestamp=200`, the code ensure that all sample records have exactly 200 time steps. (No sure the reason use 200 time-steps.)

In [16]:
# Original https://github.com/mlds-lab/interp-net/blob/master/src/multivariate_example.py#L45
# Adapted by Micost, Aisuko, Yuwei Wang

def fix_input_format(x, T):
    """
    The code aims to standardize time-series data across Samples and remove outliers, missing values, and negative values.
    
    x: observed values
    M: masking, 0 indicates missing values
    delta: time points of observation
    """
    timestamp = 200
    num_features = 12

    # trim time stamps higher than 200
    for i in range(len(T)):
        if len(T[i]) > timestamp:
            T[i] = T[i][:timestamp]

    x = x[:, :, :timestamp]
    M = np.zeros_like(x)
    delta = np.zeros_like(x)
    print(x.shape, len(T))

    for t in T:
        for i in range(1, len(t)):
            t[i] = (t[i] - t[0]).total_seconds()/3600.0
        if len(t) != 0:
            t[0] = 0

    # count outliers and negative values as missing values
    # M = 0 indicates missing value
    # M = 1 indicates observed value
    # now since we have mask variable, we don't need -100
    M[x > 500] = 0
    x[x > 500] = 0.0
    M[x < 0] = 0
    x[x < 0] = 0.0
    M[x > 0] = 1

    for i in range(num_features):
        for j in range(x.shape[0]):
            for k in range(len(T[j])):
                delta[j, i, k] = T[j][k]

    return x, M, delta

# Vitials: 3D array (Samples, features, timestamps) containing observed values for vital signs.
# Timestamp: List of timestamp lists corresponding to each Sample.
x, M, delta = fix_input_format(vitals, timestamps)

(4531, 12, 200) 4531


In [17]:
list_info(x)
list_info(M)
list_info(delta)

<class 'numpy.ndarray'>
The admissions: 4531
the vital signs (or categories of measurements): 12
Time-series data points: 200
<class 'numpy.ndarray'>
The admissions: 4531
the vital signs (or categories of measurements): 12
Time-series data points: 200
<class 'numpy.ndarray'>
The admissions: 4531
the vital signs (or categories of measurements): 12
Time-series data points: 200


In [18]:
np.save('x.npy', x)
np.save('M.npy', M)
np.save('delta.npy', delta)

In [19]:
interactive_visualize_multiple_samples(timestamps, num_samples=100)

# Handle missing data in time-series

**NOTE: We did't execute the function.**

Here We handle the missing data by imputing the missing values with the global mean of the observed values for each feature(vital sign).


### The inputs are `vitals` and `mask`

* vitals: A 3D numpy array (x) containing time-series data for multiple Samples, where:
  * The dimensions are [number of Samples, number of features, number of timestamps].
* mask: A 3D numpy array (M) representing the observed values mask, where:
  * M[i, j, k] = 1 indicates that the value vitals[i, j, k] is observed.
  * M[i, j, k] = 0 indicates that the value vitals[i, j, k] is missing.

In [20]:
# original https://github.com/mlds-lab/interp-net/blob/master/src/multivariate_example.py#L45

def mean_imputation(vitals, mask):
    """For the time series missing entirely, our interpolation network 
    assigns the starting point (time t=0) value of the time series to 
    the global mean before applying the two-layer interpolation network.
    In such cases, the first interpolation layer just outputs the global
    mean for that channel, but the second interpolation layer performs 
    a more meaningful interpolation using the learned correlations from
    other channels."""
    counts = np.sum(np.sum(mask, axis=2), axis=0)
    mean_values = np.sum(np.sum(vitals*mask, axis=2), axis=0)/counts
    for i in range(mask.shape[0]):
        for j in range(mask.shape[1]):
            if np.sum(mask[i, j]) == 0:
                mask[i, j, 0] = 1
                vitals[i, j, 0] = mean_values[j]
    return

# Simulating missing values that the autoencoder must reconstruct


The autoencoder's learning mechanism. No useful for our case.

In [21]:
def hold_out(mask, perc=0.2):
    """To implement the autoencoder component of the loss, we introduce a set
    of masking variables mr (and mr1) for each data point. If drop_mask = 0,
    then we removecthe data point as an input to the interpolation network,
    and includecthe predicted value at this time point when assessing
    the autoencoder loss. In practice, we randomly select 20% of the
    observed data points to hold out from
    every input time series."""
    drop_mask = np.ones_like(mask)
    drop_mask *= mask
    for i in range(mask.shape[0]):
        for j in range(mask.shape[1]):
            count = np.sum(mask[i, j], dtype='int')
            if int(0.20*count) > 1:
                index = 0
                r = np.ones((count, 1))
                b = np.random.choice(count, int(0.20*count), replace=False)
                r[b] = 0
                for k in range(mask.shape[2]):
                    if mask[i, j, k] > 0:
                        drop_mask[i, j, k] = r[index]
                        index += 1
    return drop_mask


# How we going to handle missing values effectively?

There are two ways to do that.

###  Using GluonTS for Time-Series Modeling

GluonTS is a powerful toolkit for probabilistic time-series modeling. It can model temporal dependencies and handle missing values effectively by learning patterns from the observed time-series.


### Using Amazon Chronos (Transformer-based Forecasting Model):

Chronos is Amazon’s deep learning-based forecasting tool built on transformer architectures, specifically designed for time-series forecasting.


I already familiar with transformer models, Chronos might be a better fit for our project. The Chronos model is only support univariate forcasting, see https://github.com/amazon-science/chronos-forecasting/issues/13


# Visualizing Time-Series data

Let's visualize the time-seris of one sample across different features.

* **x (observed values):** This array contains the actual time-series data of vital signs (with missing values and outliers handled). This is the array we would most likely want to visualize, as it represents the vital sign measurements over time. It shows how the data changes across time steps for each feature and Sample.

* **M (masking):** This is a mask that indicates whether a value is observed (1) or missing (0). Visualizing this could be useful to understand where missing data points occur, but it won't give insights into the actual values of the vital signs. It would mainly show the distribution of missing or invalid data over time.

* **delta (time points of observation):** This array represents the time intervals between observations, indicating how time progresses for each Sample and feature. Visualizing delta can help we understand the gaps between observations, which might be relevant in understanding irregular sampling patterns.


## Visualizing `x` of a single Sample for all 12 features


In [22]:
import plotly.graph_objects as go

def interactive_visualize_sample_time_series(x, sample_index=0, num_features=12, num_time_steps=200):
    """
    Interactive visualization of time-series data for a single sample using Plotly.
    
    x: 3D numpy array where the first dimension is samples, second is features, third is time steps.
    sample_index: Index of the sample to visualize.
    num_features: Number of features (default is 12).
    num_time_steps: Number of time steps to visualize (default is 200).
    """
    time_steps = list(range(num_time_steps))  # Create a list of time steps
    
    # Create a Plotly figure
    fig = go.Figure()

    # Plot each feature's time-series data for the selected sample
    for i in range(num_features):
        fig.add_trace(go.Scatter(
            x=time_steps, 
            y=x[sample_index, i, :], 
            mode='lines', 
            name=f'Feature {i+1}'
        ))
    
    # Set labels and title
    fig.update_layout(
        title=f'Observed Values for Sample {sample_index + 1}',
        xaxis_title='Time Steps',
        yaxis_title='Observed Values',
        hovermode='x unified'
    )

    # Show the interactive plot
    fig.show()

# Example usage to visualize time-series data for sample 0
interactive_visualize_sample_time_series(x, sample_index=0)



In [23]:
import plotly.graph_objects as go

def interactive_visualize_mask_heatmap(M, sample_index=0, num_features=12, num_time_steps=200):
    """
    Interactive heatmap visualization of the mask matrix for a single sample using Plotly.
    
    M: 3D numpy array (samples, features, time steps) where 1 indicates observed and 0 indicates missing.
    sample_index: Index of the sample to visualize (default is 0).
    num_features: Number of features (default is 12).
    num_time_steps: Number of time steps to visualize (default is 200).
    """
    # Create the heatmap using Plotly
    fig = go.Figure(data=go.Heatmap(
        z=M[sample_index],  # The mask matrix for the selected sample
        x=list(range(num_time_steps)),  # Time steps
        y=[f'Feature {i+1}' for i in range(num_features)],  # Features
        colorscale='Viridis',  # Use the 'viridis' colorscale
        colorbar=dict(title="Mask Value"),  # Add a color bar
        showscale=True
    ))

    # Set layout properties
    fig.update_layout(
        title=f'Heatmap of Mask for Sample {sample_index + 1} (1 = Observed, 0 = Missing)',
        xaxis_title='Time Steps',
        yaxis_title='Features',
        hovermode='closest'
    )

    # Show the interactive heatmap
    fig.show()

# Example usage to visualize the mask for sample 0
interactive_visualize_mask_heatmap(M, sample_index=0)



In [24]:
import plotly.graph_objects as go

def interactive_visualize_time_intervals(delta, sample_index=0, num_features=12, num_time_steps=200):
    """
    Interactive visualization of time intervals (delta) for a single sample using Plotly.
    
    delta: 3D numpy array where the first dimension is samples, second is features, and third is time steps.
    sample_index: Index of the sample to visualize (default is 0).
    num_features: Number of features (default is 12).
    num_time_steps: Number of time steps to visualize (default is 200).
    """
    time_steps = list(range(num_time_steps))  # Create a list of time steps
    
    # Create a Plotly figure
    fig = go.Figure()

    # Plot each feature's time intervals for the selected sample
    for i in range(num_features):
        fig.add_trace(go.Scatter(
            x=time_steps, 
            y=delta[sample_index, i, :], 
            mode='lines', 
            name=f'Feature {i+1}'
        ))

    # Set labels and title
    fig.update_layout(
        title=f'Time Intervals Between Observations for Sample {sample_index + 1}',
        xaxis_title='Time Steps',
        yaxis_title='Time Intervals (Hours)',
        hovermode='x unified'  # Unified hover mode
    )

    # Show the interactive plot
    fig.show()

# Example usage to visualize time intervals for sample 0
interactive_visualize_time_intervals(delta, sample_index=0)


In [25]:
# x = np.concatenate((x, M, delta), axis=1)
# print(x.shape)
# y= np.array(label)
# print(y.shape)

In [26]:
# np.savez('preprocessed_5000.npz', array1=x, array2=y)

# Reference

* https://github.com/amazon-science/chronos-forecasting/issues/13