In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import json

import matplotlib.pyplot as plt

from torch.utils.tensorboard import SummaryWriter

from models.nbeats import NBeats, Block
from data import OhioData
from models.lstm_seq import LSTMPredictor
from models.nbeats import NBeats

from sklearn.linear_model import LinearRegression

from torch.utils.data import DataLoader

from tqdm import tqdm

In [2]:
input_dim = 24
n_features = 11
n_layers = 1
output_dim = 12
amount_fc = 3
hidden_dim = 10

n_blocks = 12
hidden_dim2 = 64
early_stopping_counter = 10

device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [3]:
lstm_model = LSTMPredictor(input_size=n_features, hidden_size=hidden_dim, num_layers=n_layers, bidirectional=False).to(device)
lstm_model.load_state_dict(torch.load(f"checkpoints/lstm_best_{output_dim}.chkpt"))
lstm_model = lstm_model.eval()

nbeats_model = NBeats(n_blocks=n_blocks, input_dim=input_dim, parameter_dim=n_features, output_dim=output_dim, amount_fc=amount_fc, hidden_dim=hidden_dim2).to(device)
nbeats_model.load_state_dict(torch.load(f"checkpoints/nbeats_best_{output_dim}.chkpt"))
nbeats_model = nbeats_model.eval()

# train data is only required for scale_max["cbg"] value
train_data = OhioData()
data = OhioData(mode="validation", h=output_dim)
val_loader = DataLoader(data, batch_size=1, shuffle=False, num_workers=0)

In [4]:
def prepare_data_nbeats(batch):
    batch = batch.reshape(-1, input_dim, n_features)
    batch_tmp = batch[:, :, 0].clone().detach()
    batch[:, :, 0] = batch[:, :, 2]
    batch[:, :, 2] = batch_tmp
    return batch.permute(0, 2, 1).reshape(-1, input_dim * n_features)

def prepare_data_lstm(batch):
    batch = batch.reshape(batch.shape[0], input_dim, n_features).permute(1, 0, 2)
    return batch

In [5]:
X = []
Y = []

with torch.no_grad():
    for x, y, _ in tqdm(val_loader):
        x_lstm = prepare_data_lstm(x).cuda()
        x_nbeats = prepare_data_nbeats(x).cuda()
        y_lstm = lstm_model(x_lstm, x_lstm[-1].unsqueeze(0)[:, :, 2].unsqueeze(2), teacher_force=False)
        y_nbeats = nbeats_model(x_nbeats)
        y_lstm = y_lstm.squeeze().cpu().numpy()
        y_nbeats = y_nbeats.squeeze().cpu().numpy()
        y_pred = np.concatenate((y_lstm, y_nbeats), axis=0) * train_data.scale_max["cbg"]
        y *= train_data.scale_max["cbg"]
        X.append(y_pred)
        Y.append(y.squeeze().cpu().numpy())

 14%|█▎        | 3275/23911 [00:12<01:15, 274.37it/s]

In [None]:
reg = LinearRegression()
reg.fit(X, Y)
reg.score(X, Y)

0.9252092329733997

In [None]:
test_data = OhioData(mode="test")
test_loader = DataLoader(test_data, batch_size=1, num_workers=0)

mse_loss = nn.MSELoss()

In [None]:
total_loss = 0

with torch.no_grad():
    with tqdm(test_loader) as t:
        for x, y, mask in t:
            x_lstm = prepare_data_lstm(x).cuda()
            x_nbeats = prepare_data_nbeats(x).cuda()
            y_lstm = lstm_model(x_lstm, x_lstm[-1].unsqueeze(0)[:, :, 2].unsqueeze(2), teacher_force=False)
            y_nbeats = nbeats_model(x_nbeats)
            y_lstm = y_lstm.squeeze().cpu().numpy()
            y_nbeats = y_nbeats.squeeze().cpu().numpy()
            y_pred = np.concatenate((y_lstm, y_nbeats), axis=0) * train_data.scale_max["cbg"]
            y_pred = [y_pred]
            y_pred = torch.tensor(reg.predict(y_pred))
            # scale back to original range, because otherwise the results cannot bne compared to others
            y = y * train_data.scale_max["cbg"]
            loss = mse_loss(y_pred, y)
            total_loss += loss.item()
            t.set_description(f"Test Loss: {np.sqrt(total_loss / len(test_loader)):.2f}, Running Loss: {np.sqrt(loss.item()):.2f}")

print(f"Final rMSE: {np.sqrt(total_loss / len(test_loader))}")

Test Loss: 14.35, Running Loss: 3.27: 100%|██████████| 28426/28426 [02:56<00:00, 160.76it/s]  

Final rMSE: 14.345913227285662





# Evaluate results per patient

