In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm
import itertools
import pickle
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, precision_score, recall_score, brier_score_loss 


In [2]:
alpha = 0.5
gamma = 1
batch_size = 256
seq_len = 1
learning_rate = 1e-4

max_epoch = 100
experiment_time = 5
limit_early_stop_count = 5

show_shap_flag = True
select_feature_flag = False
use_upsample = False
use_mini_feature = False
only_Weaning = False

task_name_list = ['Weaning_successful']

device = torch.device("cuda")

In [3]:
class MLP_MTL(nn.Module):
    def __init__(self, input_dim, task_name_list, dropout_ratio=0.0):
        super(MLP_MTL, self).__init__()

        self.dropout = nn.Dropout(dropout_ratio)
        self.relu = nn.ReLU()  # Activation function for hidden layers
        self.sigmoid = nn.Sigmoid()
        self.task_name_list = task_name_list
        self.num_tasks = len(task_name_list)
        hidden_dim = [256, 128, 64, 32]
        output_size = 1

        # Bottom
        self.bt_fc1 = nn.Linear(input_dim, hidden_dim[0])
        self.bt_fc2 = nn.Linear(hidden_dim[0], hidden_dim[1])
        self.bt_fc3 = nn.Linear(hidden_dim[1], hidden_dim[2])

        # Towers
        self.task_fc0 = nn.ModuleList([nn.Linear(hidden_dim[2], hidden_dim[3]) for _ in range(self.num_tasks)])
        self.task_fc1 = nn.ModuleList([nn.Linear(hidden_dim[3], output_size) for _ in range(self.num_tasks)])
    
    def data_check(self,x):
        if isinstance(x, np.ndarray):
            x = torch.tensor(x, dtype=torch.float32)
        if x.ndim == 3:
            x = x.reshape(x.shape[0], x.shape[1] * x.shape[2])  # Flatten 
            
        x = x.to(device)
        return x
    
    def forward(self, x):
        x = self.data_check(x)

        # Bottom
        x = self.bt_fc1(x)
        x = self.relu(x)
        x = self.bt_fc2(x)
        x = self.relu(x)
        x = self.bt_fc3(x)
        h = self.relu(x)
        h = self.dropout(h)  

        # Towers
        task_out = {}
        for task_index in range(self.num_tasks):
            task_name = self.task_name_list[task_index]
            hi = self.task_fc0[task_index](h)
            hi = self.relu(hi)
            hi = self.dropout(hi)
            hi = self.task_fc1[task_index](hi)
            hi = self.sigmoid(hi)
            task_out[task_name] = hi    
            
        if len(self.task_name_list) == 1:
            return task_out[self.task_name_list[0]]
        else:
            return task_out
    
    def predict_prob(self, x):
        self.eval()
        prob_dict = self.forward(x)
        
        if len(self.task_name_list) == 1:
            prob_dict_true = {}
            prob_dict_true[self.task_name_list[0]] = prob_dict
            return prob_dict_true
        return prob_dict

    def predict_proba(self, x):
        self.eval()
        prob_dict = self.forward(x)
        
        if len(self.task_name_list) == 1:
            prob_dict_true = {}
            prob_dict_true[self.task_name_list[0]] = prob_dict
            return prob_dict_true
        
        return prob_dict
    
    def predict(self, x, threshold = 0.5):
        self.eval()
        prob_dict = self.predict_prob(x)
        pred_dict = {}
        
        for key, value in prob_dict.items():
            #tensor轉numpy
            value = value.cpu().detach().numpy()
            pred_class = [1 if x > threshold else 0 for x in value]
            pred_dict[key] = np.array(pred_class) 
        return pred_dict
    
    def evaluate(self,X,label,task_name,criterion):
        with torch.no_grad():
            prob = self.predict_prob(X)[task_name].cpu().detach().numpy() #tensor=>numpy
            pred = self.predict(X)[task_name] 
            score = compute_scores(label,pred,prob)
            score['task'] = task_name
            loss = criterion(torch.from_numpy(prob).to(device),torch.from_numpy(label).to(device)).item()
            score['loss'] = loss/len(label)
            return score
    

In [4]:
def compute_scores(y_true, y_pred,y_prob):
    if np.any(np.isnan(y_prob)):
        print(y_prob)
        input()
        
    scores = {}
    try:
        scores['task'] = 'Null'
        scores['auroc'] = round(roc_auc_score(y_true, y_prob), 3)
        scores['acc'] = round(accuracy_score(y_true, y_pred), 3)
        scores['f1'] = round(f1_score(y_true, y_pred), 3)
        scores['pre'] = round(precision_score(y_true, y_pred), 3)
        scores['recall'] = round(recall_score(y_true, y_pred), 3)
        scores['brier_score'] = round(brier_score_loss(y_true, y_prob), 3)
    except Exception as e:
        print("An error occurred:", str(e))
    return scores

In [5]:
"""
Input:
    model
    dict: Mydataset
    loss_function
Output:
    score: dict + dict
    result: dict => ['total_auc','total_loss']
"""
def test(model, dataset_dict, criterion, is_show = True , only_Weaning = False):
    model.eval()

    task_name_list = list(dataset_dict.keys())
    score = {}
    result = {'total_auc': 0, 'total_loss': 0}
    for task_name in task_name_list:  # 循環每個任務
        X = dataset_dict[task_name].inputs.numpy()
        Y = dataset_dict[task_name].labels.unsqueeze(1).numpy()
    
        score[task_name] = model.evaluate(X,Y,task_name,criterion)
        
        if only_Weaning == True and 'Weaning_succecssful' in task_name_list:
            if task_name == 'Weaning_succecssful':
                result['total_auc'] = result['total_auc'] + score[task_name]['auroc']
                result['total_loss'] = result['total_loss'] + score[task_name]['loss']
        else:
            result['total_auc'] = result['total_auc'] + score[task_name]['auroc']
            result['total_loss'] = result['total_loss'] + score[task_name]['loss']
            
        if is_show:
            print(score[task_name])
    
    return score,result

