In [172]:
%load_ext tuna

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import matplotlib.pyplot as plt
import numpy as np
import dill
import os
import seaborn as sns
import torch
import pandas as pd

from collections import namedtuple
from hmmlearn.hmm import MultinomialHMM
from scipy.special import logsumexp
from scipy.stats import multivariate_normal
from spn.algorithms.Inference import log_likelihood

The tuna extension is already loaded. To reload it, use:
  %reload_ext tuna


In [173]:
import pickle
with open('all_vaso_vars.pickle', 'rb') as file:  
    rspn_vaso_models = pickle.load(file)  
    test_ds_list = pickle.load(file)

NUM_SEEDS = 5
SEEDS = range(42, 42 + NUM_SEEDS)

len(rspn_vaso_models), len(test_ds_list)

(5, 5)

In [174]:
# test_0 = test_ds_list[1]

# print(np.unique(test_1[:,5], return_counts=True))

In [175]:
%%capture
from spn.algorithms.Marginalization import marginalize
from copy import deepcopy
from spn.algorithms.MPE import mpe
from spn.structure.Base import assign_ids, rebuild_scopes_bottom_up
from spn.io.Text import str_to_spn, spn_to_str_equation
from sklearn.metrics import mean_squared_error as sk_mse, f1_score, roc_auc_score, brier_score_loss

num_time_steps = 6
n_dim = 2


def next_error(model, data, num_time_steps=num_time_steps, num_dims=n_dim):
    unrolled_model = model.get_unrolled_rspn(model.get_len_sequence())
    preds_arr = np.zeros((data.shape[0], (num_time_steps)*(n_dim-1))) # assuming 1 control variable

    preds_arr[:,0] = data[:, 1]

    p = 1 #index to insert for preds array
    # for k, t in enumerate(range(3, num_time_steps*n_dim+1, 4)):
    for action in range(3, num_time_steps*n_dim+1, 2):


        vars_to_include = list(range(action-1)) + [action]
        spn_marg = marginalize(unrolled_model, vars_to_include)
        
        data_t = deepcopy(data)
        data_t[:, action] = np.nan
        data_t = data_t[:, :action+1]
        print(f"data_t: {data_t}")
       
        spn_mpe = mpe(spn_marg, data_t)
        curr_spn_mpe = spn_mpe[:, action]
        

        preds_arr[:, p] = curr_spn_mpe
        p += 1
    return preds_arr



def metric_over_all_vars(models, test_ds_list, metric_func):
    all_metric_vars = []
    for i, (model, test_ds) in enumerate(zip(models, test_ds_list)):
        print(test_ds)
        test_ds = test_ds.astype(float)
        print(f"seed:__________{SEEDS[i]}___________")
        pred_arr = next_error(model, test_ds)
        rmse_vars = metric_func(pred_arr, test_ds)
        all_metric_vars.append(rmse_vars)
    return all_metric_vars
  

In [176]:
def rmse_over_vars(preds, real_data):
    rmses = []
    n_dim=2
    real_actions_sl = real_data[:, 1::2]
    rmse_action = sk_mse(preds.flatten(), real_actions_sl.flatten(),multioutput="raw_values", squared=False)
    rmses.append(rmse_action)
    
    return rmses

def f1_over_vars(preds, real_data):
    f1s = []
    n_dim=2
    real_actions_sl = real_data[:, 1::2]
    f1_action = f1_score(preds.flatten(), real_actions_sl.flatten())
    f1s.append(f1_action)
    
    return f1_action

def roc_auc_over_vars(preds, real_data):
    n_dim=2
    real_actions_sl = real_data[:, 1::2]
    # print(f"real_actions_sl counts: {np.unique(real_actions_sl, return_counts=True)}")

    try:
        roc_auc_action = roc_auc_score(real_actions_sl.flatten(), preds.flatten())
    
    except ValueError:
        roc_auc_action = np.nan

    return roc_auc_action


def brier_over_vars(preds, real_data):
    n_dim=2
    real_actions_sl = real_data[:, 1::2]
    brier_action = brier_score_loss(preds.flatten(), real_actions_sl.flatten())
    return brier_action

In [177]:

all_f1_vars = metric_over_all_vars(rspn_vaso_models, test_ds_list, f1_over_vars)
all_roc_auc_vars = metric_over_all_vars(rspn_vaso_models, test_ds_list, roc_auc_over_vars)
all_brier_vars = metric_over_all_vars(rspn_vaso_models, test_ds_list, brier_over_vars)

[[9 0 9 ... 0 9 0]
 [0 0 0 ... 0 4 0]
 [4 0 4 ... 0 0 0]
 ...
 [4 0 0 ... 0 0 0]
 [0 0 0 ... 0 4 0]
 [4 0 4 ... 0 0 0]]
seed:__________42___________
==>> child: In_LatentNode_1
==>> child: In_LatentNode_2
==>> child: In_LatentNode_59
==>> child: In_LatentNode_60
==>> child: In_LatentNode_61
==>> child: In_LatentNode_62
==>> child: In_LatentNode_63
==>> child: In_LatentNode_64
==>> child: In_LatentNode_65
==>> child: In_LatentNode_66
data_t: [[ 9.  0.  9. nan]
 [ 0.  0.  0. nan]
 [ 4.  0.  4. nan]
 ...
 [ 4.  0.  0. nan]
 [ 0.  0.  0. nan]
 [ 4.  0.  4. nan]]
data_t: [[ 9.  0.  9.  0.  9. nan]
 [ 0.  0.  0.  0.  4. nan]
 [ 4.  0.  4.  0.  4. nan]
 ...
 [ 4.  0.  0.  0.  8. nan]
 [ 0.  0.  0.  0.  0. nan]
 [ 4.  0.  4.  0.  4. nan]]
data_t: [[ 9.  0.  9. ...  0.  9. nan]
 [ 0.  0.  0. ...  0.  0. nan]
 [ 4.  0.  4. ...  0.  8. nan]
 ...
 [ 4.  0.  0. ...  0.  8. nan]
 [ 0.  0.  0. ...  0.  4. nan]
 [ 4.  0.  4. ...  0.  0. nan]]
data_t: [[ 9.  0.  9. ...  0.  9. nan]
 [ 0.  0.  0. ...  0

In [178]:
all_f1_vars = np.array(all_f1_vars)
all_roc_auc_vars = np.array(all_roc_auc_vars)
all_brier_vars = np.array(all_brier_vars) 


In [179]:
metric_names = ["f1", "roc_auc", "brier"]
metric_lists = [all_f1_vars, all_roc_auc_vars, all_brier_vars]

for name, ls in zip(metric_names, metric_lists):

    print(f"_____{name}_____")
    print(f"overall mean across the 5 seeds: {np.nanmean(ls)}")
    print(f"overall sd across the 5 seeds: {np.nanstd(ls)}")
    print(f"full array for each seed:{ls}\n\n")

_____f1_____
overall mean across the 5 seeds: 0.0
overall sd across the 5 seeds: 0.0
full array for each seed:[0. 0. 0. 0. 0.]


_____roc_auc_____
overall mean across the 5 seeds: 0.5
overall sd across the 5 seeds: 0.0
full array for each seed:[0.5 0.5 0.5 0.5 0.5]


_____brier_____
overall mean across the 5 seeds: 0.04752365052865888
overall sd across the 5 seeds: 0.029588940934614363
full array for each seed:[0.02462437 0.03130217 0.03728436 0.03853645 0.1058709 ]


