### 1 - Test Prediction Metrics for PreAttnMMs_LCPN
        - do some necessary post-processing for PreAttnMMs_LCPN model results
        - at last, we will get the target_labels and pred_labels of TEST dataset with shape [sample_num, 11]
        - we will calculate some metrics based on the above target_labels and pred_labels

In [7]:
import os
import pickle as pkl
import re

import numpy as np
import torch
import torch.nn.functional as F
from addict import Dict
from sklearn.metrics import (accuracy_score, auc, average_precision_score,
                             classification_report, confusion_matrix, f1_score,
                             hamming_loss, precision_recall_curve,
                             precision_score, recall_score, roc_auc_score,
                             roc_curve, zero_one_loss)
from torch import nn

from helper.data import DataPreprocess
from helper.data_loader import data_loaders
from helper.utils import get_config, load_checkpoint, set_seed
from models.model import PreAttnMMs
from train_modules.evaluation_metrics import evaluate4test

- define the model structure and optimizer

In [8]:
config_0 = Dict(get_config(config_id="PreAttnMMs_LCPN_0"))
config_1 = Dict(get_config(config_id="PreAttnMMs_LCPN_1"))
config_2 = Dict(get_config(config_id="PreAttnMMs_LCPN_2"))
config_6 = Dict(get_config(config_id="PreAttnMMs_LCPN_6"))
config_7 = Dict(get_config(config_id="PreAttnMMs_LCPN_7"))

set_seed(seed=config_0.seed)

# Load the preprocessed data
dp = DataPreprocess(config_0)
data, label, indices = dp.load()

n_classes_0 = len(label['taxonomy'][config_0.experiment.local_task])
n_classes_1 = len(label['taxonomy'][config_1.experiment.local_task])
n_classes_2 = len(label['taxonomy'][config_2.experiment.local_task])
n_classes_6 = len(label['taxonomy'][config_6.experiment.local_task])
n_classes_7 = len(label['taxonomy'][config_7.experiment.local_task])

model_0 = PreAttnMMs(config_0, 
                   data['X_t_steps'], 
                   data['X_t_features'],
                   data['X_features'],
                   n_classes_0)
model_0.to(config_0.train.device_setting.device)
model_1 = PreAttnMMs(config_1, 
                   data['X_t_steps'], 
                   data['X_t_features'],
                   data['X_features'],
                   n_classes_1)
model_1.to(config_1.train.device_setting.device)
model_2 = PreAttnMMs(config_2, 
                   data['X_t_steps'], 
                   data['X_t_features'],
                   data['X_features'],
                   n_classes_2)
model_2.to(config_2.train.device_setting.device)
model_6 = PreAttnMMs(config_6, 
                   data['X_t_steps'], 
                   data['X_t_features'],
                   data['X_features'],
                   n_classes_6)
model_6.to(config_6.train.device_setting.device)
model_7 = PreAttnMMs(config_7, 
                   data['X_t_steps'], 
                   data['X_t_features'],
                   data['X_features'],
                   n_classes_7)
model_7.to(config_7.train.device_setting.device)

criterion = nn.NLLLoss()

optimizer_0 = torch.optim.Adam(
    params = model_0.parameters(),
    lr = config_0.train.optimizer.learning_rate,
    weight_decay=config_0.train.optimizer.weight_decay
)
optimizer_1 = torch.optim.AdamW(
    params = model_1.parameters(),
    lr = config_1.train.optimizer.learning_rate,
    weight_decay=config_1.train.optimizer.weight_decay
)
optimizer_2 = torch.optim.RMSprop(
    params = model_2.parameters(),
    lr = config_2.train.optimizer.learning_rate,
    weight_decay=config_2.train.optimizer.weight_decay
)
optimizer_6 = torch.optim.RMSprop(
    params = model_6.parameters(),
    lr = config_6.train.optimizer.learning_rate,
    weight_decay=config_6.train.optimizer.weight_decay
)
optimizer_7 = torch.optim.Adagrad(
    params = model_7.parameters(),
    lr = config_7.train.optimizer.learning_rate,
    weight_decay=config_7.train.optimizer.weight_decay
)

