# Data Preparation

## installing the necessary libraries.

In [1]:
!pip install scikit-learn==1.0.2
!pip install scikit-survival

Collecting scikit-learn==1.0.2
  Using cached scikit_learn-1.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Using cached scikit_learn-1.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.5 MB)
Installing collected packages: scikit-learn
  Attempting uninstall: scikit-learn
    Found existing installation: scikit-learn 1.5.2
    Uninstalling scikit-learn-1.5.2:
      Successfully uninstalled scikit-learn-1.5.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 1.17.0 requires scikit-learn>=1.2.2, but you have scikit-learn 1.0.2 which is incompatible.
scikit-survival 0.23.0 requires scikit-learn<1.6,>=1.4.0, but you have scikit-learn 1.0.2 which is incompatible.[0m[31m
[0mSuccessfully installed scikit-learn-1.0.2
Collecting scikit-learn<1.6,>=1.4.0 (from scikit-survival)
  Using cached scikit_learn-1

In [28]:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sksurv.util import Surv
from sksurv.metrics import concordance_index_censored
from sklearn.metrics import accuracy_score, roc_curve, auc, confusion_matrix
from torch.utils.data import DataLoader, TensorDataset
import torch
import torch.nn as nn
import torch.optim as optim



## Loading the dataset from an Excel file and performing necessary preprocessing

In [29]:
# data = pd.read_excel('4A_TB_cleaned.xlsx')
data = pd.read_csv('/content/TB_Analysis (1).csv')
data.head()

Unnamed: 0,UNIT,IP NUMBER,SEX,AGE,RELIGION,DISTRICT OF RESIDENCE,WORKING DIAGNOSIS,OUT COME,NUMBER OF PREVIOUS ADMISSIONS,TB_type,MARITAL STATUS,MONTH OF ADMISSION,DAY OF ADMISSION,YEAR OF ADMISSION,DURATION,REGION,ADDITIONAL WORKING DIAGNOSIS,NUMBER OF ADDITIONAL WORKING DIAGNOSIS
0,GI,1861416,M,30,COU,Nakasongola,"ISS,Abdominal TB,,",IMPROVED,5,Abdominal_TB,2,November,Tuesday,2010,6,CENTRAL,"ISS,,",1
1,GI,1868011,F,46,COU,Kampala,"ISS Stage I&II,Disseminated TB,,",IMPROVED,3,Disseminated TB,2,December,Monday,2010,4,CENTRAL,"ISS Stage I&II,,",1
2,GI,1868840,M,80,ISLAM,Wakiso,"myeloproliferat,Malaria,INTESTINAL TB,",IMPROVED,7,Abdominal_TB,4,December,Wednesday,2010,8,CENTRAL,"myeloproliferat,Malaria,",2
3,GI,1851080,M,22,CATHOLIC,KAMPALA,"Abd TB,,,",IMPROVED,8,Abdominal_TB,9,December,Monday,2010,9,CENTRAL,",,",0
4,GI,1870478,F,18,COU,KAMPALA,"PTB,,,",IMPROVED,4,PTB,1,January,Thursday,2011,5,CENTRAL,",,",0


In [30]:
# function to calculate the survival status
def calculate_survival_status(row):
    if row['OUT COME'] == 'DIED':
        return 1
    else:
        return 0


data['SURVIVAL'] = data.apply(calculate_survival_status, axis=1)

In [31]:

# Filter the data
data = data[~((data['OUT COME'] != 'DIED') & (data['DURATION'] < 7))]


In [32]:


filtered_data = data[((data['OUT COME'] != 'DIED') & (data['DURATION'] < 7))]
if len(filtered_data) > 0:
  print("There are rows with 'OUT COME' != 'DIED' and duration less than 28 days.")
else:
  print("No rows found with 'OUT COME' != 'DIED' and duration less than 28 days.")


No rows found with 'OUT COME' != 'DIED' and duration less than 28 days.


In [33]:


data.drop(columns=['UNIT','DISTRICT OF RESIDENCE','WORKING DIAGNOSIS','ADDITIONAL WORKING DIAGNOSIS',
                   'MARITAL STATUS', 'IP NUMBER', 'OUT COME'], inplace=True)

new_col_order = ['SEX', 'AGE', 'RELIGION', 'REGION', 'NUMBER OF ADDITIONAL WORKING DIAGNOSIS', 'TB_type',
                 'NUMBER OF PREVIOUS ADMISSIONS', 'DAY OF ADMISSION', 'MONTH OF ADMISSION',
                 'YEAR OF ADMISSION', 'DURATION', 'SURVIVAL']
data = data[new_col_order]

data.dropna(subset=['SURVIVAL'], inplace=True)
data.reset_index(drop=True, inplace=True)
data['SURVIVAL'] = data['SURVIVAL'].astype(int)

data.shape

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data.drop(columns=['UNIT','DISTRICT OF RESIDENCE','WORKING DIAGNOSIS','ADDITIONAL WORKING DIAGNOSIS',


(2831, 12)

In [34]:
data['SURVIVAL'].value_counts()

Unnamed: 0_level_0,count
SURVIVAL,Unnamed: 1_level_1
0,1520
1,1311


In [35]:
data.head()

Unnamed: 0,SEX,AGE,RELIGION,REGION,NUMBER OF ADDITIONAL WORKING DIAGNOSIS,TB_type,NUMBER OF PREVIOUS ADMISSIONS,DAY OF ADMISSION,MONTH OF ADMISSION,YEAR OF ADMISSION,DURATION,SURVIVAL
0,M,80,ISLAM,CENTRAL,2,Abdominal_TB,7,Wednesday,December,2010,8,0
1,M,22,CATHOLIC,CENTRAL,0,Abdominal_TB,8,Monday,December,2010,9,0
2,F,24,CATHOLIC,CENTRAL,1,Disseminated TB,2,Wednesday,January,2011,3,1
3,F,18,CATHOLIC,CENTRAL,2,PTB,4,Sunday,January,2011,5,1
4,M,57,COU,CENTRAL,2,Abdominal_TB,6,Thursday,January,2011,7,0


In [36]:
data.isnull().sum()


Unnamed: 0,0
SEX,0
AGE,0
RELIGION,0
REGION,0
NUMBER OF ADDITIONAL WORKING DIAGNOSIS,0
TB_type,0
NUMBER OF PREVIOUS ADMISSIONS,0
DAY OF ADMISSION,0
MONTH OF ADMISSION,0
YEAR OF ADMISSION,0


## Encoding Categorical Features and Standardizing Continuous Features

In [37]:


# dictionary for month encoding
month_mapping = {'January': 1, 'February': 2, 'March': 3, 'April': 4, 'May': 5, 'June': 6,
                 'July': 7, 'August': 8, 'September': 9, 'October': 10, 'November': 11, 'December': 12}

# dictionary for day of week encoding
day_mapping = {'Sunday': 1, 'Monday': 2, 'Tuesday': 3, 'Wednesday': 4, 'Thursday': 5, 'Friday': 6, 'Saturday':76 }

data['Month of Admission'] = data['MONTH OF ADMISSION'].map(month_mapping)
data['Day of Admission'] = data['DAY OF ADMISSION'].map(day_mapping)

data.drop(columns=['MONTH OF ADMISSION', 'DAY OF ADMISSION'], inplace=True)



categorical_features = ['SEX', 'RELIGION', 'REGION', 'TB_type']

dataset = pd.get_dummies(data, columns=categorical_features)
dataset = dataset.astype(int)
dataset.head()

#Standardizing continuous features
standardize_features = ['AGE','NUMBER OF PREVIOUS ADMISSIONS','NUMBER OF ADDITIONAL WORKING DIAGNOSIS', 'DURATION', 'Month of Admission', 'Day of Admission']
scaler = StandardScaler()
dataset[standardize_features] = scaler.fit_transform(dataset[standardize_features])
dataset.head()

Unnamed: 0,AGE,NUMBER OF ADDITIONAL WORKING DIAGNOSIS,NUMBER OF PREVIOUS ADMISSIONS,YEAR OF ADMISSION,DURATION,SURVIVAL,Month of Admission,Day of Admission,SEX_F,SEX_M,...,REGION_EASTERN,REGION_NORTHERN,REGION_UNKNOWN,REGION_WESTERN,TB_type_Abdominal_TB,TB_type_CNS_TB,TB_type_Disseminated TB,TB_type_Other TBs,TB_type_PTB,TB_type_TB iris
0,4.120556,0.544308,1.054167,2010,-0.257499,0,1.599346,-0.307344,0,1,...,0,0,0,0,1,0,0,0,0,0
1,-1.221118,-2.034664,1.258547,2010,-0.142409,0,1.599346,-0.401397,0,1,...,0,0,0,0,1,0,0,0,0,0
2,-1.036922,-0.745178,0.03227,2011,-0.832947,1,-1.596575,-0.307344,1,0,...,0,0,0,0,0,0,1,0,0,0
3,-1.589509,0.544308,0.441029,2011,-0.602768,1,-1.596575,-0.448424,1,0,...,0,0,0,0,0,0,0,0,1,0
4,2.002306,0.544308,0.849788,2011,-0.372588,0,-1.596575,-0.260317,0,1,...,0,0,0,0,1,0,0,0,0,0


In [38]:
dataset['SURVIVAL'].value_counts()

Unnamed: 0_level_0,count
SURVIVAL,Unnamed: 1_level_1
0,1520
1,1311


In [39]:
#dataset.to_csv('TB_Analysis.csv', index=False)

## split

In [40]:

from sklearn.model_selection import train_test_split

# Split the data into train and test sets
train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)

# Further split the train data into train and validation sets
train_data, val_data = train_test_split(train_data, test_size=0.3, random_state=42)

# Define X, T, and E for each set
X_train = train_data.drop(['SURVIVAL', 'DURATION'], axis=1)
T_train = train_data['DURATION']
E_train = train_data['SURVIVAL']

X_val = val_data.drop(['SURVIVAL', 'DURATION'], axis=1)
T_val = val_data['DURATION']
E_val = val_data['SURVIVAL']

X_test = test_data.drop(['SURVIVAL', 'DURATION'], axis=1)
T_test = test_data['DURATION']
E_test = test_data['SURVIVAL']

In [41]:
# from imblearn.over_sampling import SMOTE
# from collections import Counter

# # Create a SMOTE instance
# smote = SMOTE(random_state=123)

# # Fit and transform the training data
# X_smote, y_smote = smote.fit_resample(X_train1, E_train)
# X_train = X_smote.drop(['DURATION'], axis=1)
# T_train = X_smote['DURATION']
# print('Original dataset shape:', Counter(E_train))
# print('Resampled dataset shape:', Counter(y_smote))

In [42]:
# E_train = y_smote
E_train.value_counts()

Unnamed: 0_level_0,count
SURVIVAL,Unnamed: 1_level_1
0,843
1,741


In [43]:
print(X_train)


           AGE  NUMBER OF ADDITIONAL WORKING DIAGNOSIS  \
2075 -0.484335                                0.544308   
2467  0.160350                               -0.745178   
2429 -0.576433                                0.544308   
1757  0.897132                               -0.745178   
502  -0.668531                                0.544308   
...        ...                                     ...   
2702  1.910209                               -2.034664   
2115 -0.208042                               -0.745178   
1667  0.897132                               -0.745178   
1284  0.252448                                0.544308   
15   -1.773705                               -0.745178   

      NUMBER OF PREVIOUS ADMISSIONS  YEAR OF ADMISSION  Month of Admission  \
2075                      -0.376488               2017           -0.724960   
2467                      -0.376488               2018            0.146655   
2429                      -0.376488               2018           -0.1

In [44]:
train_data.shape

(1584, 26)

In [45]:
test_data.shape

(567, 26)

In [46]:
val_data.shape

(680, 26)

# Train

In [21]:
!pip install --upgrade lifelines



In [47]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, roc_curve, auc, confusion_matrix
from lifelines.utils import concordance_index


In [52]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd

# Define the model
class NeuralNet(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_layers, dropout=0.2, initializer='glorot_uniform', output_activation=None):
        super(NeuralNet, self).__init__()
        self.initializer = initializer
        self.input_dim = input_dim
        self.output_dim = output_dim
        layers = []
        in_dim = input_dim

        # Constructing hidden layers
        for layer in hidden_layers:
            linear_layer = nn.Linear(in_dim, layer['num_units'])
            self.init_weights(linear_layer)  # Initialize the weights
            layers.append(linear_layer)

            # Add activation
            if layer['activation'] == 'ReLU':
                layers.append(nn.ReLU())
            elif layer['activation'] == 'Tanh':
                layers.append(nn.Tanh())

            layers.append(nn.Dropout(dropout))  # Adding dropout to avoid overfitting
            in_dim = layer['num_units']

        final_layer = nn.Linear(in_dim, output_dim)  # Output layer
        self.init_weights(final_layer)  # Initialize the final layer
        layers.append(final_layer)

        # Add output activation function
        if output_activation == 'sigmoid':
            layers.append(nn.Sigmoid())
        elif output_activation == 'softmax':
            layers.append(nn.Softmax(dim=1))
        elif output_activation == 'relu':
            layers.append(nn.ReLU())

        self.network = nn.Sequential(*layers)

    def init_weights(self, layer):
        """Initialize layer weights based on the specified initializer."""
        if self.initializer == 'glorot_uniform':
            nn.init.xavier_uniform_(layer.weight)
        elif self.initializer == 'glorot_normal':
            nn.init.xavier_normal_(layer.weight)
        elif self.initializer == 'kaiming_uniform':
            nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
        elif self.initializer == 'kaiming_normal':
            nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
        else:
            raise ValueError(f"Initializer '{self.initializer}' not supported.")

    def forward(self, x):
        return self.network(x)

    def predict_survival(self, x):
        """
        Predict survival probabilities for input data.
        """
        with torch.no_grad():  # Ensure no gradient computation for inference
            logits = self.forward(x)
            survival_prob = torch.exp(-torch.cumsum(logits, dim=1))  # Compute survival probabilities
        return survival_prob

# Loss function with the survival constraint
class BaseMultiTaskModel:
    def __init__(self, structure, bins=100, auto_scaler=True):
        self.loss_values = []
        self.bins = bins
        self.structure = structure
        self.auto_scaler = auto_scaler
        self.scaler = StandardScaler() if auto_scaler else None

    # Function to get time discretization (bins)
    def get_times(self, T, is_min_time_zero=True, extra_pct_time=0.1):
        max_time = max(T)
        min_time = 0 if is_min_time_zero else min(T)
        p = extra_pct_time
        self.times = np.linspace(min_time, max_time * (1. + p), self.bins)
        self.num_times = len(self.times) - 1
        self.time_buckets = [(self.times[i], self.times[i+1]) for i in range(len(self.times)-1)]

    # Function to compute X and Y matrices
    def compute_XY(self, X, T, E, is_min_time_zero=True, extra_pct_time=0.1):
        self.get_times(T, is_min_time_zero, extra_pct_time)
        n_units = T.shape[0]
        Y_cens, Y_uncens = [], []
        X_cens, X_uncens = [], []

        # For each sample, compute censored and uncensored data
        for i, (t, e) in enumerate(zip(T, E)):
            y = np.zeros(self.num_times + 1)
            index = np.argmin([abs(a_j_1 - t) for (a_j_1, a_j) in self.time_buckets])

            if e == 1:  # Uncensored
                y[index] = 1.
                X_uncens.append(X[i, :].tolist())
                Y_uncens.append(y.tolist())
            else:  # Censored
                y[(index):] = 1.
                X_cens.append(X[i, :].tolist())
                Y_cens.append(y.tolist())

        X_cens, X_uncens = torch.FloatTensor(X_cens), torch.FloatTensor(X_uncens)
        Y_cens, Y_uncens = torch.FloatTensor(Y_cens), torch.FloatTensor(Y_uncens)

        X_cens = X_cens.float()
X_uncens = X_uncens.float()
Y_cens = Y_cens.float()
Y_uncens = Y_uncens.float()
Triangle = Triangle.float()
        return X_cens, X_uncens, Y_cens, Y_uncens

    # Norm difference for MTLR last layer smoothing
    def norm_diff(self, W):
        dims = len(W.shape)
        if dims == 1:
            diff = W[1:] - W[:-1]
        elif dims == 2:
            diff = W[1:, :] - W[:-1, :]
        return torch.sum(diff * diff)

    # Loss function with data and constraint loss
    def loss_function(self, model, X_cens, X_uncens, Y_cens, Y_uncens, Triangle, l2_reg, l2_smooth):
        score_uncens = model(X_uncens)
        phi_uncens = torch.exp(torch.mm(score_uncens, Triangle))
        reduc_phi_uncens = torch.sum(phi_uncens * Y_uncens, dim=1)

        score_cens = model(X_cens)
        phi_cens = torch.exp(torch.mm(score_cens, Triangle))
        reduc_phi_cens = torch.sum(phi_cens * Y_cens, dim=1)

        z_uncens = torch.exp(torch.mm(score_uncens, Triangle))
        reduc_z_uncens = torch.sum(z_uncens, dim=1)

        z_cens = torch.exp(torch.mm(score_cens, Triangle))
        reduc_z_cens = torch.sum(z_cens, dim=1)

        loss = - (
            torch.sum(torch.log(reduc_phi_uncens)) +
            torch.sum(torch.log(reduc_phi_cens)) -
            torch.sum(torch.log(reduc_z_uncens)) -
            torch.sum(torch.log(reduc_z_cens))
        )

        # Adding L2 regularization and smoothing for the MTLR model
        nb_set_parameters = len(list(model.parameters()))
        for i, w in enumerate(model.parameters()):
            loss += l2_reg * torch.sum(w * w)
            if i >= nb_set_parameters - 2:
                loss += l2_smooth * self.norm_diff(w)

        # ------------- Survival Constraint --------------
        # Add constraint to enforce non-increasing survival function
        survival_constraint_loss = self.survival_constraint(model, X_uncens)
        loss += survival_constraint_loss
        # -----------------------------------------------

        return loss

    # Survival constraint ensures survival function is non-increasing
    def survival_constraint(self, model, X_uncens):
        """
        Constraint to ensure the survival function is non-increasing.
        """
        loss = 0
        for i in range(X_uncens.shape[0]):
            # Get survival probabilities for all bins for this individual
            S_pred = model.predict_survival(X_uncens[[i]])

            # Create a tensor for time points (bins)
            t = torch.tensor(np.arange(self.bins), requires_grad=True, dtype=torch.float32) # Ensure t is part of the computation graph

            # Compute the gradient (time derivative) of the survival function
            S_t = torch.autograd.grad(S_pred, t, grad_outputs=torch.ones_like(S_pred), create_graph=True)[0]

            # Penalize positive values (non-decreasing survival function)
            loss += torch.sum(torch.clamp(S_t, min=0))

        return loss


    def fit(self, X_train, T_train, E_train, X_val, T_val, E_val,
            init_method='glorot_uniform', optimizer='adam', lr=1e-5,
            num_epochs=1000, dropout=0.5, l2_reg=1e-2, l2_smooth=1e-2,
            batch_normalization=False, bn_and_dropout=True, verbose=True,
            extra_pct_time=0.1, is_min_time_zero=True, initializer='glorot_uniform', output_activation='sigmoid'):

        """
        Fit the estimator with both training and validation sets.
        """

        # Initialize attributes
        self.loss_values = []
        self.val_loss_values = []

        # Extracting data parameters
        nb_units, self.num_vars = X_train.shape
        input_shape = self.num_vars

        # Scaling data
        if self.auto_scaler:
            X_train = self.scaler.fit_transform(X_train)
            X_val = self.scaler.transform(X_val)

        # Build the time axis, time buckets, and output Y for both training and validation
        X_cens_train, X_uncens_train, Y_cens_train, Y_uncens_train = self.compute_XY(X_train, T_train, E_train, is_min_time_zero, extra_pct_time)
        X_cens_val, X_uncens_val, Y_cens_val, Y_uncens_val = self.compute_XY(X_val, T_val, E_val, is_min_time_zero, extra_pct_time)

        # Initialize the model
        model = NeuralNet(input_shape, self.num_times, self.structure, dropout, initializer, output_activation)

        # Create the Triangular matrix
        Triangle = np.tri(self.num_times, self.num_times + 1, dtype=np.float32)
        Triangle = torch.FloatTensor(Triangle)

        # Initialize optimizer (changed here)
        if optimizer == 'adam':
            optimizer = optim.Adam(model.parameters(), lr=lr)
        elif optimizer == 'sgd':
            optimizer = optim.SGD(model.parameters(), lr=lr)
        else:
            raise ValueError(f"Unknown optimizer: {optimizer}")

        # Training loop
        for epoch in range(num_epochs):
            model.train()  # Set model to training mode

            # Compute training loss
            loss_train = self.loss_function(model, X_cens_train, X_uncens_train, Y_cens_train, Y_uncens_train, Triangle, l2_reg, l2_smooth)
            self.loss_values.append(loss_train.item())

            # Perform a forward pass on the validation data
            model.eval()  # Set model to evaluation mode
            with torch.no_grad():
                loss_val = self.loss_function(model, X_cens_val, X_uncens_val, Y_cens_val, Y_uncens_val, Triangle, l2_reg, l2_smooth)
            self.val_loss_values.append(loss_val.item())

            # Gradient update
            optimizer.zero_grad()
            loss_train.backward()
            optimizer.step()

            # Logging
            if verbose and (epoch % 100 == 0 or epoch == num_epochs - 1):
                print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {loss_train.item():.4f}, Validation Loss: {loss_val.item():.4f}')

        self.model = model.eval()
        return self.model

    def plot_loss(self):
        plt.figure(figsize=(10, 6))
        plt.plot(self.loss_values, label='Training Loss')
        plt.plot(self.val_loss_values, label='Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss over Epochs')
        plt.legend()
        plt.grid(True)
        plt.show()

    def predict(self, x, t=None):
        """ Predicting the hazard, density and survival functions """

        # Transform to PyTorch tensor
        x = torch.FloatTensor(x)

        # Predict using the trained model
        score_torch = self.model(x)
        score = score_torch.data.numpy()

        # Creating time triangles
        Triangle1 = np.tri(self.num_times, self.num_times + 1)
        Triangle2 = np.tri(self.num_times + 1, self.num_times + 1)

        # Calculate score, density, hazard, and survival
        phi = np.exp(np.dot(score, Triangle1))
        div = np.repeat(np.sum(phi, 1).reshape(-1, 1), phi.shape[1], axis=1)
        density = phi / div
        Survival = np.dot(density, Triangle2)
        hazard = density[:, :-1] / Survival[:, 1:]

        # Return full functions or predictions at time t
        if t is None:
            return hazard, density, Survival
        else:
            min_abs_value = [abs(a_j_1 - t) for (a_j_1, a_j) in self.time_buckets]
            index = np.argmin(min_abs_value)
            return hazard[:, index], density[:, index], Survival[:, index]


    def evaluate(self, X, T, E):
        """
        Evaluate the model using the concordance index and integrated Brier score.

        Parameters:
        X (np.ndarray): Feature matrix.
        T (np.ndarray): Array of times.
        E (np.ndarray): Array of event indicators.

        Returns:
        float: Concordance index, Integrated Brier score.
        """
        # Scale the features
        if self.auto_scaler:
            X = self.scaler.transform(X)

        # Get predictions
        _, _, Survival = self.predict(X)

        # Calculate concordance index
        c_index = concordance_index_censored(E.astype(bool), T, -Survival[:, -1])[0]

        # Calculate the Brier score
        time_grid = np.linspace(0, T.max(), 100)
        brier_scores = []
        for t in time_grid:
            pred_risk = np.mean(Survival[:, :-1] >= t, axis=1)
            brier_score = np.mean((E - pred_risk) ** 2)
            brier_scores.append(brier_score)
        integrated_brier_score = np.trapz(brier_scores, time_grid) / T.max()

        # Binary classification at a chosen threshold (e.g., 0.5)
        risk_threshold = 0.5
        predicted_events = (Survival[:, -1] > risk_threshold).astype(int)

        # Accuracy
        accuracy = accuracy_score(E, predicted_events)

        # ROC Curve and AUC
        fpr, tpr, thresholds = roc_curve(E, Survival[:, -1])
        roc_auc = auc(fpr, tpr)

        # Plot ROC curve
        self.plot_roc_curve(fpr, tpr, roc_auc)

        # Confusion Matrix
        conf_matrix = confusion_matrix(E, predicted_events)
        self.plot_confusion_matrix(conf_matrix)

        # Plot Brier Score and Actual vs. Predicted Events
        self.plot_brier_score(time_grid, brier_scores, integrated_brier_score)
        self.plot_actual_vs_predicted(T, E, Survival)

        return c_index, integrated_brier_score, accuracy, roc_auc

    def plot_brier_score(self, time_grid, brier_scores, integrated_brier_score):
        """Plot the Brier score curve."""
        plt.figure(figsize=(10, 6))
        plt.plot(time_grid, brier_scores, label=f'IBS: {integrated_brier_score:.2f}', color='blue')
        plt.axhline(y=0.25, color='red', linestyle='--', label='0.25 limit')
        plt.title('Prediction Error Curve with Integrated Brier Score')
        plt.xlabel('Time')
        plt.ylabel('Brier Score')
        plt.legend()
        plt.grid(True)
        plt.show()

    def plot_actual_vs_predicted(self, T, E, Survival):
        """Plot the actual vs. predicted number of events over time."""
        time_grid = np.linspace(0, T.max(), 100)
        actual_events = np.array([(T >= t).sum() for t in time_grid])
        predicted_events = np.array([(Survival[:, :-1] >= t).sum() for t in time_grid])

        plt.figure(figsize=(10, 6))
        plt.plot(time_grid, actual_events, label='Actual', color='red')
        plt.plot(time_grid, predicted_events, label='Predicted', color='blue')
        plt.fill_between(time_grid, predicted_events - 1.96 * np.sqrt(predicted_events),
                         predicted_events + 1.96 * np.sqrt(predicted_events), color='blue', alpha=0.2,
                         label='Confidence Intervals')
        plt.title('Actual vs Predicted Number of Events Over Time')
        plt.xlabel('Time')
        plt.ylabel('Number of Events')
        plt.legend()
        plt.grid(True)
        plt.show()

    def plot_roc_curve(self, fpr, tpr, roc_auc):
        """
        Plot the ROC curve.

        Parameters:
        fpr (np.ndarray): False positive rates.
        tpr (np.ndarray): True positive rates.
        roc_auc (float): Area under the ROC curve.
        """
        plt.figure(figsize=(8, 6))
        plt.plot(fpr, tpr, color='blue', label=f'ROC curve (AUC = {roc_auc:.2f})')
        plt.plot([0, 1], [0, 1], color='red', linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic (ROC)')
        plt.legend(loc="lower right")
        plt.grid(True)
        plt.show()

    def plot_confusion_matrix(self, cm):
        """
        Plot the confusion matrix.

        Parameters:
        cm (np.ndarray): Confusion matrix.
        """
        plt.figure(figsize=(6, 6))
        plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
        plt.title('Confusion Matrix')
        plt.colorbar()
        tick_marks = np.arange(2)
        plt.xticks(tick_marks, ['No Event', 'Event'], rotation=45)
        plt.yticks(tick_marks, ['No Event', 'Event'])

        thresh = cm.max() / 2.0
        for i, j in np.ndindex(cm.shape):
            plt.text(j, i, format(cm[i, j], 'd'), horizontalalignment="center",
                    color="white" if cm[i, j] > thresh else "black")

        plt.ylabel('True label')
        plt.xlabel('Predicted label')
        plt.tight_layout()
        plt.show()



In [53]:
# Create MTLR object with the necessary parameters
structure = [{'activation': 'tanh', 'num_units': 8}, {'activation': 'relu', 'num_units': 8}]
model = BaseMultiTaskModel(structure=structure, bins=28)

# Fit the model with training and validation data
model.fit(
    X_train=X_train,
    T_train=T_train,
    E_train=E_train,
    X_val=X_val,
    T_val=T_val,
    E_val=E_val,
    num_epochs=10000,
    optimizer='adam',
    lr=1e-1,
    init_method='glorot_normal',
    dropout=0.0,
    l2_reg=1e-2,
    l2_smooth=1e-1,
    batch_normalization=True,
    bn_and_dropout=True,
    verbose=True
)

# Plot the training and validation loss curves
model.plot_loss()


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn