In [1]:
import torch
import pandas as pd
from itertools import groupby, islice
import numpy as np

# bit hacky but passes checks and I don't have time to implement a neater solution
lab_indices = [0, 1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 18, 21, 22, 23, 24, 29, 32, 33, 34, 39, 40, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 60, 62, 63, 67, 68, 69, 70, 71, 72, 75, 83, 84, 86]
labs_to_keep = [0] + [(i + 1) for i in lab_indices] + [(i + 88) for i in lab_indices] + [-1]
no_lab_indices = list(range(87))
no_lab_indices = [x for x in no_lab_indices if x not in lab_indices]
no_labs_to_keep = [0] + [(i + 1) for i in no_lab_indices] + [(i + 88) for i in no_lab_indices] + [-1]


class eICUReader(object):

    def __init__(self, data_path, device=None, labs_only=False, no_labs=False):
        self._diagnoses_path = data_path + '/diagnoses.csv'
        self._labels_path = data_path + '/labels.csv'
        self._flat_path = data_path + '/flat.csv'
        self._timeseries_path = data_path + '/timeseries.csv'
        self._device = device
        self.labs_only = labs_only
        self.no_labs = no_labs
        self._dtype = torch.cuda.FloatTensor if device.type == 'cuda' else torch.FloatTensor

        self.labels = pd.read_csv(self._labels_path, index_col='patient')
        self.flat = pd.read_csv(self._flat_path, index_col='patient')
        self.diagnoses = pd.read_csv(self._diagnoses_path, index_col='patient')

        # we minus 2 to calculate F because hour and time are not features for convolution
        self.F = (pd.read_csv(self._timeseries_path, index_col='patient', nrows=1).shape[1] - 2)//2
        self.D = self.diagnoses.shape[1]
        self.no_flat_features = self.flat.shape[1]

        self.patients = list(self.labels.index)
        self.no_patients = len(self.patients)

    def line_split(self, line):
        return [float(x) for x in line.split(',')]

    def pad_sequences(self, ts_batch):
        seq_lengths = [len(x) for x in ts_batch]
        max_len = max(seq_lengths)
        padded = [patient + [[0] * (self.F * 2 + 2)] * (max_len - len(patient)) for patient in ts_batch]
        if self.labs_only:
            padded = np.array(padded)
            padded = padded[:, :, labs_to_keep]
        if self.no_labs:
            padded = np.array(padded)
            padded = padded[:, :, no_labs_to_keep]
        padded = torch.tensor(padded, device=self._device).type(self._dtype).permute(0, 2, 1)  # B * (2F + 2) * T
        padded[:, 0, :] /= 24  # scale the time into days instead of hours
        mask = torch.zeros(padded[:, 0, :].shape, device=self._device).type(self._dtype)
        for p, l in enumerate(seq_lengths):
            mask[p, :l] = 1
        return padded, mask, torch.tensor(seq_lengths).type(self._dtype)

    def get_los_labels(self, labels, times, mask):
        times = labels.unsqueeze(1).repeat(1, times.shape[1]) - times
        # clamp any labels that are less than 30 mins otherwise it becomes too small when the log is taken
        # make sure where there is no data the label is 0
        return (times.clamp(min=1/48) * mask)

    def get_mort_labels(self, labels, length):
        repeated_labels = labels.unsqueeze(1).repeat(1, length)
        return repeated_labels

    def batch_gen(self, batch_size=8, time_before_pred=5):

        # note that once the generator is finished, the file will be closed automatically
        with open(self._timeseries_path, 'r') as timeseries_file:
            # the first line is the feature names; we have to skip over this
            self.timeseries_header = next(timeseries_file).strip().split(',')
            # this produces a generator that returns a list of batch_size patient identifiers
            patient_batches = (self.patients[pos:pos + batch_size] for pos in range(0, len(self.patients), batch_size))
            # create a generator to capture a single patient timeseries
            ts_patient = groupby(map(self.line_split, timeseries_file), key=lambda line: line[0])
            # we loop through these batches, tracking the index because we need it to index the pandas dataframes
            for i, batch in enumerate(patient_batches):
                ts_batch = [[line[1:] for line in ts] for _, ts in islice(ts_patient, batch_size)]
                padded, mask, seq_lengths = self.pad_sequences(ts_batch)
                los_labels = self.get_los_labels(torch.tensor(self.labels.iloc[i*batch_size:(i+1)*batch_size,7].values, device=self._device).type(self._dtype), padded[:,0,:], mask)
                mort_labels = self.get_mort_labels(torch.tensor(self.labels.iloc[i*batch_size:(i+1)*batch_size,5].values, device=self._device).type(self._dtype), length=mask.shape[1])

                # we must avoid taking data before time_before_pred hours to avoid diagnoses and apache variable from the future
                yield (padded,  # B * (2F + 2) * T
                       mask[:, time_before_pred:],  # B * (T - time_before_pred)
                       torch.tensor(self.diagnoses.iloc[i*batch_size:(i+1)*batch_size].values, device=self._device).type(self._dtype),  # B * D
                       torch.tensor(self.flat.iloc[i*batch_size:(i+1)*batch_size].values.astype(float), device=self._device).type(self._dtype),  # B * no_flat_features
                       los_labels[:, time_before_pred:],
                       mort_labels[:, time_before_pred:],
                       seq_lengths - time_before_pred)

In [2]:
datareader = eICUReader

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:

import os

# Set the path to the directory
data_path = '/content/drive/MyDrive/ads/eicu/train'

# List all files in the directory
files = os.listdir(data_path)
print(files)

['labels.csv', 'stays.txt', 'flat.csv', 'diagnoses.csv', 'timeseries.csv']


In [5]:
import torch
data_path = '/content/drive/MyDrive/ads/eicu/'
# import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
train_datareader = datareader(
                    data_path + 'train',
                    device=device,
                    labs_only=False,
                    no_labs=False
                    )

Using device: cuda


In [6]:
val_datareader = datareader(
                  data_path + 'val',
                  device=device,
                  labs_only=False,
                  no_labs=False
                  )
test_datareader = datareader(
                  data_path + 'test',
                  device=device,
                  labs_only=False,
                  no_labs=False
                  )

In [7]:
train_datareader.no_patients

102577

In [8]:
val_datareader.no_patients

21990

In [9]:
test_datareader.no_patients

22106

In [10]:
def best_global(c):
    c['alpha'] = 100
    if c['dataset'] == 'eICU':
        c['main_dropout_rate'] = 0.45
        c['last_linear_size'] = 17
        c['diagnosis_size'] = 64
        c['batch_norm'] = 'mybatchnorm'
    elif c['dataset'] == 'MIMIC':
        # diagnosis size does not apply for MIMIC since we don't have diagnoses
        c['main_dropout_rate'] = 0
        c['last_linear_size'] = 36
        c['batch_norm'] = 'mybatchnorm'
    return c

def best_tpc(c):
    c = best_global(c)
    c['mode'] = 'test'
    c['model_type'] = 'tpc'
    if c['dataset'] == 'eICU':
        if c['percentage_data'] == 6.25:
            c['n_epochs'] = 8
        elif c['task'] == 'mortality':
            c['n_epochs'] = 6
        else:
            c['n_epochs'] = 15
        c['batch_size'] = 32
        c['n_layers'] = 9
        c['kernel_size'] = 4
        c['no_temp_kernels'] = 12
        c['point_size'] = 13
        c['learning_rate'] = 0.00226
        c['temp_dropout_rate'] = 0.05
        c['temp_kernels'] = [12] * 9 if not c['share_weights'] else [32] * 9
        c['point_sizes'] = [13] * 9
    elif c['dataset'] == 'MIMIC':
        c['no_diag'] = True
        c['n_epochs'] = 10 if c['task'] != 'mortality' else 6
        c['batch_size'] = 8
        c['batch_size_test'] = 8  # purely to keep experiment size small so I can run many in parallel
        c['n_layers'] = 8
        c['kernel_size'] = 5
        c['no_temp_kernels'] = 11
        c['point_size'] = 5
        c['learning_rate'] = 0.00221
        c['temp_dropout_rate'] = 0.05
        c['temp_kernels'] = [11] * 8
        c['point_sizes'] = [5] * 8
    return c