INFO:  Loading previously preprocessed data...


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))


- read the checkpoint file and load the parameters

In [9]:
checkpoint_base = config_0.train.checkpoint.dir
checkpoint_dir_0 = os.path.join(checkpoint_base, 'hp_tuning', 'PreAttnMMs_LCPN', 'node-0', 'Standardization', 'auc')
checkpoint_dir_1 = os.path.join(checkpoint_base, 'hp_tuning', 'PreAttnMMs_LCPN', 'node-1', 'Standardization', 'macro-auc')
checkpoint_dir_2 = os.path.join(checkpoint_base, 'hp_tuning', 'PreAttnMMs_LCPN', 'node-2', 'Standardization', 'auc')
checkpoint_dir_6 = os.path.join(checkpoint_base, 'hp_tuning', 'PreAttnMMs_LCPN', 'node-6', 'Standardization', 'auc')
checkpoint_dir_7 = os.path.join(checkpoint_base, 'hp_tuning', 'PreAttnMMs_LCPN', 'node-7', 'Standardization', 'auc')

# get the best checkpoint .pt file
checkpoints = []
idx_checkpoint = np.array([])
for i in os.listdir(checkpoint_dir_0):
    if "best_best_checkpoint" in i:
        checkpoints.append(i)
for i in checkpoints:
    seachobj = re.search(r"\d+(?=\).pt)", i)
    idx_checkpoint = np.append(idx_checkpoint, int(seachobj.group()))
target_model_0 = checkpoints[np.argmax(idx_checkpoint)]

print("Loading the checkpoint --> {} for node-{} task".format(target_model_0, 0))

checkpoints = []
idx_checkpoint = np.array([])
for i in os.listdir(checkpoint_dir_1):
    if "best_best_checkpoint" in i:
        checkpoints.append(i)
for i in checkpoints:
    seachobj = re.search(r"\d+(?=\).pt)", i)
    idx_checkpoint = np.append(idx_checkpoint, int(seachobj.group()))
target_model_1 = checkpoints[np.argmax(idx_checkpoint)]

print("Loading the checkpoint --> {} for node-{} task".format(target_model_1, 1))

checkpoints = []
idx_checkpoint = np.array([])
for i in os.listdir(checkpoint_dir_2):
    if "best_best_checkpoint" in i:
        checkpoints.append(i)
for i in checkpoints:
    seachobj = re.search(r"\d+(?=\).pt)", i)
    idx_checkpoint = np.append(idx_checkpoint, int(seachobj.group()))
target_model_2 = checkpoints[np.argmax(idx_checkpoint)]

print("Loading the checkpoint --> {} for node-{} task".format(target_model_2, 2))

checkpoints = []
idx_checkpoint = np.array([])
for i in os.listdir(checkpoint_dir_6):
    if "best_best_checkpoint" in i:
        checkpoints.append(i)
for i in checkpoints:
    seachobj = re.search(r"\d+(?=\).pt)", i)
    idx_checkpoint = np.append(idx_checkpoint, int(seachobj.group()))
target_model_6 = checkpoints[np.argmax(idx_checkpoint)]

print("Loading the checkpoint --> {} for node-{} task".format(target_model_6, 6))

checkpoints = []
idx_checkpoint = np.array([])
for i in os.listdir(checkpoint_dir_7):
    if "best_best_checkpoint" in i:
        checkpoints.append(i)
for i in checkpoints:
    seachobj = re.search(r"\d+(?=\).pt)", i)
    idx_checkpoint = np.append(idx_checkpoint, int(seachobj.group()))
target_model_7 = checkpoints[np.argmax(idx_checkpoint)]

print("Loading the checkpoint --> {} for node-{} task".format(target_model_7, 7))


# reload the checkpoint file and run on test Dataset
best_epoch_model_file_0 = os.path.join(checkpoint_dir_0, target_model_0)
if os.path.isfile(best_epoch_model_file_0):
    best_performance_0, config_0 = load_checkpoint(best_epoch_model_file_0, 
                                                model=model_0,
                                                config=config_0,
                                                optimizer=optimizer_0)
