## Visualize

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
plt.rcParams['figure.dpi'] = 200

def visualize(samples: torch.Tensor, figsize=(20, 10), linewidth=1, markersize=5, marker='s', feature_names: list=None) -> None:
    samples = samples.numpy()
    rows = samples.shape[0] // 3
    cols = 5
    fig, axs = plt.subplots(rows, cols, figsize=(cols*11, rows*8))
    data_iterator = iter(samples)
    x = 5*np.arange(1, samples.shape[-1]+1) / 60
    plot_dict = {0: 'Heart Rate (Beats/Min)', 2: 'Respiration (Breaths/Min)', 4: 'SPO2 (%)', 6: 'Mean Arterial Pressure (mmHg)', 8: 'Mortality (1/0)'}
    
    for i in range(rows):
        timeseries = next(data_iterator)
        color_choice = iter(sns.color_palette())
        for j in range(cols):
            idx = 2*j
            if idx == 8:
                sns.lineplot(x=x, y=timeseries[idx], ax=axs[i, j], linewidth=5, label=plot_dict[idx], color=next(color_choice))
                axs[i, j].set_ylim(-0.3, 1.3)
            else:
                data = np.where(timeseries[idx+1] == 1, np.nan, timeseries[idx])
                if np.all(np.isnan(data)):
                    next(color_choice)
                    continue
                else:
                    sns.lineplot(x=x, y=data, ax=axs[i, j], linewidth=2, label=plot_dict[idx], color=next(color_choice))
            axs[i, j].set_xlabel("Time (Hours)", fontsize=30)
            axs[i, j].set_ylabel("Measurement Value", fontsize=30)
            axs[i, j].legend(fontsize=30)
            axs[i, j].tick_params(axis='both', which='major', labelsize=30)

    
    plt.tight_layout()
    plt.legend()
    # self.wandb.log({"check_point_performance": self.wandb.Image(plt)})
    # plt.show()
    # plt.close()

## All Patients

In [2]:
from helpers import seed_everything
seed_everything(2023)
data = torch.load('data/eicu-extract/TRAIN-eicu_multiple_60_2880_564.pt')
perm = torch.randperm(data.shape[0])
data = data[perm, :, :]
visualize(data[:50, :, :])
plt.savefig('img/eicu_all_48hrs.png')
plt.close()

Seed set to 2023



In [3]:
from helpers import seed_everything
seed_everything(2023)
data = torch.load('data/eicu-extract/TRAIN-eicu_multiple_60_1440_276.pt')
perm = torch.randperm(data.shape[0])
data = data[perm, :, :]
visualize(data[:50, :, :])
plt.savefig('img/eicu_all_24hrs.png')
plt.close()

Seed set to 2023



## Sepsis/Septicemia

In [12]:
from helpers import seed_everything
seed_everything(2023)
data = torch.load('data/eicu-extract/TRAIN-SEPSIS-eicu_multiple_60_2880_564.pt')
perm = torch.randperm(data.shape[0])
data = data[perm, :, :]
visualize(data[:50, :, :])
plt.savefig('img/eicu_sepsis_48hrs.png')
plt.close()

Seed set to 2023



In [13]:
from helpers import seed_everything
seed_everything(2023)
data = torch.load('data/eicu-extract/TRAIN-SEPSIS-eicu_multiple_60_1440_276.pt')
perm = torch.randperm(data.shape[0])
data = data[perm, :, :]
visualize(data[:50, :, :])
plt.savefig('img/eicu_sepsis_24hrs.png')
plt.close()

Seed set to 2023



## Acute Myocardial Infarction

In [6]:
from helpers import seed_everything
seed_everything(2023)
data = torch.load('data/eicu-extract/TRAIN-AMI-eicu_multiple_60_2880_564.pt')
perm = torch.randperm(data.shape[0])
data = data[perm, :, :]
visualize(data[:50, :, :])
plt.savefig('img/eicu_ami_48hrs.png')
plt.close()

Seed set to 2023



In [7]:
from helpers import seed_everything
seed_everything(2023)
data = torch.load('data/eicu-extract/TRAIN-AMI-eicu_multiple_60_1440_276.pt')
perm = torch.randperm(data.shape[0])
data = data[perm, :, :]
visualize(data[:50, :, :])
plt.savefig('img/eicu_ami_24hrs.png')
plt.close()

Seed set to 2023



## Acute Kidney Failure

In [8]:
from helpers import seed_everything
seed_everything(2023)
data = torch.load('data/eicu-extract/TRAIN-AKF-eicu_multiple_60_2880_564.pt')
perm = torch.randperm(data.shape[0])
data = data[perm, :, :]
visualize(data[:50, :, :])
plt.savefig('img/eicu_akf_48hrs.png')
plt.close()

Seed set to 2023



In [9]:
from helpers import seed_everything
seed_everything(2023)
data = torch.load('data/eicu-extract/TRAIN-AKF-eicu_multiple_60_1440_276.pt')
perm = torch.randperm(data.shape[0])
data = data[perm, :, :]
visualize(data[:50, :, :])
plt.savefig('img/eicu_akf_24hrs.png')
plt.close()

Seed set to 2023



## Heart Failure

In [10]:
from helpers import seed_everything
seed_everything(2023)
data = torch.load('data/eicu-extract/TRAIN-HF-eicu_multiple_60_2880_564.pt')
perm = torch.randperm(data.shape[0])
data = data[perm, :, :]
visualize(data[:50, :, :])
plt.savefig('img/eicu_hf_48hrs.png')
plt.close()

Seed set to 2023



In [11]:
from helpers import seed_everything
seed_everything(2023)
data = torch.load('data/eicu-extract/TRAIN-HF-eicu_multiple_60_1440_276.pt')
perm = torch.randperm(data.shape[0])
data = data[perm, :, :]
visualize(data[:50, :, :])
plt.savefig('img/eicu_hf_24hrs.png')
plt.close()

Seed set to 2023