In [11]:
config = {
    # ── general ─────────────────────────────
    "dataset": "eICU",
    "disable_cuda": False,
    "intermediate_reporting": False,
    "batch_size_test": 32,
    "shuffle_train": False,
    "save_results_csv": False,
    "percentage_data": 100.0,
    "task": "LoS",
    "mode": "train",

    # ── loss ────────────────────────────────
    "loss": "hdloss",
    "sum_losses": True,

    # ── ablations / feature flags ───────────
    "labs_only": False,
    "no_mask": False,
    "no_diag": False,    # stays False because dataset is eICU
    "no_labs": False,
    "no_exp": False,

    # ── shared hyper‑parameters ─────────────
    "alpha": 100,
    "main_dropout_rate": 0.45,
    "L2_regularisation": 0.0,
    "last_linear_size": 17,
    "diagnosis_size": 64,
    "batchnorm": "mybatchnorm",

    # ── TPC‑specific hyper‑parameters ───────
    "n_epochs": 15,
    "batch_size": 32,
    "n_layers": 9,
    "kernel_size": 4,
    "no_temp_kernels": 12,
    "point_size": 13,
    "learning_rate": 0.00226,
    "temp_dropout_rate": 0.05,
    "share_weights": False,
    "no_skip_connections": False,

    # ── derived lists (one entry per layer) ─
    "temp_kernels": [12, 12, 12, 12, 12, 12, 12, 12, 12],
    "point_sizes": [13, 13, 13, 13, 13, 13, 13, 13, 13]
}
config['dataset'] = 'eICU'

In [12]:
config = best_tpc(config)

In [13]:
no_train_batches = len(train_datareader.patients) / config["batch_size"]

In [14]:
checkpoint_counter = 0

In [15]:
model = None
optimiser = None

In [16]:
train_datareader.D

293

In [17]:
import torch
import torch.nn as nn
from torch import cat, exp
import torch.nn.functional as F
from torch.nn.functional import pad
from torch.nn.modules.batchnorm import _BatchNorm
from types import SimpleNamespace


###============== The main defining function of the TPC model is temp_pointwise() on line 403 ==============###

import torch
import torch.nn as nn

# Hybrid Loss Function combining MSE, MAE, and Regularization
class HybridLoss(nn.Module):
    def __init__(self, alpha=1.0, beta=1.0, lambda_reg=0.01):
        super(HybridLoss, self).__init__()
        self.mse_loss = nn.MSELoss(reduction='none')  # Mean Squared Error
        self.mae_loss = nn.L1Loss(reduction='none')  # Mean Absolute Error
        self.lambda_reg = lambda_reg  # Regularization term (L2 regularization weight)
        self.alpha = alpha  # Weight for MSE
        self.beta = beta  # Weight for MAE

    def forward(self, y_hat, y, mask, seq_length, sum_losses=False):
        # Masking the predictions and labels where there is no data
        y_hat = y_hat.where(mask, torch.zeros_like(y))
        y = y.where(mask, torch.zeros_like(y))

        # Calculating MSE and MAE
        mse = self.mse_loss(y_hat, y)
        mae = self.mae_loss(y_hat, y)

        # Summing the MSE and MAE losses across the batch
        loss = self.alpha * torch.sum(mse, dim=1) + self.beta * torch.sum(mae, dim=1)

        # Regularization term (L2 penalty on the weights)
        l2_reg = self.lambda_reg * torch.sum(torch.square(y_hat))  # Example of L2 regularization

        # Adding the regularization term to the loss
        loss = loss + l2_reg

        # Normalizing by sequence length to ensure the loss is proportional to sequence length
        if not sum_losses:
            loss = loss / seq_length.clamp(min=1)

        return loss.mean()


# Mean Squared Logarithmic Error (MSLE) loss
class MSLELoss(nn.Module):
    def __init__(self):
        super(MSLELoss, self).__init__()
        self.squared_error = nn.MSELoss(reduction='none')

    def forward(self, y_hat, y, mask, seq_length, sum_losses=False):
        # the log(predictions) corresponding to no data should be set to 0
        log_y_hat = y_hat.log().where(mask, torch.zeros_like(y))
        # the we set the log(labels) that correspond to no data to be 0 as well
        log_y = y.log().where(mask, torch.zeros_like(y))
        # where there is no data log_y_hat = log_y = 0, so the squared error will be 0 in these places
        loss = self.squared_error(log_y_hat, log_y)
        loss = torch.sum(loss, dim=1)
        if not sum_losses:
            loss = loss / seq_length.clamp(min=1)
        return loss.mean()


# Mean Squared Error (MSE) loss
class MSELoss(nn.Module):
    def __init__(self):
        super(MSELoss, self).__init__()
        self.squared_error = nn.MSELoss(reduction='none')

    def forward(self, y_hat, y, mask, seq_length, sum_losses=False):
        # the predictions corresponding to no data should be set to 0
        y_hat = y_hat.where(mask, torch.zeros_like(y))
        # the we set the labels that correspond to no data to be 0 as well
        y = y.where(mask, torch.zeros_like(y))
        # where there is no data log_y_hat = log_y = 0, so the squared error will be 0 in these places
        loss = self.squared_error(y_hat, y)
        loss = torch.sum(loss, dim=1)
        if not sum_losses:
            loss = loss / seq_length.clamp(min=1)
        return loss.mean()