best_epoch_model_file_1 = os.path.join(checkpoint_dir_1, target_model_1)
if os.path.isfile(best_epoch_model_file_1):
    best_performance_1, config_1 = load_checkpoint(best_epoch_model_file_1, 
                                                model=model_1,
                                                config=config_1,
                                                optimizer=optimizer_1)
best_epoch_model_file_2 = os.path.join(checkpoint_dir_2, target_model_2)
if os.path.isfile(best_epoch_model_file_2):
    best_performance_2, config_2 = load_checkpoint(best_epoch_model_file_2, 
                                                model=model_2,
                                                config=config_2,
                                                optimizer=optimizer_2)
best_epoch_model_file_6 = os.path.join(checkpoint_dir_6, target_model_6)
if os.path.isfile(best_epoch_model_file_6):
    best_performance_6, config_6 = load_checkpoint(best_epoch_model_file_6, 
                                                model=model_6,
                                                config=config_6,
                                                optimizer=optimizer_6)
best_epoch_model_file_7 = os.path.join(checkpoint_dir_7, target_model_7)
if os.path.isfile(best_epoch_model_file_7):
    best_performance_7, config_7 = load_checkpoint(best_epoch_model_file_7, 
                                                model=model_7,
                                                config=config_7,
                                                optimizer=optimizer_7)

Loading the checkpoint --> best_best_checkpoint(i.e.trial90).pt for node-0 task
Loading the checkpoint --> best_best_checkpoint(i.e.trial65).pt for node-1 task
Loading the checkpoint --> best_best_checkpoint(i.e.trial94).pt for node-2 task
Loading the checkpoint --> best_best_checkpoint(i.e.trial77).pt for node-6 task
Loading the checkpoint --> best_best_checkpoint(i.e.trial21).pt for node-7 task


- load the data and statistics of each local task

In [10]:
indice_0 = indices['folds_idx_with_txy']['parent-node-0'][0,:][2]
indice_1 = indices['folds_idx_with_txy']['parent-node-1'][0,:][2]
indice_2 = indices['folds_idx_with_txy']['parent-node-2'][0,:][2]
indice_6 = indices['folds_idx_with_txy']['parent-node-6'][0,:][2]
indice_7 = indices['folds_idx_with_txy']['parent-node-7'][0,:][2]
sample_size = indice_0.shape[0]
data_0 = {
    'X_t': data['X_t'][indice_0], 
    'T_t': data['T_t_rel'][indice_0],
    'X_t_mask': data['X_t_mask'][indice_0],
    'deltaT_t': data['deltaT_t'][indice_0],
    'X_val': data['static_data_val'][indice_0],
    'X_cat': data['static_data_cat_onehot'][indice_0],
    'y_classes_unique': np.array(label['y_classes_unique'], dtype=object)[indice_0].tolist()
}
data_1 = {
    'X_t': data['X_t'][indice_1], 
    'T_t': data['T_t_rel'][indice_1],
    'X_t_mask': data['X_t_mask'][indice_1],
    'deltaT_t': data['deltaT_t'][indice_1],
    'X_val': data['static_data_val'][indice_1],
    'X_cat': data['static_data_cat_onehot'][indice_1],
    'y_classes_unique': np.array(label['y_classes_unique'], dtype=object)[indice_1].tolist()
}
data_2 = {
    'X_t': data['X_t'][indice_2], 
    'T_t': data['T_t_rel'][indice_2],
    'X_t_mask': data['X_t_mask'][indice_2],
    'deltaT_t': data['deltaT_t'][indice_2],
    'X_val': data['static_data_val'][indice_2],
    'X_cat': data['static_data_cat_onehot'][indice_2],
    'y_classes_unique': np.array(label['y_classes_unique'], dtype=object)[indice_2].tolist()
}
data_6 = {
    'X_t': data['X_t'][indice_6], 
    'T_t': data['T_t_rel'][indice_6],
    'X_t_mask': data['X_t_mask'][indice_6],
    'deltaT_t': data['deltaT_t'][indice_6],
    'X_val': data['static_data_val'][indice_6],
    'X_cat': data['static_data_cat_onehot'][indice_6],
    'y_classes_unique': np.array(label['y_classes_unique'], dtype=object)[indice_6].tolist()
}
data_7 = {
    'X_t': data['X_t'][indice_7], 
    'T_t': data['T_t_rel'][indice_7],
    'X_t_mask': data['X_t_mask'][indice_7],
    'deltaT_t': data['deltaT_t'][indice_7],
    'X_val': data['static_data_val'][indice_7],
    'X_cat': data['static_data_cat_onehot'][indice_7],
    'y_classes_unique': np.array(label['y_classes_unique'], dtype=object)[indice_7].tolist()
}