"""
local_best_model_dict: #dict{'task_name':{'model','performance(target_score)','id'}}
model
"""
def test2(local_best_model_dict, modelr, dataset_dict, criterion, is_show = True):
    score = {}
    result = {'total_auc': 0, 'total_loss': 0}
    task_name_list = list(dataset_dict.keys())
    
    for task_name in task_name_list:
        print(f"task: {task_name} ")
        print(f"{local_best_model_dict[task_name]['performance']}")
        modelr.load_state_dict(local_best_model_dict[task_name]['model'])
        modelr.eval()
        X = dataset_dict[task_name].inputs.numpy()
        Y = dataset_dict[task_name].labels.unsqueeze(1).numpy()
        score[task_name] = modelr.evaluate(X,Y,task_name,criterion)
        result['total_auc'] = result['total_auc'] + score[task_name]['auroc']
        result['total_loss'] = result['total_loss'] + score[task_name]['loss']
        if is_show:
            print(score[task_name])
            
    return score,result

In [6]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class MyDataset(Dataset):
    def __init__(self, np_X_scalar,np_X_original, np_Y):
        self.inputs = torch.from_numpy(np_X_scalar).float()
        self.inputs_original = torch.from_numpy(np_X_original).float()
        self.labels = torch.from_numpy(np_Y).float()

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]
    
    def remove_samples(self, feature_index, threshold, condition_type):
        """
        Remove samples based on a specified condition on a specific feature.

        Parameters:
        - feature_index (int): Index of the feature.
        - threshold (float): Threshold value for the condition.
        - condition_type (str): Type of condition ('type1' for '<' or 'type2' for '>=').
        """
        if condition_type == 'type1':
            indices_to_remove = torch.nonzero(self.inputs[:, feature_index] < threshold).squeeze()
        elif condition_type == 'type2':
            indices_to_remove = torch.nonzero(self.inputs[:, feature_index] >= threshold).squeeze()
        else:
            raise ValueError("Invalid condition_type. Use 'type1' for '<' or 'type2' for '>='.")

        # Remove samples
        self.inputs = torch.index_select(self.inputs, 0, indices_to_remove)
        self.inputs_original = torch.index_select(self.inputs_original, 0, indices_to_remove)
        self.labels = torch.index_select(self.labels, 0, indices_to_remove)
    