class MyBatchNorm(_BatchNorm):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super(MyBatchNorm, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    def forward(self, input):
        self._check_input_dim(input)

        # hack to work around model.eval() issue
        if not self.training:
            self.eval_momentum = 0  # set the momentum to zero when the model is validating

        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum if self.training else self.eval_momentum

        if self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked = self.num_batches_tracked + 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum if self.training else self.eval_momentum

        return F.batch_norm(
            input, self.running_mean, self.running_var, self.weight, self.bias,
            training=True, momentum=exponential_average_factor, eps=self.eps)  # set training to True so it calculates the norm of the batch


class MyBatchNorm1d(MyBatchNorm):
    def _check_input_dim(self, input):
        if input.dim() != 2 and input.dim() != 3:
            raise ValueError('expected 2D or 3D input (got {}D input)'.format(input.dim()))


class EmptyModule(nn.Module):
    def forward(self, X):
        return X


class TempPointConv(nn.Module):
    def __init__(self, config, F=None, D=None, no_flat_features=None):
        # --------------------------------------------------------------
        # NEW: make both dict‑ and attribute‑style configs work
        # --------------------------------------------------------------
        if isinstance(config, dict):               # user passed a dict
            config = SimpleNamespace(**config)     # wrap for dot access
        # afterwards we can keep calling config.xxx everywhere.
        # --------------------------------------------------------------

        super(TempPointConv, self).__init__()

        # --- copy all hyper‑parameters --------------------------------
        self.task              = config.task
        self.n_layers          = config.n_layers
        self.model_type        = config.model_type
        self.share_weights     = getattr(config, 'share_weights', False)
        self.diagnosis_size    = config.diagnosis_size
        self.main_dropout_rate = config.main_dropout_rate
        self.temp_dropout_rate = config.temp_dropout_rate
        self.kernel_size       = config.kernel_size
        self.temp_kernels      = config.temp_kernels
        self.point_sizes       = config.point_sizes
        self.batchnorm         = config.batchnorm
        self.last_linear_size  = config.last_linear_size
        self.F                 = F
        self.D                 = D
        self.no_flat_features  = no_flat_features
        self.no_diag           = config.no_diag
        self.no_mask           = config.no_mask
        self.no_exp            = config.no_exp
        self.no_skip_connections = config.no_skip_connections
        self.alpha             = config.alpha
        self.momentum          = 0.01 if self.batchnorm == 'low_momentum' else 0.1

        # --- layers & helpers ----------------------------------------
        self.relu      = nn.ReLU()
        self.sigmoid   = nn.Sigmoid()
        self.hardtanh  = nn.Hardtanh(min_val=1/48, max_val=100)
        self.msle_loss = MSLELoss()
        self.mse_loss  = MSELoss()
        self.bce_loss  = nn.BCELoss()
        self.hb_loss   = HybridLoss(alpha=self.alpha, beta=1.0, lambda_reg=0.01)

        self.main_dropout = nn.Dropout(p=self.main_dropout_rate)
        self.temp_dropout = nn.Dropout(p=self.temp_dropout_rate)

        self.remove_none   = lambda x: tuple(xi for xi in x if xi is not None)
        self.empty_module  = EmptyModule()

        if self.batchnorm in ['mybatchnorm', 'pointonly', 'temponly', 'low_momentum']:
            self.batchnormclass = MyBatchNorm1d
        elif self.batchnorm == 'default':
            self.batchnormclass = nn.BatchNorm1d

        # input:  B × D   →  B × diagnosis_size
        self.diagnosis_encoder   = nn.Linear(self.D, self.diagnosis_size)

        if self.batchnorm in ['mybatchnorm', 'pointonly', 'low_momentum', 'default']:
            self.bn_diagnosis_encoder = self.batchnormclass(self.diagnosis_size, momentum=self.momentum)
            self.bn_point_last_los    = self.batchnormclass(self.last_linear_size, momentum=self.momentum)
            self.bn_point_last_mort   = self.batchnormclass(self.last_linear_size, momentum=self.momentum)
        else:
            self.bn_diagnosis_encoder = self.empty_module
            self.bn_point_last_los    = self.empty_module
            self.bn_point_last_mort   = self.empty_module

        # input:  (B × T) × last_linear_size   →   (B × T) × 1
        self.point_final_los  = nn.Linear(self.last_linear_size, 1)
        self.point_final_mort = nn.Linear(self.last_linear_size, 1)

        # --- backbone selection --------------------------------------
        if self.model_type == 'tpc':
            self.init_tpc()
        elif self.model_type == 'temp_only':
            self.init_temp()
        elif self.model_type == 'pointwise_only':
            self.init_pointwise()
        else:
            raise NotImplementedError(
                'model_type must be one of {tpc, temp_only, pointwise_only}'
            )

    def init_tpc(self):

        # non-module layer attributes
        self.layers = []
        for i in range(self.n_layers):
            dilation = i * (self.kernel_size - 1) if i > 0 else 1  # dilation = 1 for the first layer, after that it captures all the information gathered by previous layers
            temp_k = self.temp_kernels[i]
            point_size = self.point_sizes[i]
            self.update_layer_info(layer=i, temp_k=temp_k, point_size=point_size, dilation=dilation, stride=1)

        # module layer attributes
        self.create_temp_pointwise_layers()

        # input shape: (B * T) * ((F + Zt) * (1 + Y) + diagnosis_size + no_flat_features)
        # output shape: (B * T) * last_linear_size
        input_size = (self.F + self.Zt) * (1 + self.Y) + self.diagnosis_size + self.no_flat_features
        if self.no_diag:
            input_size = input_size - self.diagnosis_size
        if self.no_skip_connections:
            input_size = self.F * self.Y + self.Z + self.diagnosis_size + self.no_flat_features
        self.point_last_los = nn.Linear(in_features=input_size, out_features=self.last_linear_size)
        self.point_last_mort = nn.Linear(in_features=input_size, out_features=self.last_linear_size)

        return


    def init_temp(self):

        # non-module layer attributes
        self.layers = []
        for i in range(self.n_layers):
            dilation = i * (self.kernel_size - 1) if i > 0 else 1  # dilation = 1 for the first layer, after that it captures all the information gathered by previous layers
            temp_k = self.temp_kernels[i]
            self.update_layer_info(layer=i, temp_k=temp_k, dilation=dilation, stride=1)

        # module layer attributes
        self.create_temp_only_layers()

        # input shape: (B * T) * (F * (1 + Y) + diagnosis_size + no_flat_features)
        # output shape: (B * T) * last_linear_size
        input_size = self.F * (1 + self.Y) + self.diagnosis_size + self.no_flat_features
        self.point_last_los = nn.Linear(in_features=input_size, out_features=self.last_linear_size)
        self.point_last_mort = nn.Linear(in_features=input_size, out_features=self.last_linear_size)
        return


    def init_pointwise(self):

        # non-module layer attributes
        self.layers = []
        for i in range(self.n_layers):
            point_size = self.point_sizes[i]
            self.update_layer_info(layer=i, point_size=point_size)

        # module layer attributes
        self.create_pointwise_only_layers()

        # input shape: (B * T) * (Zt + 2F + 2 + no_flat_features + diagnosis_size)
        # output shape: (B * T) * last_linear_size
        if self.no_mask:
            input_size = self.Zt + self.F + 2 + self.no_flat_features + self.diagnosis_size
        else:
            input_size = self.Zt + 2 * self.F + 2 + self.no_flat_features + self.diagnosis_size
        self.point_last_los = nn.Linear(in_features=input_size, out_features=self.last_linear_size)
        self.point_last_mort = nn.Linear(in_features=input_size, out_features=self.last_linear_size)

        return


    def update_layer_info(self, layer=None, temp_k=None, point_size=None, dilation=None, stride=None):

        self.layers.append({})
        if point_size is not None:
            self.layers[layer]['point_size'] = point_size
        if temp_k is not None:
            padding = [(self.kernel_size - 1) * dilation, 0]  # [padding_left, padding_right]
            self.layers[layer]['temp_kernels'] = temp_k
            self.layers[layer]['dilation'] = dilation
            self.layers[layer]['padding'] = padding
            self.layers[layer]['stride'] = stride

        return


    def create_temp_pointwise_layers(self):

        ### Notation used for tracking the tensor shapes ###

        # Z is the number of extra features added by the previous pointwise layer (could be 0 if this is the first layer)
        # Zt is the cumulative number of extra features that have been added by all previous pointwise layers
        # Zt-1 = Zt - Z (cumulative number of extra features minus the most recent pointwise layer)
        # Y is the number of channels in the previous temporal layer (could be 0 if this is the first layer)

        self.layer_modules = nn.ModuleDict()

        self.Y = 0
        self.Z = 0
        self.Zt = 0

        for i in range(self.n_layers):

            temp_in_channels = (self.F + self.Zt) * (1 + self.Y) if i > 0 else 2 * self.F  # (F + Zt) * (Y + 1)
            temp_out_channels = (self.F + self.Zt) * self.layers[i]['temp_kernels']  # (F + Zt) * temp_kernels
            linear_input_dim = (self.F + self.Zt - self.Z) * self.Y + self.Z + 2 * self.F + 2 + self.no_flat_features  # (F + Zt-1) * Y + Z + 2F + 2 + no_flat_features
            linear_output_dim = self.layers[i]['point_size']  # point_size
            # correct if no_mask
            if self.no_mask:
                if i == 0:
                    temp_in_channels = self.F
                linear_input_dim = (self.F + self.Zt - self.Z) * self.Y + self.Z + self.F + 2 + self.no_flat_features  # (F + Zt-1) * Y + Z + F + 2 + no_flat_features

            temp = nn.Conv1d(in_channels=temp_in_channels,  # (F + Zt) * (Y + 1)
                             out_channels=temp_out_channels,  # (F + Zt) * Y
                             kernel_size=self.kernel_size,
                             stride=self.layers[i]['stride'],
                             dilation=self.layers[i]['dilation'],
                             groups=self.F + self.Zt)

            point = nn.Linear(in_features=linear_input_dim, out_features=linear_output_dim)

            # correct if no_skip_connections
            if self.no_skip_connections:
                temp_in_channels = self.F * self.Y if i > 0 else 2 * self.F  # F * Y
                temp_out_channels = self.F * self.layers[i]['temp_kernels']  # F * temp_kernels
                #linear_input_dim = self.F * self.Y + self.Z if i > 0 else 2 * self.F + 2 + self.no_flat_features  # (F * Y) + Z
                linear_input_dim = self.Z if i > 0 else 2 * self.F + 2 + self.no_flat_features  # Z
                temp = nn.Conv1d(in_channels=temp_in_channels,
                                 out_channels=temp_out_channels,
                                 kernel_size=self.kernel_size,
                                 stride=self.layers[i]['stride'],
                                 dilation=self.layers[i]['dilation'],
                                 groups=self.F)

                point = nn.Linear(in_features=linear_input_dim, out_features=linear_output_dim)

            if self.batchnorm in ['default', 'mybatchnorm', 'low_momentum']:
                bn_temp = self.batchnormclass(num_features=temp_out_channels, momentum=self.momentum)
                bn_point = self.batchnormclass(num_features=linear_output_dim, momentum=self.momentum)
            elif self.batchnorm == 'temponly':
                bn_temp = self.batchnormclass(num_features=temp_out_channels)
                bn_point = self.empty_module
            elif self.batchnorm == 'pointonly':
                bn_temp = self.empty_module
                bn_point = self.batchnormclass(num_features=linear_output_dim)
            else:
                bn_temp = bn_point = self.empty_module  # linear module; does nothing

            self.layer_modules[str(i)] = nn.ModuleDict({
                'temp': temp,
                'bn_temp': bn_temp,
                'point': point,
                'bn_point': bn_point})

            self.Y = self.layers[i]['temp_kernels']
            self.Z = linear_output_dim
            self.Zt += self.Z

        return


    def create_temp_only_layers(self):

        # Y is the number of channels in the previous temporal layer (could be 0 if this is the first layer)
        self.layer_modules = nn.ModuleDict()
        self.Y = 0

        for i in range(self.n_layers):

            if self.share_weights:
                temp_in_channels = (1 + self.Y) if i > 0 else 2  # (Y + 1)
                temp_out_channels = self.layers[i]['temp_kernels']
                groups = 1
            else:
                temp_in_channels = self.F * (1 + self.Y) if i > 0 else 2 * self.F  # F * (Y + 1)
                temp_out_channels = self.F * self.layers[i]['temp_kernels']  # F * temp_kernels
                groups = self.F

            temp = nn.Conv1d(in_channels=temp_in_channels,
                             out_channels=temp_out_channels,
                             kernel_size=self.kernel_size,
                             stride=self.layers[i]['stride'],
                             dilation=self.layers[i]['dilation'],
                             groups=groups)

            if self.batchnorm in ['default', 'mybatchnorm', 'low_momentum', 'temponly']:
                bn_temp = self.batchnormclass(num_features=temp_out_channels, momentum=self.momentum)
            else:
                bn_temp = self.empty_module  # linear module; does nothing

            self.layer_modules[str(i)] = nn.ModuleDict({
                'temp': temp,
                'bn_temp': bn_temp})

            self.Y = self.layers[i]['temp_kernels']

        return


    def create_pointwise_only_layers(self):

        # Zt is the cumulative number of extra features that have been added by previous pointwise layers
        self.layer_modules = nn.ModuleDict()
        self.Zt = 0

        for i in range(self.n_layers):

            linear_input_dim = self.Zt + 2 * self.F + 2 + self.no_flat_features  # Zt + 2F + 2 + no_flat_features
            linear_output_dim = self.layers[i]['point_size']  # point_size

            if self.no_mask:
                linear_input_dim = self.Zt + self.F + 2 + self.no_flat_features  # Zt + 2F + 2 + no_flat_features

            point = nn.Linear(in_features=linear_input_dim, out_features=linear_output_dim)

            if self.batchnorm in ['default', 'mybatchnorm', 'low_momentum', 'pointonly']:
                bn_point = self.batchnormclass(num_features=linear_output_dim, momentum=self.momentum)
            else:
                bn_point = self.empty_module  # linear module; does nothing

            self.layer_modules[str(i)] = nn.ModuleDict({
                'point': point,
                'bn_point': bn_point})

            self.Zt += linear_output_dim

        return


    # This is really where the crux of TPC is defined. This function defines one TPC layer, as in Figure 3 in the paper:
    # https://arxiv.org/pdf/2007.09483.pdf
    def temp_pointwise(self, B=None, T=None, X=None, repeat_flat=None, X_orig=None, temp=None, bn_temp=None, point=None,
                       bn_point=None, temp_kernels=None, point_size=None, padding=None, prev_temp=None, prev_point=None,
                       point_skip=None):

        ### Notation used for tracking the tensor shapes ###

        # Z is the number of extra features added by the previous pointwise layer (could be 0 if this is the first layer)
        # Zt is the cumulative number of extra features that have been added by all previous pointwise layers
        # Zt-1 = Zt - Z (cumulative number of extra features minus the most recent pointwise layer)
        # Y is the number of channels in the previous temporal layer (could be 0 if this is the first layer)
        # X shape: B * ((F + Zt) * (Y + 1)) * T; N.B exception in the first layer where there are also mask features, in this case it is B * 2F * T
        # repeat_flat shape: (B * T) * no_flat_features
        # X_orig shape: (B * T) * (2F + 2)
        # prev_temp shape: (B * T) * ((F + Zt-1) * (Y + 1))
        # prev_point shape: (B * T) * Z

        Z = prev_point.shape[1] if prev_point is not None else 0

        X_padded = pad(X, padding, 'constant', 0)  # B * ((F + Zt) * (Y + 1)) * (T + padding)
        X_temp = self.temp_dropout(bn_temp(temp(X_padded)))  # B * ((F + Zt) * temp_kernels) * T

        X_concat = cat(self.remove_none((prev_temp,  # (B * T) * ((F + Zt-1) * Y)
                                         prev_point,  # (B * T) * Z
                                         X_orig,  # (B * T) * (2F + 2)
                                         repeat_flat)),  # (B * T) * no_flat_features
                       dim=1)  # (B * T) * (((F + Zt-1) * Y) + Z + 2F + 2 + no_flat_features)

        point_output = self.main_dropout(bn_point(point(X_concat)))  # (B * T) * point_size

        # point_skip input: B * (F + Zt-1) * T
        # prev_point: B * Z * T
        # point_skip output: B * (F + Zt) * T
        point_skip = cat((point_skip, prev_point.view(B, T, Z).permute(0, 2, 1)), dim=1) if prev_point is not None else point_skip

        temp_skip = cat((point_skip.unsqueeze(2),  # B * (F + Zt) * 1 * T
                         X_temp.view(B, point_skip.shape[1], temp_kernels, T)),  # B * (F + Zt) * temp_kernels * T
                        dim=2)  # B * (F + Zt) * (1 + temp_kernels) * T

        X_point_rep = point_output.view(B, T, point_size, 1).permute(0, 2, 3, 1).repeat(1, 1, (1 + temp_kernels), 1)  # B * point_size * (1 + temp_kernels) * T
        X_combined = self.relu(cat((temp_skip, X_point_rep), dim=1))  # B * (F + Zt) * (1 + temp_kernels) * T
        next_X = X_combined.view(B, (point_skip.shape[1] + point_size) * (1 + temp_kernels), T)  # B * ((F + Zt + point_size) * (1 + temp_kernels)) * T

        temp_output = X_temp.permute(0, 2, 1).contiguous().view(B * T, point_skip.shape[1] * temp_kernels)  # (B * T) * ((F + Zt) * temp_kernels)

        return (temp_output,  # (B * T) * ((F + Zt) * temp_kernels)
                point_output,  # (B * T) * point_size
                next_X,  # B * ((F + Zt) * (1 + temp_kernels)) * T
                point_skip)  # for keeping track of the point skip connections; B * (F + Zt) * T


    def temp(self, B=None, T=None, X=None, X_temp_orig=None, temp=None, bn_temp=None, temp_kernels=None, padding=None):

        ### Notation used for tracking the tensor shapes ###

        # Y is the number of channels in the previous temporal layer (could be 0 if this is the first layer)
        # X shape: B * (F * (Y + 1)) * T; N.B exception in the first layer where there are also mask features, in this case it is B * 2F * T
        # X_temp_orig shape: B * F * T

        X_padded = pad(X, padding, 'constant', 0)  # B * (F * (Y + 1)) * (T + padding)

        if self.share_weights:
            _, C, padded_length = X_padded.shape
            chans = int(C / self.F)
            X_temp = self.temp_dropout(bn_temp(temp(X_padded.view(B * self.F, chans, padded_length)))).view(B, (self.F * temp_kernels), T)  # B * (F * temp_kernels) * T
        else:
            X_temp = self.temp_dropout(bn_temp(temp(X_padded)))  # B * (F * temp_kernels) * T

        temp_skip = self.relu(cat((X_temp_orig.unsqueeze(2),  # B * F * 1 * T
                                   X_temp.view(B, self.F, temp_kernels, T)),  # B * F * temp_kernels * T
                                   dim=2))  # B * F * (1 + temp_kernels) * T

        next_X = temp_skip.view(B, (self.F * (1 + temp_kernels)), T)  # B * (F * (1 + temp_kernels)) * T

        return next_X  # B * (F * temp_kernels) * T


    def point(self, B=None, T=None, X=None, repeat_flat=None, X_orig=None, point=None, bn_point=None, point_skip=None):

        ### Notation used for tracking the tensor shapes ###

        # Z is the number of extra features added by the previous pointwise layer (could be 0 if this is the first layer)
        # Zt is the cumulative number of extra features that have been added by all previous pointwise layers
        # Zt-1 = Zt - Z (cumulative number of extra features minus the most recent pointwise layer)
        # X shape: B * (F + Zt) * T; N.B exception in the first layer where there are also mask features, in this case it is B * 2F * T
        # repeat_flat shape: (B * T) * no_flat_features
        # X_orig shape: (B * T) * (2F + 2)
        # prev_point shape: (B * T) * Z

        X_combined = cat((X, repeat_flat), dim=1)

        X_point = self.main_dropout(bn_point(point(X_combined)))  # (B * T) * point_size

        # point_skip input: B * Zt-1 * T
        # prev_point: B * Z * T
        # point_skip output: B * Zt * T
        point_skip = cat(self.remove_none((point_skip, X_point.view(B, T, -1).permute(0, 2, 1))), dim=1)

        # point_skip: B * Zt * T
        # X_orig: (B * T) * (2F + 2)
        # repeat_flat: (B * T) * no_flat_features
        # next_X: (B * T) * (Zt + 2F + 2 + no_flat_features)
        next_X = self.relu(cat((point_skip.permute(0, 2, 1).contiguous().view(B * T, -1), X_orig), dim=1))

        return (next_X,  # (B * T) * (Zt + 2F + 2 + no_flat_features)
                point_skip)  # for keeping track of the pointwise skip connections; B * Zt * T


    def temp_pointwise_no_skip(self, B=None, T=None, temp=None, bn_temp=None, point=None, bn_point=None, padding=None, prev_temp=None,
                               prev_point=None, temp_kernels=None, X_orig=None, repeat_flat=None):

        ### Temporal component ###

        # Y is the number of channels in the previous temporal layer (could be 0 if this is the first layer)
        # prev_temp shape: B * (F * Y) * T; N.B exception in the first layer where there are also mask features, in this case it is B * 2F * T

        X_padded = pad(prev_temp, padding, 'constant', 0)  # B * (F * Y) * (T + padding)
        temp_output = self.relu(self.temp_dropout(bn_temp(temp(X_padded))))  # B * (F * temp_kernels) * T

        ### Pointwise component ###

        # prev_point shape: (B * T) * ((F * Y) + Z)
        point_output = self.relu(self.main_dropout(bn_point(point(prev_point))))  # (B * T) * point_size

        return (temp_output,  # B * (F * temp_kernels) * T
                point_output)  # (B * T) * point_size


    def forward(self, X, diagnoses, flat, time_before_pred=5):
        # flat is B * no_flat_features
        # diagnoses is B * D
        # X is B * (2F + 2) * T
        # X_mask is B * T
        # (the batch is padded to the longest sequence, the + 2 is the time and the hour which are not for temporal convolution)

        # Split the input X into features
        X_separated = torch.split(X[:, 1:-1, :], self.F, dim=1)  # tuple ((B * F * T), (B * F * T))

        # Get batch size, features, and time dimension
        B, _, T = X_separated[0].shape

        # Debug print the shapes of tensors
        # print(f"Shape of X_separated[0]: {X_separated[0].shape}")
        # print(f"Shape of flat: {flat.shape}")
        # print(f"Shape of diagnoses: {diagnoses.shape}")

        if self.model_type in ['pointwise_only', 'tpc']:
            repeat_flat = flat.repeat_interleave(T, dim=0)  # (B * T) * no_flat_features

            if self.no_mask:
                # For no mask case, include time and hour in X_orig (skip first and last time columns)
                X_orig = cat((X_separated[0],
                              X[:, 0, :].unsqueeze(1),
                              X[:, -1, :].unsqueeze(1)), dim=1).permute(0, 2, 1).contiguous().view(B * T, self.F + 2)  # (B * T) * (F + 2)
            else:
                X_orig = X.permute(0, 2, 1).contiguous().view(B * T, 2 * self.F + 2)  # (B * T) * (2F + 2)

            repeat_args = {'repeat_flat': repeat_flat, 'X_orig': X_orig, 'B': B, 'T': T}

            if self.model_type == 'tpc':
                if self.no_mask:
                    next_X = X_separated[0]
                else:
                    next_X = torch.stack(X_separated, dim=2).reshape(B, 2 * self.F, T)  # B * 2F * T
                point_skip = X_separated[0]  # Keeps track of skip connections generated from linear layers; B * F * T
                temp_output = None
                point_output = None
            else:  # pointwise only
                next_X = X_orig
                point_skip = None

        elif self.model_type == 'temp_only':
            next_X = torch.stack(X_separated, dim=2).view(B, 2 * self.F, T)  # B * 2F * T
            X_temp_orig = X_separated[0]  # Skip connections for temp only model
            repeat_args = {'X_temp_orig': X_temp_orig, 'B': B, 'T': T}

        if self.no_skip_connections:
            temp_output = next_X
            point_output = cat((X_orig,  # (B * T) * (2F + 2)
                                repeat_flat),  # (B * T) * no_flat_features
                              dim=1)  # (B * T) * (2F + 2 + no_flat_features)
            self.layer1 = True

        for i in range(self.n_layers):
            kwargs = dict(self.layer_modules[str(i)], **repeat_args)

            if self.model_type == 'tpc':
                if self.no_skip_connections:
                    temp_output, point_output = self.temp_pointwise_no_skip(
                        prev_point=point_output, prev_temp=temp_output,
                        temp_kernels=self.layers[i]['temp_kernels'],
                        padding=self.layers[i]['padding'], **kwargs)
                else:
                    temp_output, point_output, next_X, point_skip = self.temp_pointwise(
                        X=next_X, point_skip=point_skip,
                        prev_temp=temp_output, prev_point=point_output,
                        temp_kernels=self.layers[i]['temp_kernels'],
                        padding=self.layers[i]['padding'],
                        point_size=self.layers[i]['point_size'], **kwargs)
            elif self.model_type == 'temp_only':
                next_X = self.temp(X=next_X, temp_kernels=self.layers[i]['temp_kernels'],
                                  padding=self.layers[i]['padding'], **kwargs)
            elif self.model_type == 'pointwise_only':
                next_X, point_skip = self.point(X=next_X, point_skip=point_skip, **kwargs)

        # Tidy up the outputs
        if self.model_type == 'pointwise_only':
            next_X = next_X.view(B, T, -1).permute(0, 2, 1)
        elif self.no_skip_connections:
            # Combine the final layer
            next_X = cat((point_output,
                          temp_output.permute(0, 2, 1).contiguous().view(B * T, self.F * self.layers[-1]['temp_kernels'])),
                        dim=1)
            next_X = next_X.view(B, T, -1).permute(0, 2, 1)

        # Note: We cut off at time_before_pred hours here because the model is only valid from time_before_pred hours onwards
        if self.no_diag:
            combined_features = cat((flat.repeat_interleave(T - time_before_pred, dim=0),  # (B * (T - time_before_pred)) * no_flat_features
                                    next_X[:, :, time_before_pred:].permute(0, 2, 1).contiguous().view(B * (T - time_before_pred), -1)), dim=1)  # (B * (T - time_before_pred)) * (((F + Zt) * (1 + Y)) + no_flat_features) for tpc
        else:
            diagnoses_enc = self.relu(self.main_dropout(self.bn_diagnosis_encoder(self.diagnosis_encoder(diagnoses))))  # B * diagnosis_size
            combined_features = cat((flat.repeat_interleave(T - time_before_pred, dim=0),  # (B * (T - time_before_pred)) * no_flat_features
                                    diagnoses_enc.repeat_interleave(T - time_before_pred, dim=0),  # (B * (T - time_before_pred)) * diagnosis_size
                                    next_X[:, :, time_before_pred:].permute(0, 2, 1).contiguous().view(B * (T - time_before_pred), -1)), dim=1)  # (B * (T - time_before_pred)) * (((F + Zt) * (1 + Y)) + diagnosis_size + no_flat_features) for tpc

        last_point_los = self.relu(self.main_dropout(self.bn_point_last_los(self.point_last_los(combined_features))))
        last_point_mort = self.relu(self.main_dropout(self.bn_point_last_mort(self.point_last_mort(combined_features))))

        if self.no_exp:
            los_predictions = self.hardtanh(self.point_final_los(last_point_los).view(B, T - time_before_pred))  # B * (T - time_before_pred)
        else:
            los_predictions = self.hardtanh(exp(self.point_final_los(last_point_los).view(B, T - time_before_pred)))  # B * (T - time_before_pred)
        mort_predictions = self.sigmoid(self.point_final_mort(last_point_mort).view(B, T - time_before_pred))  # B * (T - time_before_pred)

        return los_predictions, mort_predictions



    def temp_pointwise_no_skip_old(self, B=None, T=None, temp=None, bn_temp=None, point=None, bn_point=None, padding=None, prev_temp=None,
                               prev_point=None, temp_kernels=None, X_orig=None, repeat_flat=None):

        ### Temporal component ###

        # Y is the number of channels in the previous temporal layer (could be 0 if this is the first layer)
        # prev_temp shape: B * (F * Y) * T; N.B exception in the first layer where there are also mask features, in this case it is B * 2F * T

        X_padded = pad(prev_temp, padding, 'constant', 0)  # B * (F * Y) * (T + padding)
        temp_output = self.relu(self.temp_dropout(bn_temp(temp(X_padded))))  # B * (F * temp_kernels) * T

        ### Pointwise component ###

        # prev_point shape: (B * T) * ((F * Y) + Z)

        # if this is not layer 1:
        if self.layer1:
            X_concat = prev_point
            self.layer1 = False
        else:
            X_concat = cat((prev_point,
                            prev_temp.permute(0, 2, 1).contiguous().view(B * T, self.F * temp_kernels)),
                           dim=1)

        point_output = self.relu(self.main_dropout(bn_point(point(X_concat))))  # (B * T) * point_size

        return (temp_output,  # B * (F * temp_kernels) * T
                point_output)  # (B * T) * point_size


    def loss(self, y_hat_los, y_hat_mort, y_los, y_mort, mask, seq_lengths, device, sum_losses, loss_type):
        # mort loss
        if self.task == 'mortality':
            loss = self.bce_loss(y_hat_mort, y_mort) * self.alpha
        # los loss
        else:
            bool_type = torch.cuda.BoolTensor if device == torch.device('cuda') else torch.BoolTensor
            if loss_type == 'msle':
                los_loss = self.msle_loss(y_hat_los, y_los, mask.type(bool_type), seq_lengths, sum_losses)
            elif loss_type == 'mse':
                los_loss = self.mse_loss(y_hat_los, y_los, mask.type(bool_type), seq_lengths, sum_losses)
            elif loss_type == 'hdloss':  # Add this condition to select the HybridLoss
              los_loss = self.hb_loss(y_hat_los, y_los, mask.type(bool_type), seq_lengths, sum_losses)
            if self.task == 'LoS':
                loss = los_loss
            # multitask loss
            if self.task == 'multitask':
                loss = los_loss + self.bce_loss(y_hat_mort, y_mort) * self.alpha
        return loss

In [18]:
model = TempPointConv(
    config=config,
    F=train_datareader.F,
    D=train_datareader.D,
    no_flat_features= train_datareader.no_flat_features
    ).to(device=device)

In [19]:
model.train()

TempPointConv(
  (relu): ReLU()
  (sigmoid): Sigmoid()
  (hardtanh): Hardtanh(min_val=0.020833333333333332, max_val=100)
  (msle_loss): MSLELoss(
    (squared_error): MSELoss()
  )
  (mse_loss): MSELoss(
    (squared_error): MSELoss()
  )
  (bce_loss): BCELoss()
  (hb_loss): HybridLoss(
    (mse_loss): MSELoss()
    (mae_loss): L1Loss()
  )
  (main_dropout): Dropout(p=0.45, inplace=False)
  (temp_dropout): Dropout(p=0.05, inplace=False)
  (empty_module): EmptyModule()
  (diagnosis_encoder): Linear(in_features=293, out_features=64, bias=True)
  (bn_diagnosis_encoder): MyBatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn_point_last_los): MyBatchNorm1d(17, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn_point_last_mort): MyBatchNorm1d(17, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (point_final_los): Linear(in_features=17, out_features=1, bias=True)
  (point_final_mort): Linear(in_features=17, out_features=1, b

In [20]:
train_batches = train_datareader.batch_gen(batch_size=config["batch_size"])

In [21]:
import torch
import torch.nn as nn

# Mean Squared Error for Length of Stay (LoS) prediction (Regression Task)
criterion_los = nn.MSELoss()

# Binary Cross-Entropy with Logits for Mortality prediction (Binary Classification Task)
criterion_mort = nn.BCEWithLogitsLoss()

import torch.optim as optim

# Define Adam optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

import numpy as np

def train(model, train_datareader, optimizer, criterion_los, criterion_mort, device):
    model.train()
    running_loss_los = 0.0
    running_loss_mort = 0.0
    num_batches = 0  # To count the number of batches

    # Iterate directly over the generator produced by batch_gen
    for batch in train_datareader.batch_gen(batch_size=config['batch_size']):
        try:
            padded, mask, diagnoses, flat, los_labels, mort_labels, seq_lengths = batch  # Unpack correctly

            # Move data to the device (GPU/CPU)
            padded, los_labels, mort_labels = padded.to(device), los_labels.to(device), mort_labels.to(device)
            flat, diagnoses = flat.to(device), diagnoses.to(device)

            optimizer.zero_grad()

            # Forward pass
            los_pred, mort_pred = model(padded, diagnoses, flat)

            # Calculate loss
            loss_los = criterion_los(los_pred, los_labels)
            loss_mort = criterion_mort(mort_pred, mort_labels)

            # Total loss
            loss = loss_los + loss_mort

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            running_loss_los += loss_los.item()
            running_loss_mort += loss_mort.item()
            num_batches += 1  # Increment the batch counter
            # print(f'Batch {num_batches}: LoS Loss: {loss_los.item():.4f}, Mortality Loss: {loss_mort.item():.4f}')

        except Exception as e:
            # print(f"Skipping batch due to error: {e}")
            continue  # Skip this batch and proceed to the next one

    # Calculate the average loss
    avg_loss_los = running_loss_los / num_batches if num_batches > 0 else 0
    avg_loss_mort = running_loss_mort / num_batches if num_batches > 0 else 0

    return avg_loss_los, avg_loss_mort


# Example of training for one epoch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# for epoch in range(10):  # Run for 10 epochs as an example
train_loss_los, train_loss_mort = train(model, train_datareader, optimizer, criterion_los, criterion_mort, device)
print(f' LoS Loss: {train_loss_los:.4f}, Mortality Loss: {train_loss_mort:.4f}')


 LoS Loss: 7.1211, Mortality Loss: 0.7035


In [22]:
import torch
import numpy as np
from sklearn.metrics import cohen_kappa_score
import pandas as pd
from sklearn.metrics import mean_absolute_percentage_error, mean_squared_error, r2_score

# Store per-batch metrics
batch_metrics = []

# Iterate through the validation data
for batch_idx, batch in enumerate(val_datareader.batch_gen(batch_size=config['batch_size'])):
    padded, mask, diagnoses, flat, los_labels, mort_labels, seq_lengths = batch

    # Move data to device (GPU/CPU)
    padded, los_labels, mort_labels = padded.to(device), los_labels.to(device), mort_labels.to(device)
    flat, diagnoses = flat.to(device), diagnoses.to(device)

    try:
        # Forward pass: Get model predictions
        los_pred, mort_pred = model(padded, diagnoses, flat)

        # Ensure binary predictions (0 or 1) by thresholding the predicted values
        mort_pred_binary = (mort_pred.detach().cpu().numpy() > 0.5).astype(int)  # Apply threshold to get binary (0 or 1)

        # Ensure labels are binary (0 or 1)
        mort_labels_binary = mort_labels.detach().cpu().numpy().astype(int)

        # Flatten if the predictions or labels are not 1D arrays
        if len(mort_labels_binary.shape) > 1:
            mort_labels_binary = mort_labels_binary.flatten()

        if len(mort_pred_binary.shape) > 1:
            mort_pred_binary = mort_pred_binary.flatten()

        # Calculate batch-wise metrics for Mortality
        mort_accuracy = np.mean((mort_labels_binary == mort_pred_binary))  # Binary accuracy
        kappa_mort = cohen_kappa_score(mort_labels_binary, mort_pred_binary)  # Cohen's Kappa

        # Calculate batch-wise metrics for LoS (Length of Stay) regression
        mad_los = np.mean(np.abs(los_labels.detach().cpu().numpy() - los_pred.detach().cpu().numpy()))  # Mean Absolute Deviation (MAD)
        mape_los = mean_absolute_percentage_error(los_labels.detach().cpu().numpy(), los_pred.detach().cpu().numpy())  # Mean Absolute Percentage Error (MAPE)
        mse_los = mean_squared_error(los_labels.detach().cpu().numpy(), los_pred.detach().cpu().numpy())  # Mean Squared Error (MSE)
        msle_los = mean_squared_error(np.log1p(los_labels.detach().cpu().numpy()), np.log1p(los_pred.detach().cpu().numpy()))  # Mean Squared Logarithmic Error (MSLE)
        r2_los = r2_score(los_labels.detach().cpu().numpy(), los_pred.detach().cpu().numpy())  # R-squared (R2)

        # Save metrics for the current batch
        batch_metrics.append({
            'Batch': batch_idx,
            'MAD (LoS)': mad_los,
            'MAPE (LoS)': mape_los,
            'MSE (LoS)': mse_los,
            'MSLE (LoS)': msle_los,
            'R2 (LoS)': r2_los,
            'Accuracy (Mortality)': mort_accuracy,
            'Kappa (Mortality)': kappa_mort
        })

        # # Print batch-wise metrics
        # print(f"Batch {batch_idx + 1}:")
        # print(f"  LoS MAD: {mad_los:.4f}, MAPE: {mape_los:.4f}, MSE: {mse_los:.4f}, MSLE: {msle_los:.4f}, R2: {r2_los:.4f}")
        # print(f"  Mortality Accuracy: {mort_accuracy:.4f}, Kappa: {kappa_mort:.4f}\n")

    except RuntimeError as e:
        # print(f"Skipping batch {batch_idx + 1} due to error: {e}")
        continue  # Skip this batch if there's an error (e.g., mismatched tensor shapes)

# Convert batch-wise metrics to DataFrame
batch_metrics_df = pd.DataFrame(batch_metrics)

# Aggregate metrics across all batches (e.g., taking the mean)
aggregated_metrics = {
    'MAD (LoS)': batch_metrics_df['MAD (LoS)'].mean(),
    'MSE (LoS)': batch_metrics_df['MSE (LoS)'].mean(),
    'MSLE (LoS)': batch_metrics_df['MSLE (LoS)'].mean(),
    'R2 (LoS)': batch_metrics_df['R2 (LoS)'].mean(),
    'Accuracy (Mortality)': batch_metrics_df['Accuracy (Mortality)'].mean(),
    'Kappa (Mortality)': batch_metrics_df['Kappa (Mortality)'].mean()
}


# Optionally, display batch-wise metrics as well
print(aggregated_metrics)

  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
  k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
  k = np.sum(w_mat * confusion) / np.sum(w_mat * expecte

{'MAD (LoS)': np.float32(0.64354044), 'MSE (LoS)': np.float64(10.929902297026382), 'MSLE (LoS)': np.float64(0.12085423469097314), 'R2 (LoS)': np.float64(0.05843460618039), 'Accuracy (Mortality)': np.float64(0.9389165419161677), 'Kappa (Mortality)': np.float64(0.3386)}


USE NEW LOSS FUNCTION

In [23]:
config["loss"] = "hdloss"

In [24]:
model = TempPointConv(
    config=config,
    F=train_datareader.F,
    D=train_datareader.D,
    no_flat_features= train_datareader.no_flat_features
    ).to(device=device)

In [25]:
model.train()

TempPointConv(
  (relu): ReLU()
  (sigmoid): Sigmoid()
  (hardtanh): Hardtanh(min_val=0.020833333333333332, max_val=100)
  (msle_loss): MSLELoss(
    (squared_error): MSELoss()
  )
  (mse_loss): MSELoss(
    (squared_error): MSELoss()
  )
  (bce_loss): BCELoss()
  (hb_loss): HybridLoss(
    (mse_loss): MSELoss()
    (mae_loss): L1Loss()
  )
  (main_dropout): Dropout(p=0.45, inplace=False)
  (temp_dropout): Dropout(p=0.05, inplace=False)
  (empty_module): EmptyModule()
  (diagnosis_encoder): Linear(in_features=293, out_features=64, bias=True)
  (bn_diagnosis_encoder): MyBatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn_point_last_los): MyBatchNorm1d(17, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn_point_last_mort): MyBatchNorm1d(17, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (point_final_los): Linear(in_features=17, out_features=1, bias=True)
  (point_final_mort): Linear(in_features=17, out_features=1, b

In [26]:
train_batches = train_datareader.batch_gen(batch_size=config["batch_size"])

In [27]:
import torch
import torch.nn as nn

# Mean Squared Error for Length of Stay (LoS) prediction (Regression Task)
criterion_los = nn.MSELoss()

# Binary Cross-Entropy with Logits for Mortality prediction (Binary Classification Task)
criterion_mort = nn.BCEWithLogitsLoss()

import torch.optim as optim

# Define Adam optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

import numpy as np

def train(model, train_datareader, optimizer, criterion_los, criterion_mort, device):
    model.train()
    running_loss_los = 0.0
    running_loss_mort = 0.0
    num_batches = 0  # To count the number of batches

    # Iterate directly over the generator produced by batch_gen
    for batch in train_datareader.batch_gen(batch_size=config['batch_size']):
        try:
            padded, mask, diagnoses, flat, los_labels, mort_labels, seq_lengths = batch  # Unpack correctly

            # Move data to the device (GPU/CPU)
            padded, los_labels, mort_labels = padded.to(device), los_labels.to(device), mort_labels.to(device)
            flat, diagnoses = flat.to(device), diagnoses.to(device)

            optimizer.zero_grad()

            # Forward pass
            los_pred, mort_pred = model(padded, diagnoses, flat)

            # Calculate loss
            loss_los = criterion_los(los_pred, los_labels)
            loss_mort = criterion_mort(mort_pred, mort_labels)

            # Total loss
            loss = loss_los + loss_mort

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            running_loss_los += loss_los.item()
            running_loss_mort += loss_mort.item()
            num_batches += 1  # Increment the batch counter
            # print(f'Batch {num_batches}: LoS Loss: {loss_los.item():.4f}, Mortality Loss: {loss_mort.item():.4f}')

        except Exception as e:
            # print(f"Skipping batch due to error: {e}")
            continue  # Skip this batch and proceed to the next one

    # Calculate the average loss
    avg_loss_los = running_loss_los / num_batches if num_batches > 0 else 0
    avg_loss_mort = running_loss_mort / num_batches if num_batches > 0 else 0

    return avg_loss_los, avg_loss_mort


# Example of training for one epoch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# for epoch in range(10):  # Run for 10 epochs as an example
train_loss_los, train_loss_mort = train(model, train_datareader, optimizer, criterion_los, criterion_mort, device)
print(f' LoS Loss: {train_loss_los:.4f}, Mortality Loss: {train_loss_mort:.4f}')


 LoS Loss: 7.1269, Mortality Loss: 0.7014


In [28]:
import torch
import numpy as np
from sklearn.metrics import cohen_kappa_score
import pandas as pd
from sklearn.metrics import mean_absolute_percentage_error, mean_squared_error, r2_score

# Store per-batch metrics
batch_metrics = []

# Iterate through the validation data
for batch_idx, batch in enumerate(val_datareader.batch_gen(batch_size=config['batch_size'])):
    padded, mask, diagnoses, flat, los_labels, mort_labels, seq_lengths = batch

    # Move data to device (GPU/CPU)
    padded, los_labels, mort_labels = padded.to(device), los_labels.to(device), mort_labels.to(device)
    flat, diagnoses = flat.to(device), diagnoses.to(device)

    try:
        # Forward pass: Get model predictions
        los_pred, mort_pred = model(padded, diagnoses, flat)

        # Ensure binary predictions (0 or 1) by thresholding the predicted values
        mort_pred_binary = (mort_pred.detach().cpu().numpy() > 0.5).astype(int)  # Apply threshold to get binary (0 or 1)

        # Ensure labels are binary (0 or 1)
        mort_labels_binary = mort_labels.detach().cpu().numpy().astype(int)

        # Flatten if the predictions or labels are not 1D arrays
        if len(mort_labels_binary.shape) > 1:
            mort_labels_binary = mort_labels_binary.flatten()

        if len(mort_pred_binary.shape) > 1:
            mort_pred_binary = mort_pred_binary.flatten()

        # Calculate batch-wise metrics for Mortality
        mort_accuracy = np.mean((mort_labels_binary == mort_pred_binary))  # Binary accuracy
        kappa_mort = cohen_kappa_score(mort_labels_binary, mort_pred_binary)  # Cohen's Kappa

        # Calculate batch-wise metrics for LoS (Length of Stay) regression
        mad_los = np.mean(np.abs(los_labels.detach().cpu().numpy() - los_pred.detach().cpu().numpy()))  # Mean Absolute Deviation (MAD)
        mape_los = mean_absolute_percentage_error(los_labels.detach().cpu().numpy(), los_pred.detach().cpu().numpy())  # Mean Absolute Percentage Error (MAPE)
        mse_los = mean_squared_error(los_labels.detach().cpu().numpy(), los_pred.detach().cpu().numpy())  # Mean Squared Error (MSE)
        msle_los = mean_squared_error(np.log1p(los_labels.detach().cpu().numpy()), np.log1p(los_pred.detach().cpu().numpy()))  # Mean Squared Logarithmic Error (MSLE)
        r2_los = r2_score(los_labels.detach().cpu().numpy(), los_pred.detach().cpu().numpy())  # R-squared (R2)

        # Save metrics for the current batch
        batch_metrics.append({
            'Batch': batch_idx,
            'MAD (LoS)': mad_los,
            'MAPE (LoS)': mape_los,
            'MSE (LoS)': mse_los,
            'MSLE (LoS)': msle_los,
            'R2 (LoS)': r2_los,
            'Accuracy (Mortality)': mort_accuracy,
            'Kappa (Mortality)': kappa_mort
        })

        # # Print batch-wise metrics
        # print(f"Batch {batch_idx + 1}:")
        # print(f"  LoS MAD: {mad_los:.4f}, MAPE: {mape_los:.4f}, MSE: {mse_los:.4f}, MSLE: {msle_los:.4f}, R2: {r2_los:.4f}")
        # print(f"  Mortality Accuracy: {mort_accuracy:.4f}, Kappa: {kappa_mort:.4f}\n")

    except RuntimeError as e:
        # print(f"Skipping batch {batch_idx + 1} due to error: {e}")
        continue  # Skip this batch if there's an error (e.g., mismatched tensor shapes)

# Convert batch-wise metrics to DataFrame
batch_metrics_df = pd.DataFrame(batch_metrics)

# Aggregate metrics across all batches (e.g., taking the mean)
aggregated_metrics = {
    'MAD (LoS)': batch_metrics_df['MAD (LoS)'].mean(),
    'MSE (LoS)': batch_metrics_df['MSE (LoS)'].mean(),
    'MSLE (LoS)': batch_metrics_df['MSLE (LoS)'].mean(),
    'R2 (LoS)': batch_metrics_df['R2 (LoS)'].mean(),
    'Accuracy (Mortality)': batch_metrics_df['Accuracy (Mortality)'].mean(),
    'Kappa (Mortality)': batch_metrics_df['Kappa (Mortality)'].mean()
}


# Optionally, display batch-wise metrics as well
print(aggregated_metrics)

{'MAD (LoS)': np.float32(0.6363913), 'MSE (LoS)': np.float64(10.93475916633706), 'MSLE (LoS)': np.float64(0.11792143061757088), 'R2 (LoS)': np.float64(0.058985815296215), 'Accuracy (Mortality)': np.float64(0.9175082806602836), 'Kappa (Mortality)': np.float64(0.38759203387716344)}