stat_0 = {
    'X_val_mean': indices['folds_stats']['parent-node-{}'.format(config_0.experiment.local_task)][config_0.data.kfold][0,0],
    'X_val_std': indices['folds_stats']['parent-node-{}'.format(config_0.experiment.local_task)][config_0.data.kfold][0,1],
    'X_val_max': indices['folds_stats']['parent-node-{}'.format(config_0.experiment.local_task)][config_0.data.kfold][0,2],
    'X_val_min': indices['folds_stats']['parent-node-{}'.format(config_0.experiment.local_task)][config_0.data.kfold][0,3],
    'X_t_mean': indices['folds_stats']['parent-node-{}'.format(config_0.experiment.local_task)][config_0.data.kfold][1,0],
    'X_t_std': indices['folds_stats']['parent-node-{}'.format(config_0.experiment.local_task)][config_0.data.kfold][1,1],
    'X_t_max': indices['folds_stats']['parent-node-{}'.format(config_0.experiment.local_task)][config_0.data.kfold][1,2],
    'X_t_min': indices['folds_stats']['parent-node-{}'.format(config_0.experiment.local_task)][config_0.data.kfold][1,3]
}
stat_1 = {
    'X_val_mean': indices['folds_stats']['parent-node-{}'.format(config_1.experiment.local_task)][config_1.data.kfold][0,0],
    'X_val_std': indices['folds_stats']['parent-node-{}'.format(config_1.experiment.local_task)][config_1.data.kfold][0,1],
    'X_val_max': indices['folds_stats']['parent-node-{}'.format(config_1.experiment.local_task)][config_1.data.kfold][0,2],
    'X_val_min': indices['folds_stats']['parent-node-{}'.format(config_1.experiment.local_task)][config_1.data.kfold][0,3],
    'X_t_mean': indices['folds_stats']['parent-node-{}'.format(config_1.experiment.local_task)][config_1.data.kfold][1,0],
    'X_t_std': indices['folds_stats']['parent-node-{}'.format(config_1.experiment.local_task)][config_1.data.kfold][1,1],
    'X_t_max': indices['folds_stats']['parent-node-{}'.format(config_1.experiment.local_task)][config_1.data.kfold][1,2],
    'X_t_min': indices['folds_stats']['parent-node-{}'.format(config_1.experiment.local_task)][config_1.data.kfold][1,3]
}
stat_2 = {
    'X_val_mean': indices['folds_stats']['parent-node-{}'.format(config_2.experiment.local_task)][config_2.data.kfold][0,0],
    'X_val_std': indices['folds_stats']['parent-node-{}'.format(config_2.experiment.local_task)][config_2.data.kfold][0,1],
    'X_val_max': indices['folds_stats']['parent-node-{}'.format(config_2.experiment.local_task)][config_2.data.kfold][0,2],
    'X_val_min': indices['folds_stats']['parent-node-{}'.format(config_2.experiment.local_task)][config_2.data.kfold][0,3],
    'X_t_mean': indices['folds_stats']['parent-node-{}'.format(config_2.experiment.local_task)][config_2.data.kfold][1,0],
    'X_t_std': indices['folds_stats']['parent-node-{}'.format(config_2.experiment.local_task)][config_2.data.kfold][1,1],
    'X_t_max': indices['folds_stats']['parent-node-{}'.format(config_2.experiment.local_task)][config_2.data.kfold][1,2],
    'X_t_min': indices['folds_stats']['parent-node-{}'.format(config_2.experiment.local_task)][config_2.data.kfold][1,3]
}
stat_6 = {
    'X_val_mean': indices['folds_stats']['parent-node-{}'.format(config_6.experiment.local_task)][config_6.data.kfold][0,0],
    'X_val_std': indices['folds_stats']['parent-node-{}'.format(config_6.experiment.local_task)][config_6.data.kfold][0,1],
    'X_val_max': indices['folds_stats']['parent-node-{}'.format(config_6.experiment.local_task)][config_6.data.kfold][0,2],
    'X_val_min': indices['folds_stats']['parent-node-{}'.format(config_6.experiment.local_task)][config_6.data.kfold][0,3],
    'X_t_mean': indices['folds_stats']['parent-node-{}'.format(config_6.experiment.local_task)][config_6.data.kfold][1,0],
    'X_t_std': indices['folds_stats']['parent-node-{}'.format(config_6.experiment.local_task)][config_6.data.kfold][1,1],
    'X_t_max': indices['folds_stats']['parent-node-{}'.format(config_6.experiment.local_task)][config_6.data.kfold][1,2],
    'X_t_min': indices['folds_stats']['parent-node-{}'.format(config_6.experiment.local_task)][config_6.data.kfold][1,3]
}
stat_7 = {
    'X_val_mean': indices['folds_stats']['parent-node-{}'.format(config_7.experiment.local_task)][config_7.data.kfold][0,0],
    'X_val_std': indices['folds_stats']['parent-node-{}'.format(config_7.experiment.local_task)][config_7.data.kfold][0,1],
    'X_val_max': indices['folds_stats']['parent-node-{}'.format(config_7.experiment.local_task)][config_7.data.kfold][0,2],
    'X_val_min': indices['folds_stats']['parent-node-{}'.format(config_7.experiment.local_task)][config_7.data.kfold][0,3],
    'X_t_mean': indices['folds_stats']['parent-node-{}'.format(config_7.experiment.local_task)][config_7.data.kfold][1,0],
    'X_t_std': indices['folds_stats']['parent-node-{}'.format(config_7.experiment.local_task)][config_7.data.kfold][1,1],
    'X_t_max': indices['folds_stats']['parent-node-{}'.format(config_7.experiment.local_task)][config_7.data.kfold][1,2],
    'X_t_min': indices['folds_stats']['parent-node-{}'.format(config_7.experiment.local_task)][config_7.data.kfold][1,3]
}