In [None]:
patient_ids = [559, 563, 570, 575, 588, 591, 540, 544, 552, 567, 584, 596]

mse_loss = nn.MSELoss()
mae_loss = nn.L1Loss()

In [None]:
maes = []
mses = []
with torch.no_grad():
    for id in patient_ids:
        total_mae = 0
        total_mse = 0
        test_data = OhioData(mode="test", patient_id=id)
        test_loader = DataLoader(test_data, batch_size=1, num_workers=0)
        with tqdm(test_loader) as t:
            for x, y, mask in t:
                x_lstm = prepare_data_lstm(x).cuda()
                x_nbeats = prepare_data_nbeats(x).cuda()
                y_lstm = lstm_model(x_lstm, x_lstm[-1].unsqueeze(0)[:, :, 2].unsqueeze(2), teacher_force=False)
                y_nbeats = nbeats_model(x_nbeats)
                y_lstm = y_lstm.squeeze().cpu().numpy()
                y_nbeats = y_nbeats.squeeze().cpu().numpy()
                y_pred = np.concatenate((y_lstm, y_nbeats), axis=0) * train_data.scale_max["cbg"]
                y_pred = [y_pred]
                y_pred = torch.tensor(reg.predict(y_pred))
                # scale back to original range, because otherwise the results cannot bne compared to others
                y = y * train_data.scale_max["cbg"]
                mse = mse_loss(y_pred, y)
                mae = mae_loss(y_pred, y) 
                total_mae += mae.item()
                total_mse += mse.item()
                t.set_description(f"Patient: {id}, rMSE: {np.sqrt(total_mse / len(test_loader)):.2f}, MAE: {total_mae / len(test_loader):.2f}")
        print(f"Results - Patient: {id}, rMSE: {np.sqrt(total_mse / len(test_loader)):.2f}, MAE: {total_mae / len(test_loader):.2f}")
        maes.append(total_mae / len(test_loader))
        mses.append(total_mse / len(test_loader))
print(f"Mean results - MAE: {sum(maes) / len(maes)}, rMSE: {np.sqrt(sum(mses) / len(mses))}")

Patient: 559, rMSE: 14.23, MAE: 9.22: 100%|██████████| 2142/2142 [00:13<00:00, 160.68it/s]


Results - Patient: 559, rMSE: 14.23, MAE: 9.22


Patient: 563, rMSE: 13.91, MAE: 9.00: 100%|██████████| 2446/2446 [00:15<00:00, 161.23it/s]


Results - Patient: 563, rMSE: 13.91, MAE: 9.00


Patient: 570, rMSE: 12.27, MAE: 8.27: 100%|██████████| 2435/2435 [00:15<00:00, 161.57it/s]


Results - Patient: 570, rMSE: 12.27, MAE: 8.27


Patient: 575, rMSE: 16.19, MAE: 9.88: 100%|██████████| 2249/2249 [00:13<00:00, 161.24it/s]


Results - Patient: 575, rMSE: 16.19, MAE: 9.88


Patient: 588, rMSE: 13.79, MAE: 9.15: 100%|██████████| 2698/2698 [00:16<00:00, 164.52it/s]


Results - Patient: 588, rMSE: 13.79, MAE: 9.15


Patient: 591, rMSE: 15.49, MAE: 10.22: 100%|██████████| 2605/2605 [00:16<00:00, 162.06it/s]


Results - Patient: 591, rMSE: 15.49, MAE: 10.22


Patient: 540, rMSE: 16.13, MAE: 10.83: 100%|██████████| 2617/2617 [00:16<00:00, 161.91it/s]


Results - Patient: 540, rMSE: 16.13, MAE: 10.83


Patient: 544, rMSE: 13.16, MAE: 8.90: 100%|██████████| 2499/2499 [00:15<00:00, 161.25it/s]


Results - Patient: 544, rMSE: 13.16, MAE: 8.90


Patient: 552, rMSE: 12.16, MAE: 8.40: 100%|██████████| 2023/2023 [00:12<00:00, 163.18it/s]


Results - Patient: 552, rMSE: 12.16, MAE: 8.40


Patient: 567, rMSE: 15.66, MAE: 10.30: 100%|██████████| 2017/2017 [00:12<00:00, 161.59it/s]


Results - Patient: 567, rMSE: 15.66, MAE: 10.30


Patient: 584, rMSE: 15.52, MAE: 10.33: 100%|██████████| 2169/2169 [00:13<00:00, 162.33it/s]


Results - Patient: 584, rMSE: 15.52, MAE: 10.33


Patient: 596, rMSE: 12.84, MAE: 8.45: 100%|██████████| 2526/2526 [00:15<00:00, 160.12it/s]

Results - Patient: 596, rMSE: 12.84, MAE: 8.45
Mean results - MAE: 9.412001328873915, rMSE: 14.350235417323082