class BCEFocalLoss(torch.nn.Module):

    def __init__(self, gamma=2, alpha=0.25, reduction='elementwise_mean'):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
 
    def forward(self, _input, target):
        pt = _input
        alpha = self.alpha
        loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
               (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
        if self.reduction == 'elementwise_mean':
            loss = torch.mean(loss)
        elif self.reduction == 'sum':
            loss = torch.sum(loss)
        return loss    

    
def check_label_distribution (data_Y):
    count_1 = np.count_nonzero(data_Y == 1)
    count_0 = np.count_nonzero(data_Y == 0)
    count_others = np.count_nonzero((data_Y != 1) & (data_Y != 0))
    ratio_1 = round(count_1/len(data_Y)*100,2)
    ratio_0 = round(count_0/len(data_Y)*100,2)
    ratio_others = round(count_others/len(data_Y)*100,2)
    print(f'Distribution: 1=>{count_1}({ratio_1}%),  0=>{count_0}({ratio_0}%),  others=>{count_others}({ratio_others}%)')

    
def upsampling_auto(X,X_original,Y,up_ratio):
    check_label_distribution(Y)
    zero_idx = np.where(Y == 0)[0]
    one_idx = np.where(Y == 1)[0]
    other_idx = np.where((Y != 1) & (Y != 0))[0]
    if len(other_idx > 0):
        return X,Y
    repeated_data_X = np.tile(X[one_idx], (up_ratio, 1, 1))
    repeated_data_X_original = np.tile(X_original[one_idx], (up_ratio, 1, 1))
    repeated_data_Y = np.tile(Y[one_idx], (up_ratio))

    X_upsampled = np.vstack((X[zero_idx], repeated_data_X))
    X_original_upsampled = np.vstack((X_original[zero_idx], repeated_data_X_original))

    Y_upsampled = np.concatenate((Y[zero_idx], repeated_data_Y)) 
    return X_upsampled,X_original_upsampled, Y_upsampled

In [7]:
import numpy as np

"""
Input:
    X: numpy
    feature_name_list : List
    select_feature_list : List   (必須是feature_name_list的子集)
Output
    select_feature_list data
"""
def select_features(X, feature_name_list, select_feature_list):
    invalid_features = set(select_feature_list) - set(feature_name_list)
    if invalid_features:
        raise ValueError(f"Invalid features in select_feature_list: {invalid_features}")
    selected_feature_indices = [feature_name_list.index(feature) for feature in select_feature_list]
    X_selected = X[:, :, selected_feature_indices]

    return X_selected

In [8]:
import numpy as np

def read_data(task_name_list,data_date,data_type, select_feature_list = [], batch_size = 256,use_upsample = False):
    batch_size = 256
    data_path = "data/sample/standard_data"
    
    #Feature name
    df_feature = pd.read_csv("data/sample/full_feature_name.csv")
    feature_name_list = df_feature.columns.to_list()

   
    #dataset
    dataset_dict = {}
    original_data_dict = {}
    for task_name in task_name_list:
        X_scalar = np.load(f"{data_path}/{data_type}_scalar_X_{task_name}.npy", allow_pickle=True)
        X_original = np.load(f"{data_path}/{data_type}_X_{task_name}.npy", allow_pickle=True)
        X_original_with_id = np.load(f"{data_path}/{data_type}_X_with_id_{task_name}.npy", allow_pickle=True)
        
        if len(select_feature_list)>0:
            X_scalar = select_features(X_scalar,feature_name_list,select_feature_list)
            X_original = select_features(X_original,feature_name_list,select_feature_list)
            feature_name_list = select_feature_list
    
            assert X_scalar.shape[2] == len(select_feature_list)
            assert X_original.shape[2] == len(select_feature_list)
        X_original_with_id = X_original_with_id[:,:,:1]    
        Y = np.load(f"{data_path}/20240129_{data_type}_Y_{task_name}.npy", allow_pickle=True)
        
        if use_upsample:
            if task_name == 'Weaning_successful' and data_type == 'test':
                X_scalar,X_original,Y = upsampling_auto(X_scalar,X_original,Y,2)
        dataset_dict[task_name] = MyDataset(X_scalar,X_original,Y)
        original_data_dict['X_scalar'] = X_scalar
        original_data_dict['X'] = X_original
        original_data_dict['X_with_id'] = X_original_with_id
        original_data_dict['Y'] = Y
    
    #dataloader
    loader_dict = {}
    for key, dataset in dataset_dict.items():        
        loader_dict[key] = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
    
    return dataset_dict,loader_dict,feature_name_list,original_data_dict


In [9]:
def MTL_to_STL(multi_task_model):
    single_task_models = {}

    for task_index, task_name in enumerate(multi_task_model.task_name_list):
        
        single_task_model = MLP_MTL(input_dim, [task_name])  
        single_task_model.to(device) 

        single_task_model.bt_fc1.weight.data = multi_task_model.bt_fc1.weight.data.clone()
        single_task_model.bt_fc1.bias.data = multi_task_model.bt_fc1.bias.data.clone()

        single_task_model.bt_fc2.weight.data = multi_task_model.bt_fc2.weight.data.clone()
        single_task_model.bt_fc2.bias.data = multi_task_model.bt_fc2.bias.data.clone()

        single_task_model.bt_fc3.weight.data = multi_task_model.bt_fc3.weight.data.clone()
        single_task_model.bt_fc3.bias.data = multi_task_model.bt_fc3.bias.data.clone()

        single_task_model.task_fc0[0].weight.data = multi_task_model.task_fc0[task_index].weight.data.clone()
        single_task_model.task_fc0[0].bias.data = multi_task_model.task_fc0[task_index].bias.data.clone()

        single_task_model.task_fc1[0].weight.data = multi_task_model.task_fc1[task_index].weight.data.clone()
        single_task_model.task_fc1[0].bias.data = multi_task_model.task_fc1[task_index].bias.data.clone()

        single_task_models[task_name] = single_task_model
    return single_task_models

In [10]:
import shap
import matplotlib.pyplot as plt

def calculate_feature_important(shap_value,feature_name_list):
    abs_shap_value = np.abs(shap_value)
    sum_per_feature = np.sum(abs_shap_value, axis=0)
    sorted_feature_indices = np.argsort(sum_per_feature)[::-1] #[::-1]是reversed
    sorted_feature_names = [feature_name_list[i] for i in sorted_feature_indices]
    return sorted_feature_names, sum_per_feature

def get_model_shap(model,data_X_train,data_X_test,data_X_test_original,feature_name_list,task_name,use_mini_sample = True,n_sample = 100):
    
    max_sample = 1000
    
    seq_day = data_X_train.shape[1]
    feature_count = data_X_train.shape[2]
    
    if use_mini_sample:
        background_data = torch.from_numpy(data_X_train[:max_sample]).float().to(device)
        shap_data = torch.from_numpy(data_X_test[:max_sample]).float().to(device)
        shap_data_original = torch.from_numpy(data_X_test_original[:max_sample]).float().to(device)
    else:
        background_data = torch.from_numpy(data_X_train[:]).float().to(device)
        shap_data = torch.from_numpy(data_X_test[:]).float().to(device)
        shap_data_original = torch.from_numpy(data_X_test_original[:]).float().to(device)

    model.eval()
    explainer = shap.GradientExplainer(model, background_data)
    
    shap_values = explainer.shap_values(shap_data,nsamples=n_sample)
    shap_values = np.array(shap_values)
    
    shap_value_flatten = np.zeros((len(shap_data),seq_day*feature_count))
    shap_data_flatten = np.zeros((len(shap_data),seq_day*feature_count))
    
    for i in range(0,len(shap_data)):
        count=0
        for j in range(feature_count):
            for k in range(seq_day):
                shap_value_flatten[i][count]=shap_values[i][k][j]  
                shap_data_flatten[i][count]=shap_data_original[i][k][j]  
                count += 1
    feature_important,_ = calculate_feature_important(shap_value_flatten, feature_name_list)
    return feature_important, shap_value_flatten, shap_data_flatten

"""
Input:
    shap_value_flatten (sample,feature_flatten)
    shap_data_flatten (sample,feature_flatten)
    max_display 
"""
def show_shap(shap_value_flatten, shap_data_flatten,feature_name_list, max_display = 20,task_name = ''):
    fig = shap.summary_plot(shap_value_flatten,shap_data_flatten,feature_names=feature_name_list, show=False,max_display = max_display)
    #plt.title(f"***Task:{task_name}***")
    plt.xticks(fontsize=20, fontweight='bold', fontfamily='Arial')
    plt.yticks(fontsize=20, fontweight='bold', fontfamily='Arial')
    plt.xlabel('SHAP Value',fontsize=24, fontweight='bold', fontfamily='Arial')
    
    ax = plt.gca()  
    #plt.savefig(f'./PDP/SHAP.tif', bbox_inches = 'tight', dpi=300)
    plt.show()





In [11]:
def pdp_plot(x, y, feature_name, point_color='black'):
    point_size = 4
    #plt.figure(figsize=(6,4))
    fig, ax = plt.subplots(figsize=(6, 6))
    plt.scatter(x, y, color=point_color,s=point_size)
    #plt.xlabel(feature_name)
    plt.xlabel(feature_name, fontsize=26, fontweight='bold', fontfamily='Arial')
    plt.ylabel('SHAP Value', fontsize=26, fontweight='bold', fontfamily='Arial')
    plt.tick_params(axis='both', which='both', labelsize=18)
    plt.legend()
    
    if feature_name == 'Peak Airway Pressure':
        plt.xticks([10,13,16,19,22,25], [10,13,16,19,22,25])
    if feature_name == 'RASS':
        plt.xticks([-5,-4,-3,-2,-1,0,1,2,3], [-5,-4,-3,-2,-1,0,1,2,3])
    if feature_name == 'FiO2':
        plt.xticks([20,40,60,80,100], [20,40,60,80,100])
    plt.xticks(fontweight='bold')
    plt.yticks([])
    plt.legend().set_visible(False)
    #plt.legend(loc='lower center')
    #plt.savefig(f'./PDP/{feature_name}.png', bbox_inches = 'tight', dpi=300)
    plt.show()

In [12]:
from sklearn.metrics import confusion_matrix
from sklearn.calibration import calibration_curve
from sklearn.metrics import roc_curve, auc, precision_recall_curve

""" Calibration curve """
def plot_calibration_curve(model_list,model_name_list,x,y,task_name,num_bins=40):
    if torch.is_tensor(x):
        x = x.numpy()
        y = y.numpy()
        
    assert len(model_list) == len(model_name_list)
    
    for i in range(len(model_list)):
        model = model_list[i]
        model.eval()
        if torch.is_tensor(x):
            x = x.cpu().numpy()
            #y = y.cpu().numpy()
        x = torch.from_numpy(x).float().to(device)
        y_true = y
        out = model(x)
        y_pred_prob = out[:,:]
        y_pred_prob = y_pred_prob.float().cpu().detach().numpy()
        
        prob_true, prob_pred = calibration_curve(y_true, y_pred_prob, n_bins=num_bins, strategy='quantile')
        perfect_model = np.linspace(0, 1, num_bins)
        
        hoose_colar = 'darkorange'
        if i == 1:
            choose_colar = 'blue'
        else:
            choose_colar = 'green'
        choose_colar = 'black'
        
        #plt.plot(prob_pred, prob_true, color=choose_colar, label= f'{model_name_list[i]}')
        #plt.plot(prob_pred, prob_true, marker='o', color='black', label= f'{model_name_list[i]}')
        plt.plot(prob_pred, prob_true, marker='o', color='black')
        
    font_properties = {'size': 18,  'family': 'Arial', 'fontweight':'bold'}
    plt.xticks(fontproperties='Arial', **font_properties)
    plt.yticks(fontproperties='Arial', **font_properties)
    
    #plt.plot(perfect_model, perfect_model, color='black', linestyle='--', label='Perfectly calibrated')
    plt.plot(perfect_model, perfect_model, color='black', linestyle='--')
    plt.xlabel('Average Predicted Probability', fontsize=24, fontweight='bold', fontfamily='Arial') #平均預測機率
    plt.ylabel('Ratio of Positives', fontsize=24, fontweight='bold', fontfamily='Arial')
    #plt.title(f'Calibration Chart')
    #plt.legend()
    plt.legend().set_visible(False)
    #plt.savefig(f'./FG2/calibration.png', bbox_inches = 'tight', dpi=300)
    plt.show()
    
""" AUROC curve """
def plot_roc_curve(model_list,model_name_list,x,y,task_name):
    #plt.figure(figsize=(5, 5))
    if torch.is_tensor(x):
        x = x.numpy()
        y = y.numpy()
        
    assert len(model_list) == len(model_name_list)
    for i in range(len(model_list)):
        model = model_list[i]
        model.eval()
        if torch.is_tensor(x):
            x = x.cpu().numpy()
            #y = y.cpu().numpy()
        x = torch.from_numpy(x).float().to(device)
        y_true = y
        out = model(x)
        y_pred_prob = out[:,:]
        y_pred_prob = y_pred_prob.float().cpu().detach().numpy()
        fpr, tpr, thresholds = roc_curve(y_true, y_pred_prob)
        roc_auc = auc(fpr, tpr)
        
        choose_colar = 'black'    
        #plt.plot(fpr, tpr, marker='o', lw=2, label=f'{model_name_list[i]}({roc_auc:.4f})')
        plt.plot(fpr, tpr, color=choose_colar, lw=2)
        plt.plot([0, 1], [0, 1], color=choose_colar, lw=2, linestyle='--')
        
    plt.xlabel('False Positive Rate', fontsize=24, fontweight='bold', fontfamily='Arial')
    plt.ylabel('True Positive Rate', fontsize=24, fontweight='bold', fontfamily='Arial')
    font_properties = {'size': 18,  'family': 'Arial', 'fontweight':'bold'}
    plt.xticks(fontproperties='Arial', **font_properties)
    plt.yticks(fontproperties='Arial', **font_properties)
    plt.legend().set_visible(False)
    #plt.savefig(f'./FG2/AUROC.png', bbox_inches = 'tight', dpi=300)
    plt.show()
    
""" AUPRC """
def plot_pr_curve(model_list,model_name_list, x, y, task_name):
                  
    for i in range(len(model_list)):
        model = model_list[i]
                  
        if torch.is_tensor(x):
            x = x.numpy()
            y = y.numpy()
        model.eval()
        x = torch.from_numpy(x).float().to(device)
        y_true = y
        out = model(x)
        y_pred_prob = out[:, :]
        y_pred_prob = y_pred_prob.float().cpu().detach().numpy()


        precision, recall, thresholds = precision_recall_curve(y_true, y_pred_prob)
        auprc = auc(recall, precision)

        plt.figure(figsize=(5, 5))
        plt.plot(recall, precision, color='darkorange', lw=2, label=f'{model_name_list[i]}({auprc:.2f})')
        #plt.plot(recall, precision, color='darkorange', lw=2, label=f'Precision-Recall curve (area = {auprc:.2f})')
    plt.xlabel('False Positive Rate', fontsize=24, fontweight='bold', fontfamily='Arial')
    plt.ylabel('True Positive Rate', fontsize=24, fontweight='bold', fontfamily='Arial')
    font_properties = {'size': 18,  'family': 'Arial', 'fontweight':'bold'}
    plt.xticks(fontproperties='Arial', **font_properties)
    plt.yticks(fontproperties='Arial', **font_properties)
    plt.legend().set_visible(False)
    #plt.savefig(f'./FG2/AUPRC.png', bbox_inches = 'tight', dpi=300)
    
    plt.show()

def calculate_net_benefit_model(thresh_group, y_pred_score, y_label):
    net_benefit_model = np.array([])
    for thresh in thresh_group:
        y_pred_label = y_pred_score > thresh
        tn, fp, fn, tp = confusion_matrix(y_label, y_pred_label).ravel()
        n = len(y_label)
        net_benefit = (tp / n) - (fp / n) * (thresh / (1 - thresh))
        net_benefit_model = np.append(net_benefit_model, net_benefit)
    return net_benefit_model


def calculate_net_benefit_all(thresh_group, y_label):
    net_benefit_all = np.array([])
    tn, fp, fn, tp = confusion_matrix(y_label, y_label).ravel()
    total = tp + tn
    for thresh in thresh_group:
        net_benefit = (tp / total) - (tn / total) * (thresh / (1 - thresh))
        net_benefit_all = np.append(net_benefit_all, net_benefit)
    return net_benefit_all


def plot_DCA(ax, thresh_group, net_benefit_model, net_benefit_all , model_id,model_name_list):
    
    choose_colar = 'darkorange'
    if model_id == 1:
        choose_colar = 'blue'
    elif model_id == 2:
        choose_colar = 'green'
    else:
        choose_colar = 'red'
    
    choose_colar = 'black'
    #Plot

    ax.plot(thresh_group, net_benefit_model, color = choose_colar)
    #ax.plot(thresh_group, net_benefit_model, color = choose_colar, label = f'{model_name_list[model_id]}')
    
    if model_id == 1:
        #ax.plot(thresh_group, net_benefit_all, color = 'black',label = 'Treat all')
        #ax.plot((0, 1), (0, 0), color = 'black', linestyle = ':', label = 'Treat none')
        ax.plot(thresh_group, net_benefit_all, color = 'black')
        ax.plot((0, 1), (0, 0), color = 'black', linestyle = ':')
    ax.plot(thresh_group, net_benefit_all, color = 'black')
    ax.plot((0, 1), (0, 0), color = 'black', linestyle = ':')
    
    #Fill，显示出模型较于treat all和treat none好的部分
    y2 = np.maximum(net_benefit_all, 0)
    y1 = np.maximum(net_benefit_model, y2)
    ax.fill_between(thresh_group, y1, y2, color = 'black', alpha = 0.2)
    
    font_properties = {'size': 18,  'family': 'Arial', 'fontweight':'bold'}
    plt.xticks(fontproperties='Arial', **font_properties)
    plt.yticks(fontproperties='Arial', **font_properties)
    
    #Figure Configuration， 美化一下细节
    ax.set_xlim(0,1)
    ax.set_ylim(net_benefit_model.min() - 0.15, net_benefit_model.max() + 0.15)#adjustify the y axis limitation
    ax.set_xlabel(
        xlabel = 'Threshold Probability', 
        fontdict= {'fontfamily':'Arial', 'fontsize': 24, 'fontweight':'bold'}
        )
    ax.set_ylabel(
        ylabel = 'Net Benefit', 
        fontdict= {'fontfamily':'Arial', 'fontsize': 24, 'fontweight':'bold'}
        )
    
    #ax.grid('major')
    ax.spines['right'].set_color((0.8, 0.8, 0.8))
    ax.spines['top'].set_color((0.8, 0.8, 0.8))
    #ax.legend(loc = 'upper right')
    plt.legend().set_visible(False)

    return ax

def decision_curve(model_list,model_name_list,x,y,task_name):
    if torch.is_tensor(x):
        x = x.numpy()
        y = y.numpy()
        
    assert len(model_list) == len(model_name_list)
    fig, ax = plt.subplots()
    for i in range(len(model_list)):
        model = model_list[i]
        model.eval()
        if torch.is_tensor(x):
            x = x.cpu().numpy()
            #y = y.cpu().numpy()
        x = torch.from_numpy(x).float().to(device)
        y_label = y
        out = model(x)
        y_pred_score = out[:,:]
        y_pred_score = y_pred_score.float().cpu().detach().numpy()
        y_pred = (y_pred_score > 0.5).astype(int)
        
        ########################################
        thresh_group = np.arange(0,1,0.01)
        net_benefit_model = calculate_net_benefit_model(thresh_group, y_pred_score, y_label)
        net_benefit_all = calculate_net_benefit_all(thresh_group, y_label)
        #fig, ax = plt.subplots()
        #ax = plot_DCA(ax, thresh_group, net_benefit_model, net_benefit_all)
        plot_DCA(ax, thresh_group, net_benefit_model, net_benefit_all,i,model_name_list)
    # fig.savefig('fig1.png', dpi = 300)
    #plt.savefig(f'./FG2/Decision.png', bbox_inches = 'tight', dpi=300)
    plt.show()

In [13]:
from datetime import datetime

start_time = datetime.now()

In [14]:
from openpyxl import Workbook
from openpyxl.utils.dataframe import dataframe_to_rows
from datetime import datetime

def group_result(df):
    agg_columns = {
        'acc': ['mean', 'std'],
        'pre': ['mean', 'std'],
        'f1': ['mean', 'std'],
        'recall': ['mean', 'std'],
        'auroc': ['mean', 'std'],
        'brier_score': ['mean', 'std']
    }
    df_group = df.groupby('task').agg(agg_columns)
    df_group.columns = [f"{col[0]}_{col[1]}" for col in df_group.columns]

    for metric in ['acc', 'pre', 'f1', 'recall', 'auroc','brier_score']:
        df_group[f"{metric}_combined"] = df_group.apply(
            lambda row: f"{row[f'{metric}_mean']:.4f} ± {row[f'{metric}_std']:.4f}", axis=1
        )

    df_result = df_group[[f"{metric}_combined" for metric in ['acc', 'pre', 'f1', 'recall', 'auroc','brier_score']]]

    df_result.reset_index(inplace=True)
    df_result.columns = ['task','acc', 'pre', 'f1', 'recall', 'auroc','brier_score']
    return df_result


def save_to_xlsx(df_save,file_name = 'output'):
    wb = Workbook()
    ws = wb.active
    for r_idx, row in enumerate(dataframe_to_rows(df_save, index=False, header=True), 1):
        for c_idx, value in enumerate(row, 1):
            ws.cell(row=r_idx, column=c_idx, value=value)
    wb.save(f'{file_name}.xlsx')

    

# Start

In [15]:
task_name = 'Weaning_successful'

# Feature

In [16]:
path = "./model/group_result/mtl_group/vent_group"
df_feature = pd.read_csv(f"{path}/feature_name_list.csv")
select_feature_list = df_feature['Feature'].tolist()
input_dim = len(select_feature_list)
loss_func = BCEFocalLoss(alpha=alpha, gamma=gamma)


train_dataset_dict,train_loader_dict,feature_name_list,_ = read_data([task_name],"",'train',select_feature_list,batch_size = batch_size,use_upsample = use_upsample)
val_dataset_dict,val_loader_dict,_ ,_= read_data([task_name],"",'validation',select_feature_list,batch_size = batch_size,use_upsample = use_upsample)
test_dataset_dict,test_loader_dict,_ ,original_data_dict= read_data([task_name],"",'test',select_feature_list,batch_size = batch_size,use_upsample = use_upsample)

print(f'input_dim: {input_dim}')

input_dim: 18


# MTL_model

In [17]:
mode = 'lite'

# Vent_Group

In [18]:
""" Weaning_successful (Vent_group) """
""" best model """
data_path = "./model/group_result/mtl_group/vent_group"
model_vent = MLP_MTL(input_dim, task_name_list).to(device)
model_vent.load_state_dict(torch.load(f'{data_path}/{task_name}_best_{mode}'))
result,_ = test(model_vent, test_dataset_dict, loss_func, is_show = False)
print(result)

{'Weaning_successful': {'task': 'Weaning_successful', 'auroc': 0.821, 'acc': 0.758, 'f1': 0.598, 'pre': 0.66, 'recall': 0.546, 'brier_score': 0.165, 'loss': 6.205740544593265e-05}}


# SHAP

In [19]:
feature_important, shap_value_flatten, shap_data_flatten = get_model_shap(
                                                            model_vent,
                                                            train_dataset_dict[task_name].inputs.numpy(),
                                                            train_dataset_dict[task_name].inputs.numpy(),
                                                            train_dataset_dict[task_name].inputs_original.numpy(),
                                                            select_feature_list,
                                                            task_name,
                                                            use_mini_sample = False,
                                                            n_sample = 1)


In [20]:
df_feature_important = pd.DataFrame()
df_feature_important['Feature'] = feature_important
df_feature_important.to_csv('./error_analysis/feature_important.csv',index = False)

OSError: Cannot save file into a non-existent directory: 'error_analysis'

In [None]:
new_feature_list = ['APACHE III' if item == 'apsiii' else item for item in select_feature_list]
new_feature_list = ['Fluid balance' if item == 'total' else item for item in new_feature_list]
new_feature_list = ['Urine output' if item == 'Urine_value' else item for item in new_feature_list]
new_feature_list = ['Enteral feeding' if item == 'Nutrition_Enteral_value' else item for item in new_feature_list]
new_feature_list = ['Fluid input' if item == 'Fluid_intake_value' else item for item in new_feature_list]

show_shap(shap_value_flatten, shap_data_flatten,new_feature_list,task_name = task_name)

In [None]:
from collections import Counter

def data_select(feature_name,shap_data,shap_value,indedx_of_feature):
    if feature_name == 'Nutrition_Enteral_value':
        select_indices1 = np.where(shap_value_flatten[:, indedx_of_feature] <= 0.7)[0]
        select_indices2 = np.where(shap_data_flatten[:, indedx_of_feature].astype(int) != 865)[0]
        select_indices3 = np.where(shap_data_flatten[:, indedx_of_feature].astype(int) <= 1500)[0]
        select_indices4 = np.where(shap_value_flatten[:, indedx_of_feature] > -0.7)[0]
        select_indices5 = np.where(shap_data_flatten[:, indedx_of_feature].astype(int) != 1000)[0]
        select_indices6 = np.where(shap_data_flatten[:, indedx_of_feature].astype(int) != 999)[0]

        select_indices = list(set(select_indices1) & set(select_indices2))
        select_indices = list(set(select_indices) & set(select_indices3))
        select_indices = list(set(select_indices) & set(select_indices4))
        select_indices = list(set(select_indices) & set(select_indices5))
        select_indices = list(set(select_indices) & set(select_indices6))

        return shap_data[select_indices,indedx_of_feature], shap_value[select_indices,indedx_of_feature]
    elif feature_name == 'apsiii':
        select_indices1 = np.where(shap_value_flatten[:, indedx_of_feature] <= 0.4)[0]
        select_indices2 = np.where(shap_value_flatten[:, indedx_of_feature] >= -0.4)[0]
        select_indices = list(set(select_indices1) & set(select_indices2))
        return shap_data[select_indices,indedx_of_feature], shap_value[select_indices,indedx_of_feature]
        
    elif feature_name == 'Peak Airway Pressure':
        select_indices1 = np.where(shap_data_flatten[:, indedx_of_feature].astype(int) <= 25)[0]
        select_indices2 = np.where(shap_data_flatten[:, indedx_of_feature].astype(int) >= 10)[0]
        select_indices3 = np.where(shap_value_flatten[:, indedx_of_feature] <= 0.7)[0]
        #select_indices4 = np.where(shap_value_flatten[:, indedx_of_feature] >= -0.5)[0]

        select_indices = list(set(select_indices1) & set(select_indices2))
        select_indices = list(set(select_indices) & set(select_indices3))
        #select_indices = list(set(select_indices) & set(select_indices4))
        return shap_data[select_indices,indedx_of_feature], shap_value[select_indices,indedx_of_feature]
    elif feature_name == 'RASS':
        select_indices1 = np.where(shap_value_flatten[:, indedx_of_feature] >= -0.4)[0]

        select_indices = list(set(select_indices1) & set(select_indices1))
        return shap_data[select_indices,indedx_of_feature], shap_value[select_indices,indedx_of_feature]
    elif feature_name == 'Urine_value':
        select_indices = np.where(shap_data_flatten[:, indedx_of_feature].astype(int) <= 2500)[0]
        return shap_data[select_indices,indedx_of_feature], shap_value[select_indices,indedx_of_feature]
    elif feature_name == 'total':
        select_indices1 = np.where(shap_value_flatten[:, indedx_of_feature] >= -0.2)[0]
        select_indices2 = np.where(shap_value_flatten[:, indedx_of_feature] <= 0.25)[0]

        select_indices = list(set(select_indices1) & set(select_indices2))
        return shap_data[select_indices,indedx_of_feature], shap_value[select_indices,indedx_of_feature]
   
    elif feature_name == 'FiO2':

        return shap_data[:,indedx_of_feature], shap_value[:,indedx_of_feature]

    else:
        select_indices = np.where(shap_data_flatten[:, indedx_of_feature].astype(int) !=  -99999)[0]
        return shap_data[select_indices,indedx_of_feature], shap_value[select_indices,indedx_of_feature]

In [None]:
for feature_name in feature_important[:7]:
    indedx_of_feature = feature_name_list.index(feature_name)
    
    shap_data,shap_value = data_select(feature_name,shap_data_flatten,shap_value_flatten,indedx_of_feature)

    print(feature_name)

    if feature_name == 'apsiii':
        name = 'APACHE III'
    elif feature_name == 'total':
        name = 'Fluid balance'
    elif feature_name == 'Nutrition_Enteral_value':
        name = 'Enteral feeding'
    else:
        name = feature_name
        
    pdp_plot(
        shap_data, 
        shap_value, 
        name
    )



In [None]:
model_list = [model_vent]
model_name_list = ['MTL with Vent'] 

In [None]:
plot_calibration_curve(model_list,model_name_list,test_dataset_dict['Weaning_successful'].inputs.cpu(),test_dataset_dict['Weaning_successful'].labels.cpu(),'Weaning_successful')

plot_roc_curve(model_list,model_name_list,test_dataset_dict['Weaning_successful'].inputs.cpu(),test_dataset_dict['Weaning_successful'].labels.cpu(),'Weaning_successful')

decision_curve(model_list,model_name_list,train_dataset_dict['Weaning_successful'].inputs,train_dataset_dict['Weaning_successful'].labels,'Weaning_successful') 

# Error Analysis

In [None]:
from interpret_community.common.constants import ShapValuesOutput, ModelTask
from interpret.ext.blackbox import MimicExplainer
from interpret.ext.glassbox import LGBMExplainableModel

In [None]:
from sklearn import svm
import pandas as pd
import zipfile
from lightgbm import LGBMClassifier

# Explainer Used: Mimic Explainer
from interpret.ext.blackbox import MimicExplainer
from interpret.ext.glassbox import LinearExplainableModel
from interpret.ext.glassbox import LGBMExplainableModel

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer

from raiwidgets import ErrorAnalysisDashboard

In [None]:
def predict_func(model, X):
    task_name = model.task_name_list[0]
    return model.predict(X)[task_name]

def predict_proba_func(model, X):
    task_name = model.task_name_list[0]
    result = model.predict_proba(X)[task_name]
    result = result.cpu().detach().numpy()
    return result 

def create_model_pipeline(model):
    model_pipeline = Pipeline([
        ('model', model)
    ])

    model_pipeline.predict = lambda X: predict_func(model_pipeline.named_steps['model'], X)
    model_pipeline.predict_proba = lambda X: predict_proba_func(model_pipeline.named_steps['model'], X)

    return model_pipeline





In [None]:
model_pipeline = create_model_pipeline(model_vent)

In [None]:
X_train_original = train_dataset_dict['Weaning_successful'].inputs.numpy()
X_train_original = np.squeeze(X_train_original)

X_test_original_full = test_dataset_dict['Weaning_successful'].inputs.numpy()
X_test_original_full = np.squeeze(X_test_original_full)


In [None]:
sample_count = X_test_original_full.shape[0]
y_test_full = test_dataset_dict['Weaning_successful'].labels.numpy()
X_test_original = X_test_original_full[:int(sample_count*0.9),:]
y_test = y_test_full[:int(sample_count*0.9)]

In [None]:
from interpret_community.common.constants import ShapValuesOutput, ModelTask
# 1. Using SHAP TabularExplainer
model_task = ModelTask.Classification
explainer = MimicExplainer(model_pipeline, X_train_original, LGBMExplainableModel,
                           augment_data=True, max_num_of_augmentations=10,
                           features=select_feature_list, classes=[0,1], model_task=model_task)

In [None]:
# Passing in test dataset for evaluation examples - note it must be a representative sample of the original data
# X_train can be passed as well, but with more examples explanations will take longer although they may be more accurate
global_explanation = explainer.explain_global(X_test_original)

In [None]:
dashboard_pipeline = create_model_pipeline(model_vent)

In [None]:
ErrorAnalysisDashboard(global_explanation, dashboard_pipeline, dataset=X_test_original_full,
                       true_y=y_test, categorical_features = [],
                       true_y_dataset=y_test_full)

In [None]:
import joblib
scaler = joblib.load('./data/scaler_model.joblib')

df_feature = pd.read_csv(f"./data/sample/full_feature_name.csv")

feature_name_list = df_feature.columns.to_list()
feature_count = len(feature_name_list)
print(feature_count)

In [None]:
def choose_feature(select_feature_list):
    for i in range(len(select_feature_list)):
        print(f'[{i+1}]...{select_feature_list[i]}')
    
    select_id = int(input('feature id = '))
    value = float(input('value = '))
    
    return select_feature_list[select_id-1], value
    

In [None]:
select_feature_dict = {}

In [None]:
#feature_name1, value1 = choose_feature(select_feature_list)
feature_name1 = 'Mean Airway Pressure'
value1 = 0.5

index_of_feature1_full = feature_name_list.index(feature_name1)
index_of_feature1_lite = select_feature_list.index(feature_name1)

data = np.full(feature_count, value1)
data = data.reshape(1,feature_count)

original_value = scaler.inverse_transform(data)[0,index_of_feature1_full]


In [None]:
#feature_name2, value2 = choose_feature(select_feature_list)
feature_name2 = 'Nutrition_Enteral_value'
value2 = 0.22

index_of_feature2_full = feature_name_list.index(feature_name2)
index_of_feature2_lite = select_feature_list.index(feature_name2)

data = np.full(feature_count, value2)
data = data.reshape(1,feature_count)

original_value = scaler.inverse_transform(data)[0,index_of_feature2_full]


In [None]:
#feature_name3, value3 = choose_feature(select_feature_list)
feature_name3 = 'Nutrition_Enteral_value'
value3 = 0.22

index_of_feature3_full = feature_name_list.index(feature_name3)
index_of_feature3_lite = select_feature_list.index(feature_name3)

data = np.full(feature_count, value3)
data = data.reshape(1,feature_count)

original_value = scaler.inverse_transform(data)[0,index_of_feature3_full]



In [None]:
_,_,_,data = read_data(task_name_list,'','test', select_feature_list =select_feature_list, batch_size = 256,use_upsample = False)

X_scalar = data['X_scalar']
X_original = data['X']
Y = data['Y']

dataset_dict = {}
dataset_dict[task_name] = MyDataset(X_scalar,X_original,Y)

In [None]:
def remove_samples_np(data, feature_index, threshold, condition_type):
    """
    Remove samples based on a specified condition on a specific feature.

    Parameters:
    - data (numpy.ndarray): Input data with shape [sample, 1, feature].
    - feature_index (int): Index of the feature.
    - threshold (float): Threshold value for the condition.
    - condition_type (str): Type of condition ('type1' for '<' or 'type2' for '>=').

    Returns:
    - numpy.ndarray: Updated data after removing samples.
    """
    if condition_type == 'type1':
        indices_to_remove = np.squeeze(np.argwhere(data[:, 0, feature_index] < threshold))
    elif condition_type == 'type2':
        indices_to_remove = np.squeeze(np.argwhere(data[:, 0, feature_index] <= threshold))
    elif condition_type == 'type3':
        indices_to_remove = np.squeeze(np.argwhere(data[:, 0, feature_index] > threshold))
    elif condition_type == 'type4':
        indices_to_remove = np.squeeze(np.argwhere(data[:, 0, feature_index] >= threshold))
    else:
        raise ValueError("Invalid condition_type. Use 'type1' for '<' or 'type2' for '>='.")
    
    return indices_to_remove

    ## Remove samples
    #data = np.delete(data, indices_to_remove, axis=0)
    #return data

In [None]:
#type1: <
#type2: <=
#type3: >
#type4: >=

indices_to_remove1 = remove_samples_np(X_scalar,index_of_feature1_lite,value1,'type2')
indices_to_remove2 = remove_samples_np(X_scalar,index_of_feature2_lite,value2,'type3')
indices_to_remove3 = remove_samples_np(X_scalar,index_of_feature3_lite,value3,'type3')

common_indices = np.intersect1d(np.intersect1d(indices_to_remove1, indices_to_remove2), indices_to_remove3)





In [None]:
all_indices = np.arange(X_scalar.shape[0])
not_selected_indices = np.setdiff1d(all_indices, common_indices)

In [None]:
print(data['X_with_id'][common_indices].shape)
print(data['X_with_id'][not_selected_indices].shape)

reserve_patient = data['X_with_id'][not_selected_indices,0,0].tolist()
remove_patient = data['X_with_id'][common_indices,0,0].tolist()

print(f'reserve: {len(reserve_patient)}')
print(f'remove: {len(remove_patient)}')

In [None]:
df_successful_patient = pd.DataFrame({'stay_id': reserve_patient, 'outcome': 'reserve'})
df_fail_patient = pd.DataFrame({'stay_id': remove_patient, 'outcome': 'remove'})
df_combined = pd.concat([df_successful_patient, df_fail_patient], ignore_index=True)

np.save("./error_analysis/remove_patient.npy", data['X'][common_indices])
np.save("./error_analysis/reserve_patient.npy", data['X'][not_selected_indices])
df_combined.to_csv("./error_analysis/error患者.csv",index = False)

In [None]:
print(f'found: {len(common_indices)}')

In [None]:
""" remove sample """
X_scalar_remove = X_scalar[common_indices].copy()
X_original_remove = X_original[common_indices].copy()
Y_remove = Y[common_indices].copy()

In [None]:
""" reserve sample """
full_indices = np.arange(0, Y.shape[0])
keep_indices = np.array([i for i in full_indices if i not in common_indices])

In [None]:
X_scalar_keep = X_scalar[keep_indices].copy()
X_original_keep = X_original[keep_indices].copy()
Y_keep = Y[keep_indices].copy()


In [None]:
dataset_dict_keep = {}
dataset_dict_keep[task_name] = MyDataset(X_scalar_keep,X_original_keep,Y_keep)
result_before_remove,_ = test(model_vent, test_dataset_dict, loss_func, is_show = False)
result_after_remove,_ = test(model_vent, dataset_dict_keep, loss_func, is_show = False)

In [None]:
df_result = pd.DataFrame([result_before_remove['Weaning_successful'], result_after_remove['Weaning_successful']], index=['before', 'after'])
df_result

In [None]:
arr = Y.copy()

total_sample = Y.shape[0]
count_zero = np.count_nonzero(arr == 0)
count_one = np.count_nonzero(arr == 1)

print(f'Total sample: {total_sample}')
print(f"Number of 0: {count_zero} ({round(count_zero/total_sample*100,2)}%)")
print(f"Number of 1: {count_one} ({round(count_one/total_sample*100,2)}%)")

In [None]:
import matplotlib.pyplot as plt

arr = Y_remove.copy()
total_remove_sample = Y_remove.shape[0]
count_zero = np.count_nonzero(arr == 0)
count_one = np.count_nonzero(arr == 1)

# 顯示結果
print(f'remove sample: {total_remove_sample}')
print(f"Number of 0: {count_zero} ({round(count_zero/total_remove_sample*100,2)}%)")
print(f"Number of 1: {count_one} ({round(count_one/total_remove_sample*100,2)}%)")

In [None]:
def plot_confusion_matrix(y_true, y_pred, classes, file_name, title='Confusion Matrix', cmap=plt.cm.Blues):
    cm = confusion_matrix(y_true, y_pred)
  
    cm_df = pd.DataFrame(cm, index=classes, columns=classes)
    
    new_order = ['Class 1','Class 0']
    new_order_columns = ['Class 1','Class 0']
    cm_df = cm_df.reindex(new_order)
    cm_df = cm_df.reindex(columns=new_order_columns)
    
    y_labels = list(reversed(classes))
    x_labels = list(reversed(classes))

    plt.figure(figsize=(8, 6))

    font_properties = {'size': 20,  'family': 'Arial'}
    title_font_properties = {'size': 20, 'weight': 'bold', 'family': 'Arial'}
    
    annot_kws = {'size': 20, 'weight': 'bold', 'family': 'Arial'}
    sns.heatmap(cm_df, annot=True, fmt="d", cmap=cmap, cbar=False,
                xticklabels=x_labels, yticklabels=y_labels, annot_kws=annot_kws)
    
    plt.xticks(fontproperties='Arial', **font_properties)
    plt.yticks(fontproperties='Arial', **font_properties)
    

    plt.xlabel(title, fontproperties='Arial', **title_font_properties)
    plt.ylabel(title, fontproperties='Arial', **title_font_properties)
    
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    plt.savefig(f'./error_analysis/confusion_{file_name}.png')
    

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix


In [None]:
""" [1]full sample """
correct_count_0 = 0
correct_count_1 = 0
count_zero = np.count_nonzero(Y == 0)
count_one = np.count_nonzero(Y == 1)
count = 0
pred_y = model_vent.predict(X_scalar)['Weaning_successful']

for i in range(pred_y.shape[0]):
    if Y[i] != pred_y[i]:
        count+=1
    if Y[i] == 1 and pred_y[i] == 1:
        correct_count_1 += 1
    if Y[i] == 0 and pred_y[i] == 0:
        correct_count_0 += 1

plot_confusion_matrix(Y,pred_y,['Class 0','Class 1'],'full')
 

In [None]:
""" [2]remove sample """
count_zero = np.count_nonzero(Y_remove == 0)
count_one = np.count_nonzero(Y_remove == 1)
correct_count_0 = 0
correct_count_1 = 0

pred_y = model_vent.predict(X_scalar_remove)['Weaning_successful']

for i in range(pred_y.shape[0]):
    if Y_remove[i] == 1 and pred_y[i] == 1:
        correct_count_1 += 1
    if Y_remove[i] == 0 and pred_y[i] == 0:
        correct_count_0 += 1

plot_confusion_matrix(Y_remove,pred_y,['Class 0','Class 1'],'remove_sample')


In [None]:
""" [3]reserve sample """
count_zero = np.count_nonzero(Y_keep == 0)
count_one = np.count_nonzero(Y_keep == 1)
correct_count_0 = 0
correct_count_1 = 0

pred_y = model_vent.predict(X_scalar_keep)['Weaning_successful']

for i in range(pred_y.shape[0]):
    if Y_keep[i] == 1 and pred_y[i] == 1:
        correct_count_1 += 1
    if Y_keep[i] == 0 and pred_y[i] == 0:
        correct_count_0 += 1

plot_confusion_matrix(Y_keep,pred_y,['Class 0','Class 1'],'keep_sample')