In [11]:
def _rescale_stdize(x, mean, std):
    """
    standardize the non-time-series data and time-series data
    :param x: A np.array witn shape (t_i, d)
            mean: A np.array with shape (d,)
            std: A np.array with shape (d,)
    :return A np.array with same shape as x with rescaled values
    """
    if x.ndim == 1:
        return (x - mean) / std
    elif x.ndim == 2:
        return (x - mean[np.newaxis, :]) / std[np.newaxis, :]
    elif x.ndim == 3:
        return np.asarray([(xx - mean[np.newaxis, :]) / std[np.newaxis, :] for xx in x])

def _fillnan(x, mean):
        """
        fill the nan value in non-time-series data
        :param x: A np.array of static variables with shape (d,)
               mean: A np.array of mean value of each variable with shape (d,)
        :return A np.array without nan value
        """
        x[np.isnan(x)] = mean[np.isnan(x)]

        return x  

def _rescale_stdize(x, mean, std):
    """
    standardize the non-time-series data and time-series data
    :param x: A np.array witn shape (t_i, d)
            mean: A np.array with shape (d,)
            std: A np.array with shape (d,)
    :return A np.array with same shape as x with rescaled values
    """
    if x.ndim == 1:
        return (x - mean) / std
    elif x.ndim == 2:
        return (x - mean[np.newaxis, :]) / std[np.newaxis, :]
    elif x.ndim == 3:
        return np.asarray([(xx - mean[np.newaxis, :]) / std[np.newaxis, :] for xx in x])

def _locf_numpy(X, X_nan):
    """Numpy implementation of LOCF.

    Parameters
    ----------
    X : np.ndarray,
        Time series containing missing values (NaN) to be imputed.

    Returns
    -------
    X_imputed : array,
        Imputed time series.

    Notes
    -----
    This implementation gets inspired by the question on StackOverflow:
    https://stackoverflow.com/questions/41190852/most-efficient-way-to-forward-fill-nan-values-in-numpy-array
    """
    trans_X = X.transpose((1, 0))
    trans_X_nan = X_nan.transpose((1, 0))
    mask = np.isnan(trans_X_nan)
    n_features, n_steps  = mask.shape
    idx = np.where(~mask, np.arange(n_steps), 0)
    np.maximum.accumulate(idx, axis=1, out=idx)

    X_imputed = trans_X[np.arange(n_features)[:, None], idx]
    X_imputed = X_imputed.transpose((1, 0))

    # If there are values still missing,
    # they are missing at the beginning of the time-series sequence.
    # Impute them with self.nan
    if np.isnan(X_imputed).any():
        X_imputed = np.nan_to_num(X_imputed, nan=0)

    return X_imputed

def f_empirical_mean(data, stat):

    X_rescaled = _rescale_stdize(data['X_t'], stat['X_t_mean'], stat['X_t_std'])
    X_rescaled = X_rescaled.reshape(-1, data['X_t'].shape[-1])
    empirical_mean = np.nanmean(X_rescaled, axis=0)

    return empirical_mean

def _preprocess_sample(raw_sample, stat, empirical_mean):
    """"
    normalize the sample data
    :param: raw_sample -> List[array(), array(), array(), array(), List[List[]]]
    :return: Dict{'X_t': np.array([]),
                    'X': np.array([]),
                    'X_t_mask': np.array([]),
                    'deltaT_t': np.array([]),
                    'X_t_filledLOCF': np.array([]),
                    'y_classes': List[List[int]],
                    'empirical_mean': np.array([])}
    """
    sample = {}

    sample['X_t'] = _rescale_stdize(raw_sample[0], stat['X_t_mean'], stat['X_t_std'])
    sample['X_t'] = np.nan_to_num(sample['X_t'])
    
    # fill the nan value in X_val and normalize X_val
    raw_sample[4] = _fillnan(raw_sample[4], stat['X_val_mean'])
    raw_sample[4] = _rescale_stdize(raw_sample[4], stat['X_val_mean'], stat['X_val_std'])
    # concatenate the static variables
    sample['X'] = np.concatenate((raw_sample[4], raw_sample[5]))

    # forward fill nan value in np.array
    sample['X_t_filledLOCF'] = _locf_numpy(sample['X_t'], raw_sample[0])

    sample['empirical_mean'] = empirical_mean

    sample['X_t_mask'] = raw_sample[2]
    sample['deltaT_t'] = raw_sample[3] / 86400 # 24*60*60
    sample['y_classes'] = raw_sample[-1]

    return sample

In [12]:
def _localize_label(config, label, batch_labels):
    """
    :param batch_labels: label idx of one batch, List[List[int]], e.g. [[0, 1, 3], [0, 2, 6, 9],...]
    :return batch_local_label: np.array([int]), e.g. np.array([0, 1, ...])
    """
    label_dict = {}
    for idx, value in enumerate(label['taxonomy'][config.experiment.local_task]):
        label_dict[value] = idx
    
    print(label_dict)
    print(batch_labels)
    
    local_labels = []
    for label in batch_labels:
        for label_idx in label:
            if label_idx in label_dict:
                local_labels.append(label_dict[label_idx])
                break

    assert len(local_labels) == len(batch_labels), "The labels are missed during localization, please recheck!"

    return local_labels

def _check_input(batch):
    for key, value in batch.items():
        # convert the data type if in need
        batch[key] = value.to(config_0.train.device_setting.device)

    return batch

def collect(config, label, batch):
    batch_X_t = []
    batch_X_t_mask = []
    batch_deltaT_t = []
    batch_X_t_filledLOCF = []
    batch_empirical_mean = []
    batch_X = []
    batch_label = []

    # copy batch_size times of one sample
    batch_copy = []
    for i in range(64):
        batch_copy.append(batch)

    for sample in batch_copy:
        batch_X_t.append(sample['X_t'])
        batch_X_t_mask.append(sample['X_t_mask'])
        batch_deltaT_t.append(sample['deltaT_t'])
        batch_X_t_filledLOCF.append(sample['X_t_filledLOCF'])
        batch_empirical_mean.append(sample['empirical_mean'])
        batch_X.append(sample['X'])
        batch_label.append(sample['y_classes'])

    return {
        'X': torch.tensor(np.array(batch_X)).to(torch.float32),
        'X_t': torch.tensor(np.array(batch_X_t)).to(torch.float32),
        'X_t_mask': torch.tensor(np.array(batch_X_t_mask)).to(torch.float32),
        'deltaT_t': torch.tensor(np.array(batch_deltaT_t)).to(torch.float32),
        'X_t_filledLOCF': torch.tensor(np.array(batch_X_t_filledLOCF)).to(torch.float32),
        'empirical_mean': torch.tensor(np.array(batch_empirical_mean)).to(torch.float32)
    }

- make hierarchical predictions based on the label hierarchy

In [14]:
target_labels_array = np.zeros((indice_0[:2112].shape[0], 11))
predcit_labels_array = np.zeros((indice_0[:2112].shape[0], 11))

target_labels = data_0['y_classes_unique']

for index, sample_idx in enumerate(indice_0[:2112]):
    raw_sample = [data_0[s][index] for s in ['X_t', 'T_t', 'X_t_mask', 'deltaT_t', 'X_val' , 'X_cat', 'y_classes_unique']]
    
    # parent-node-0
    empirical_mean = f_empirical_mean(data_0, stat_0)
    processed_sample = _preprocess_sample(raw_sample, stat_0, empirical_mean)
    batch = collect(config_0, label, processed_sample)
    
    model_0.eval()
    with torch.no_grad():
        inputs = _check_input(batch)
        logits = model_0(inputs)
        predictions = F.softmax(logits, dim=1)
        pred_label_0 = predictions.max(1)[-1].cpu().tolist()[0]
        if pred_label_0 == 0:
            predcit_labels_array[index, 0] = 1
        elif pred_label_0 == 1:
            predcit_labels_array[index, 1] = 1
    
    if pred_label_0 == 0:
        # parent-node-1
        empirical_mean = f_empirical_mean(data_1, stat_1)
        processed_sample = _preprocess_sample(raw_sample, stat_1, empirical_mean)
        batch = collect(config_1, label, processed_sample)

        model_1.eval()
        with torch.no_grad():
            inputs = _check_input(batch)
            logits = model_1(inputs)
            predictions = F.softmax(logits, dim=1)
            pred_label_1 = predictions.max(1)[-1].cpu().tolist()[0]
            if pred_label_1 == 0:
                predcit_labels_array[index, 2] = 1
            elif pred_label_1 == 1:
                predcit_labels_array[index, 3] = 1
            elif pred_label_1 == 2:
                predcit_labels_array[index, 4] = 1

    elif pred_label_0 == 1:
        # parent-node-2
        empirical_mean = f_empirical_mean(data_2, stat_2)
        processed_sample = _preprocess_sample(raw_sample, stat_2, empirical_mean)
        batch = collect(config_2, label, processed_sample)

        model_2.eval()
        with torch.no_grad():
            inputs = _check_input(batch)
            logits = model_2(inputs)
            predictions = F.softmax(logits, dim=1)
            pred_label_2 = predictions.max(1)[-1].cpu().tolist()[0]
            if pred_label_2 == 0:
                predcit_labels_array[index, 5] = 1
            elif pred_label_2 == 1:
                predcit_labels_array[index, 6] = 1

        if pred_label_2 == 0:
            # parent-node-6
            empirical_mean = f_empirical_mean(data_6, stat_6)
            processed_sample = _preprocess_sample(raw_sample, stat_6, empirical_mean)
            batch = collect(config_6, label, processed_sample)

            model_6.eval()
            with torch.no_grad():
                inputs = _check_input(batch)
                logits = model_6(inputs)
                predictions = F.softmax(logits, dim=1)
                pred_label_6 = predictions.max(1)[-1].cpu().tolist()[0]
                if pred_label_6 == 0:
                    predcit_labels_array[index, 7] = 1
                elif pred_label_6 == 1:
                    predcit_labels_array[index, 8] = 1
        
        elif pred_label_2 == 1:
            # parent-node-7
            empirical_mean = f_empirical_mean(data_7, stat_7)
            processed_sample = _preprocess_sample(raw_sample, stat_7, empirical_mean)
            batch = collect(config_7, label, processed_sample)

            model_7.eval()
            with torch.no_grad():
                inputs = _check_input(batch)
                logits = model_7(inputs)
                predictions = F.softmax(logits, dim=1)
                pred_label_7 = predictions.max(1)[-1].cpu().tolist()[0]
                if pred_label_7 == 0:
                    predcit_labels_array[index, 9] = 1
                elif pred_label_7 == 1:
                    predcit_labels_array[index, 10] = 1


- transform the target labels of TEST dataset into form of [sample_size, 11]

In [20]:
def _all_node_label_wo_root(batch_label):
    """
    tranform y_classes_unique to all-node label without ROOT node
    :params: batch_label, List[List[], List[], ...]-->[[0,1,3], [0,1,5], [0,2,6,8], ...]
    :Return: all_node_labels without root node, List[List[], List[], ...]-->[[1,0,1,0,0,0,0,0,0,0,0], [1,0,0,0,1,0,0,0,0,0,0],...]
    """
    all_node_labels = np.zeros((len(batch_label), 11))
    for i, label in enumerate(batch_label):
        for j in label[1:]:
            all_node_labels[i][j-1] = 1

    return all_node_labels

target_labels_array = _all_node_label_wo_root(target_labels[:2112])

In [22]:
metrics = evaluate4test(config_0, target_labels_array, predcit_labels_array)

# save the pred and true labels of TEST dataset
metrics['target_labels'] = target_labels_array
metrics['pred_labels'] = predcit_labels_array

np.save(os.path.join("./results/hp_tuning/PreAttnMMs_LCPN", "test_results.npy"), metrics)
print("The process finished!")
print(metrics)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


INFO:  本实验中，完全匹配样本数为：609，占比：0.2883522727272727 
         非完全匹配样本数为：1503， 占比：0.7116477272727273 
             不完整路径预测的样本数为：0， 占比：0.0 
                 不完整路径预测中预测正确的样本数为：0， 占比：0 
                 不完整路径预测中预测错误的样本数为：0，占比：0 
             完整路径预测的样本数为：1503，占比：1.0 
                 第一层预测错误的样本数为：415，占比：0.2761144377910845 
                 第一层预测正确的样本数为：1088，占比：0.7238855622089155 
                     第二层感染细分类预测错误的样本数为：906，占比：0.8327205882352942 
                     第二层非感染细分类预测错误的样本数为：182，占比：0.16727941176470587 
                         NIID和Neo层即分类预测错误的样本数为：125，占比：0.6868131868131868 
                         NIID和Neo层预测正确，但第三层预测错误的样本数为：57，占比：0.3131868131868132 
                     第二层及后续层预测存在违反类别约束错误的样本数为：0，占比：0.0
INFO:  本实验中, 测试样本数量为: 2112, 
         完全匹配样本数为：609, 占比: 0.2883522727272727 
         非完全匹配样本数为: 1503, 占比: 0.7116477272727273 
             不完整路径预测的样本数为: 0, 占全部测试样本比例为: 0.0 
                 不完整路径预测中预测正确的样本数为: 0, 占不完整路径预测样本的比例为: 0 
             完整路径预测的样本数为: 1503, 占全部测试样本比例为: 0.71164772