**Do some preprocessing**

In [1]:
import pandas as pd
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
import os, pickle
import matplotlib.pyplot as plt

from sklearn.cluster import KMeans
from sklearn.preprocessing import OneHotEncoder

from itertools import product

import sys, os

import trajectory as T                      # trajectory generation
import optimizer as O                       # stochastic gradient descent optimizer
import solver as S                          # MDP solver (value-iteration)
import plot as P


num_data = 355504


np.random.seed(66)

def to_interval(istr):
    c_left = istr[0]=='['
    c_right = istr[-1]==']'
    closed = {(True, False): 'left',
              (False, True): 'right',
              (True, True): 'both',
              (False, False): 'neither'
              }[c_left, c_right]
    left, right = map(pd.to_datetime, istr[1:-1].split(','))
    return pd.Interval(left, right, closed)

re_split = False
frac = [0.4,0.2,0.4]
assert np.sum(frac) == 1
frac = np.cumsum(frac)
print (frac)
data_save_path= 'data/'

def sliding(gs, window_size = 6):
    npr_l = []
    for g in gs:
        npr = np.concatenate([np.zeros([window_size-1, g.shape[1]]),g])
        npr_l.append(sliding_window_view(npr, (window_size, g.shape[1])).squeeze(1))
    return np.vstack(npr_l)

[0.4 0.6 1. ]


In [2]:
# if re_split:

aggr_df = pd.read_csv('mimic_iv_hypotensive_cut2.csv',sep = ',', header = 0,converters={1:to_interval}).set_index(['stay_id','time']).sort_index()
# create action bins (four actions in total)
aggr_df['action'] = aggr_df['bolus(binary)']*2 + aggr_df['vaso(binary)']
all_idx = np.random.permutation(aggr_df.index.get_level_values(0).unique())
train_df = aggr_df.loc[all_idx[:int(len(all_idx)*frac[0])]].sort_index()
test_df = aggr_df.loc[all_idx[int(len(all_idx)*frac[0]):int(len(all_idx)*frac[1])]].sort_index()
valid_df = aggr_df.loc[all_idx[int(len(all_idx)*frac[1]):]].sort_index()
# print (np.unique(train_df['action'],return_counts=True)[1]*1./len(train_df))
# pickle.dump([train_df, test_df, valid_df], open(data_save_path+'processed_mimic_hyp_2.pkl','wb'))
drop_columns = ['vaso(amount)','bolus(amount)',\
            'any_treatment(binary)','vaso(binary)','bolus(binary)']

In [3]:
# for now drop indicators about bolus and vaso
train_df = train_df.drop(columns=drop_columns)
test_df = test_df.drop(columns=drop_columns)
valid_df = valid_df.drop(columns=drop_columns)

#### imputation
impute_table = pd.read_csv('mimic_iv_hypotensive_cut2_impute_table.csv',sep=',',header=0).set_index(['feature'])
train_df = train_df.fillna(method='ffill')
test_df = test_df.fillna(method='ffill')
valid_df = valid_df.fillna(method='ffill')




for f in impute_table.index:
    train_df[f] = train_df[f].fillna(value = impute_table.loc[f].values[0])
    test_df[f] = test_df[f].fillna(value = impute_table.loc[f].values[0])
    valid_df[f] = valid_df[f].fillna(value = impute_table.loc[f].values[0])


data_non_normalized_df = pd.concat([train_df, valid_df, test_df], axis=0, ignore_index=False).head(num_data).copy()


#### standard normalization ####
normalize_features = ['creatinine', 'fraction_inspired_oxygen', 'lactate', 'urine_output',
                  'alanine_aminotransferase', 'asparate_aminotransferase',
                  'mean_blood_pressure', 'diastolic_blood_pressure',
                  'systolic_blood_pressure', 'gcs', 'partial_pressure_of_oxygen']
mu, std = (train_df[normalize_features]).mean().values,(train_df[normalize_features]).std().values
train_df[normalize_features] = (train_df[normalize_features] - mu)/std
test_df[normalize_features] = (test_df[normalize_features] - mu)/std
valid_df[normalize_features] = (valid_df[normalize_features] - mu)/std




### create data matrix ####
X_train = train_df.loc[:,train_df.columns!='action']
y_train = train_df['action']

X_test = test_df.loc[:,test_df.columns!='action']
y_test = test_df['action']

X_valid = valid_df.loc[:, valid_df.columns!='action']
y_valid = valid_df['action']

In [4]:
X_df = pd.concat([X_train, X_valid, X_test], axis=0, ignore_index=True).copy()
y_df = pd.concat([y_train, y_valid, y_test], axis=0, ignore_index=True).copy()

In [5]:
data_df = pd.concat([train_df, valid_df, test_df], axis=0, ignore_index=False).copy()

In [6]:
X_non_normailzed = data_non_normalized_df.copy()
del X_non_normailzed['action']

In [7]:
X_non_normailzed

Unnamed: 0_level_0,Unnamed: 1_level_0,creatinine,fraction_inspired_oxygen,lactate,urine_output,alanine_aminotransferase,asparate_aminotransferase,mean_blood_pressure,diastolic_blood_pressure,systolic_blood_pressure,gcs,partial_pressure_of_oxygen,heart_rate,temperature,respiratory_rate
stay_id,time,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
30004811,"[2139-10-06 10:40:29, 2139-10-06 11:40:29)",1.0,0.21,1.8,80.0,34.0,40.0,77.0,59.0,118.0,11.0,112.0,86.0,37.000000,19.0
30004811,"[2139-10-06 11:40:29, 2139-10-06 12:40:29)",1.0,0.21,1.8,80.0,34.0,40.0,77.0,59.0,118.0,11.0,112.0,86.0,37.000000,19.0
30004811,"[2139-10-06 12:40:29, 2139-10-06 13:40:29)",1.0,0.21,3.0,80.0,34.0,40.0,77.0,59.0,118.0,11.0,272.0,86.0,37.000000,19.0
30004811,"[2139-10-06 13:40:29, 2139-10-06 14:40:29)",1.0,0.21,3.0,80.0,34.0,40.0,77.0,59.0,118.0,11.0,272.0,86.0,37.000000,19.0
30004811,"[2139-10-06 14:40:29, 2139-10-06 15:40:29)",1.0,0.21,3.0,80.0,34.0,40.0,77.0,59.0,118.0,11.0,272.0,86.0,37.000000,19.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
39986775,"[2123-10-10 12:18:46, 2123-10-10 13:18:46)",2.7,0.40,1.6,125.0,164.0,240.0,53.0,39.0,81.0,15.0,90.0,81.0,36.833333,18.0
39986775,"[2123-10-10 13:18:46, 2123-10-10 14:18:46)",2.7,0.40,1.6,60.0,164.0,240.0,67.0,47.0,117.0,15.0,90.0,69.0,36.833333,16.0
39986775,"[2123-10-10 14:18:46, 2123-10-10 15:18:46)",2.7,0.40,1.1,40.0,164.0,240.0,61.0,48.0,97.0,15.0,45.0,88.0,36.833333,22.0
39986775,"[2123-10-10 15:18:46, 2123-10-10 16:18:46)",2.7,0.50,1.0,40.0,164.0,240.0,67.0,55.0,101.0,15.0,50.0,95.0,36.777778,29.0


**Run clustering**

In [171]:
# K-Means
num_clusters = 100
kmeans = KMeans(n_clusters=num_clusters, random_state=0)
kmeans.fit(X_df)



In [172]:
# DBSCAN
#db = DBSCAN(eps=0.3, min_samples=10).fit(X)
#core_samples_mask = np.zeros_like(db.labels_, dtype=bool)
#core_samples_mask[db.core_sample_indices_] = True
#labels = db.labels_
 
# Number of clusters in labels, ignoring noise if present
# num_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)

In [173]:
np.unique(kmeans.labels_, return_counts=True)[1]

array([14726,   469,   929,   275,   179,   401,  3514, 12641,    94,
         983,  1070,   810,   374,   649,  9997,  1423,  3834,  9217,
         302,   131,  6673,  7645, 11232,  3985,  1834, 18439,  2072,
          18,    47,    82,  7912,   863,   146,  3929,  2567,   877,
        3947,  1329,  2434,  1729,  9512,    48,  4349,  8987,  7392,
       15238,   474,  8941,  3456,   596,   501,   664,   310,  5975,
       10030,   304, 14225,   163,   322,   931,   308,    48,  1363,
        5240, 11085,   309,   373,  6874,   587,   345,   187,   288,
         459,  2022, 18926,  1003,  4219,  4252,  8135,   229,   418,
        8573,  1803,  4790, 10018,  3207,  2571,  1447,   579,   116,
        3484,   465,  2750,   114,  5453,   159,   433,  1002,   928,
        8746])

In [174]:
# Assigning each data point to a cluster
X_df['cluster'] = kmeans.labels_.copy()
data_df['cluster'] = kmeans.labels_.copy()
data_non_normalized_df['cluster'] = kmeans.labels_.copy()

**Create trajectories**

In [175]:
unique_stay_ids = data_df.index.get_level_values('stay_id').unique()

trajectories = []


for stay_id in unique_stay_ids:


  states, actions = data_df.loc[stay_id]['cluster'], data_df.loc[stay_id]['action']

  trajectory = []
  for i in range(len(states) - 1):
    trajectory.append((states[i], int(actions[i]), states[i+1] ))

  trajectories.append(T.Trajectory(trajectory))

In [176]:
terminal_states = []

for traj in trajectories:
  terminal_states.append(traj._t[-1][-1])

terminal_states = list(set(terminal_states))

**Calculate Transition Probabilities**

In [177]:
smoothing_value = 1

p_transition = np.zeros((num_clusters, num_clusters, 4)) + smoothing_value


for traj in trajectories:

  for tran in traj._t:
                 #     s,      s',      a
    p_transition[tran[0], tran[2], tran[1]] +=1

p_transition = p_transition/ p_transition.sum(axis = 1)[:, np.newaxis, :]

In [178]:
# our existing p_transition matrix is shaped [s][s'][a]
# We need to swap the axes to get [s][a][s'] for the FIRL algorithm input

# Swap the axes of p_transition to match the expected order
sa_p = np.swapaxes(p_transition, 1, 2)

# Now sa_p is in the shape [state][action][target state] as requiredb

In [179]:
sa_p.shape

(100, 4, 100)

In [180]:
def build_mdp_data(states, actions, discount, sa_p):
    # Number of states (clusters) and actions are given

    # Initialize sa_s as a 3D array where sa_s[s, a, :] contains all possible next states
    # for taking action 'a' in state 's'
    sa_s = np.zeros((states, actions, states), dtype=int)

    # Here, we'll fill in sa_s with the indices of potential successor states.
    # This is a simplification for demonstration and should be tailored to your actual environment.
    for s in range(states):
        for a in range(actions):
            # The successor states are assumed to be all other states, including the current state.
            # This means from any state 's', any action 'a' can potentially lead to any state.
            sa_s[s, a, :] = np.arange(states)

    # Create MDP data structure
    mdp_data = {
        'states': states,
        'actions': actions,
        'discount': discount,
        'sa_s': sa_s,
        'sa_p': sa_p
    }

    return mdp_data

# Example usage:
num_states = num_clusters  # Number of states
num_actions = sa_p.shape[1]  # In our case 4 possible actions (vaso, etc.)
discount = 0.9      # Discount factor for MDP

mdp_data = build_mdp_data(num_states, num_actions, discount, sa_p)

In [181]:
print(mdp_data)

{'states': 100, 'actions': 4, 'discount': 0.9, 'sa_s': array([[[ 0,  1,  2, ..., 97, 98, 99],
        [ 0,  1,  2, ..., 97, 98, 99],
        [ 0,  1,  2, ..., 97, 98, 99],
        [ 0,  1,  2, ..., 97, 98, 99]],

       [[ 0,  1,  2, ..., 97, 98, 99],
        [ 0,  1,  2, ..., 97, 98, 99],
        [ 0,  1,  2, ..., 97, 98, 99],
        [ 0,  1,  2, ..., 97, 98, 99]],

       [[ 0,  1,  2, ..., 97, 98, 99],
        [ 0,  1,  2, ..., 97, 98, 99],
        [ 0,  1,  2, ..., 97, 98, 99],
        [ 0,  1,  2, ..., 97, 98, 99]],

       ...,

       [[ 0,  1,  2, ..., 97, 98, 99],
        [ 0,  1,  2, ..., 97, 98, 99],
        [ 0,  1,  2, ..., 97, 98, 99],
        [ 0,  1,  2, ..., 97, 98, 99]],

       [[ 0,  1,  2, ..., 97, 98, 99],
        [ 0,  1,  2, ..., 97, 98, 99],
        [ 0,  1,  2, ..., 97, 98, 99],
        [ 0,  1,  2, ..., 97, 98, 99]],

       [[ 0,  1,  2, ..., 97, 98, 99],
        [ 0,  1,  2, ..., 97, 98, 99],
        [ 0,  1,  2, ..., 97, 98, 99],
        [ 0,  1,  2, ...,

In [182]:
X_df

Unnamed: 0,creatinine,fraction_inspired_oxygen,lactate,urine_output,alanine_aminotransferase,asparate_aminotransferase,mean_blood_pressure,diastolic_blood_pressure,systolic_blood_pressure,gcs,...,asparate_aminotransferase_binned_binned_binned,mean_blood_pressure_binned_binned_binned,diastolic_blood_pressure_binned_binned_binned,systolic_blood_pressure_binned_binned_binned,gcs_binned_binned_binned,partial_pressure_of_oxygen_binned_binned_binned,heart_rate_binned_binned_binned,temperature_binned_binned_binned,respiratory_rate_binned_binned_binned,cluster_binned_binned_binned
0,-0.422008,-1.760743,-0.182521,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.368850,-2.479374,...,7,170,166,185,0,48,136,571,83,8
1,-0.422008,-1.760743,-0.182521,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.368850,-2.479374,...,7,170,166,185,0,48,136,571,83,8
2,-0.422008,-1.760743,0.360532,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.368850,-2.479374,...,7,170,166,185,0,147,136,571,83,34
3,-0.422008,-1.760743,0.360532,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.368850,-2.479374,...,7,170,166,185,0,147,136,571,83,34
4,-0.422008,-1.760743,0.360532,-0.225783,-0.288689,-0.265706,0.404836,0.391566,0.368850,-2.479374,...,7,170,166,185,0,147,136,571,83,34
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
355499,0.739524,-0.606278,-0.273030,0.138982,-0.080898,-0.123953,-1.457422,-1.418182,-1.615710,0.229732,...,63,105,109,126,0,35,128,562,77,71
355500,0.739524,-0.606278,-0.273030,-0.387901,-0.080898,-0.123953,-0.371105,-0.694283,0.315214,0.229732,...,63,142,132,184,0,35,110,562,68,71
355501,0.739524,-0.606278,-0.499303,-0.550019,-0.080898,-0.123953,-0.836669,-0.603795,-0.757522,0.229732,...,63,126,134,153,0,7,139,562,96,71
355502,0.739524,0.001335,-0.544557,-0.550019,-0.080898,-0.123953,-0.371105,0.029617,-0.542975,0.229732,...,63,142,153,159,0,10,152,559,126,71


## We now discretize each column separately
We need to do this since FIRL expects binary encoded features (for each
state, or in our case: cluster

We experimented with two approaches:
- binning (i.e. like in a histogram)
and more specifically using the Freedman-Diaconis rule for bin width
- 1D clustering using the Fisher-Jenks algorithm

Our intuition (+ hopefully experiments will show) is that the latter approach performs better?

In [183]:
# We make use of the package which comes with an efficient C implementation of the algo
#!pip install -q jenkspy
#import jenkspy

<b>Binning approach</b>

In [184]:
bin_edges_dict = {}

# Discretize each column separately
for column in X_df.columns:
    # REVIEW:
    # Calculate the bin width using the Freedman-Diaconis rule
    q75, q25 = np.percentile(X_df[column].dropna(), [75, 25])
    iqr = q75 - q25
    bin_width = 2 * iqr * (len(X_df[column]) ** (-1/3))

    # Determine the range of the data
    data_min, data_max = X_df[column].min(), X_df[column].max()

    # Use numpy.histogram_bin_edges to get the bin edges
    bin_edges = np.histogram_bin_edges(X_df[column].dropna(), bins='fd', range=(data_min, data_max))

    bin_edges_dict[column + '_binned'] = bin_edges

    # Discretize the column using the cut function and the bin edges
    X_df[column + '_binned'] = pd.cut(X_df[column], bins=bin_edges, labels=False, include_lowest=True)

# Now X_df has additional columns with the suffix '_binned' representing discretized values

In [185]:
# Create a new DataFrame with only the columns that have '_binned' suffix
X_df_binned = X_df.filter(like='_binned')

In [186]:
# Examine the number of bins created per each column
print(X_df_binned.max())

creatinine_binned                                          701
fraction_inspired_oxygen_binned                            279
lactate_binned                                             844
urine_output_binned                                        438
alanine_aminotransferase_binned                           4617
asparate_aminotransferase_binned                          5655
mean_blood_pressure_binned                                 860
diastolic_blood_pressure_binned                            742
systolic_blood_pressure_binned                             432
gcs_binned                                                   0
partial_pressure_of_oxygen_binned                          409
heart_rate_binned                                          450
temperature_binned                                         826
respiratory_rate_binned                                   1328
cluster_binned                                              68
creatinine_binned_binned                               

In [187]:
# For further processing, it will also be useful to know the specific bin boundaries
# for each column

<b> 1D Clustering approach </b>

   The result of the code is a DataFrame where each continuous variable is replaced by a discretized version, with the discretization determined by the natural breaks found by the Fisher-Jenks algorithm

In [188]:
"""
This (or at least this implementation) turns out to be way too slow!

for column in X_df.columns:
    print("Processing column", column)
    # Apply Fisher-Jenks algorithm to find natural breaks
    # The number of bins is still a hyperparameter based on domain knowledge
    num_bins = 2
    breaks = jenkspy.jenks_breaks(X_df[column].dropna(), n_classes=num_bins)

    # Create a new column for the binned data
    X_df[column + '_clustered'] = pd.cut(X_df[column], bins=breaks, labels=range(num_bins), include_lowest=True)

# Filter out the original columns to create a new DataFrame with only binned data
binned_columns_X_df = X_df.filter(like='_clustered')
"""

'\nThis (or at least this implementation) turns out to be way too slow!\n\nfor column in X_df.columns:\n    print("Processing column", column)\n    # Apply Fisher-Jenks algorithm to find natural breaks\n    # The number of bins is still a hyperparameter based on domain knowledge\n    num_bins = 2\n    breaks = jenkspy.jenks_breaks(X_df[column].dropna(), n_classes=num_bins)\n\n    # Create a new column for the binned data\n    X_df[column + \'_clustered\'] = pd.cut(X_df[column], bins=breaks, labels=range(num_bins), include_lowest=True)\n\n# Filter out the original columns to create a new DataFrame with only binned data\nbinned_columns_X_df = X_df.filter(like=\'_clustered\')\n'

In [189]:
# Individual cluster statistics
df_clusters = pd.DataFrame()

# We only examine the 5-95 percentiles for each cluster for the purposes of binary encoding
# This is clumsy, need to clean this up
# def percentile_5th(series):
#     return np.percentile(series, 5)

# def percentile_95th(series):
#     return np.percentile(series, 95)

TOP_PERCENTILE = 90     # 95
BOTTOM_PERCENTILE = 10  # 5

# For feature in the DataFrame (excluding the cluster_id)
for feature in X_df.columns.difference(['cluster']):
    # Group by 'cluster_id' and calculate the 10th and 90th percentiles for the feature
    grouped = X_df.groupby('cluster')[feature].agg([lambda series: np.percentile(series, BOTTOM_PERCENTILE), lambda series: np.percentile(series, TOP_PERCENTILE)]).reset_index()

    # Rename the columns appropriately
    grouped.columns = ['cluster', f'{feature}_{BOTTOM_PERCENTILE}th_percentile', f'{feature}_{TOP_PERCENTILE}th_percentile']

    # Merge the statistics back into the df_clusters DataFrame
    if df_clusters.empty:
        df_clusters = grouped
    else:
        df_clusters = pd.merge(df_clusters, grouped, on='cluster', how='outer')

### TODO: remove
# Previous code looked at the absolute [min, max] range
# Iterate over each feature in the original DataFrame (excluding the cluster_id)
# for feature in X_df.columns.difference(['cluster']):
#     # Group by 'cluster_id' and calculate the mean, min, and max for the feature
#     grouped = X_df.groupby('cluster')[feature].agg([np.mean, np.min, np.max]).reset_index()

#     # rename
#     grouped.columns = ['cluster', f'{feature}_mean', f'{feature}_min', f'{feature}_max']

#     # Merge the stats back
#     if df_clusters.empty:
#         df_clusters = grouped
#     else:
#         df_clusters = pd.merge(df_clusters, grouped, on='cluster', how='outer')

# Set the index to the cluster_id
df_clusters.set_index('cluster', inplace=True)

In [190]:
df_clusters

Unnamed: 0_level_0,alanine_aminotransferase_10th_percentile,alanine_aminotransferase_90th_percentile,alanine_aminotransferase_binned_10th_percentile,alanine_aminotransferase_binned_90th_percentile,alanine_aminotransferase_binned_binned_10th_percentile,alanine_aminotransferase_binned_binned_90th_percentile,alanine_aminotransferase_binned_binned_binned_10th_percentile,alanine_aminotransferase_binned_binned_binned_90th_percentile,alanine_aminotransferase_binned_binned_binned_binned_10th_percentile,alanine_aminotransferase_binned_binned_binned_binned_90th_percentile,...,urine_output_10th_percentile,urine_output_90th_percentile,urine_output_binned_10th_percentile,urine_output_binned_90th_percentile,urine_output_binned_binned_10th_percentile,urine_output_binned_binned_90th_percentile,urine_output_binned_binned_binned_10th_percentile,urine_output_binned_binned_binned_90th_percentile,urine_output_binned_binned_binned_binned_10th_percentile,urine_output_binned_binned_binned_binned_90th_percentile
cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,-0.329449,-0.229549,2.5,33.0,1.5,32.0,0.5,31.0,0.0,30.0,...,-0.631078,0.746923,10.0,73.0,10.0,73.0,10.0,73.0,10.0,73.0
1,1.973042,2.436578,692.0,831.0,680.0,817.0,669.0,803.0,658.0,790.0,...,-0.712137,0.746923,7.0,73.0,7.0,73.0,7.0,73.0,7.0,73.0
2,1.581435,2.145669,575.0,744.0,565.0,732.0,555.0,720.0,546.0,708.0,...,-0.833725,1.152217,1.0,91.0,1.0,92.0,1.0,93.0,1.0,94.0
3,4.643964,5.344062,1492.0,1702.0,1468.0,1674.0,1444.0,1647.0,1421.0,1620.0,...,-0.874254,1.071159,0.0,87.4,0.0,88.4,0.0,89.4,0.0,90.4
4,7.616982,10.357274,2383.0,3203.2,2344.0,3151.0,2306.0,3099.8,2269.0,3049.6,...,-0.833725,0.584805,1.0,65.0,1.0,65.0,1.0,65.0,1.0,65.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,2.559654,3.410001,868.0,1123.0,854.0,1105.0,840.0,1087.0,826.0,1069.0,...,-0.874254,-0.387901,0.0,21.0,0.0,21.0,0.0,21.0,0.0,21.0
96,4.354654,4.597610,1406.0,1478.0,1383.0,1454.0,1360.0,1430.0,1338.0,1407.0,...,-0.687819,0.949570,8.0,82.0,8.0,83.0,8.0,84.0,8.0,85.0
97,0.772646,1.559058,333.0,568.0,327.0,558.0,321.0,549.0,315.0,540.0,...,-0.833725,0.949570,1.0,82.0,1.0,83.0,1.0,84.0,1.0,85.0
98,0.218002,0.751867,167.0,326.0,164.0,320.0,161.0,314.0,158.0,309.0,...,-0.833725,0.746923,1.0,73.0,1.0,73.0,1.0,73.0,1.0,73.0


In [191]:
num_bins_per_column = dict([(column, X_df_binned[column].max()) for column in X_df_binned])
print(num_bins_per_column)

{'creatinine_binned': 701, 'fraction_inspired_oxygen_binned': 279, 'lactate_binned': 844, 'urine_output_binned': 438, 'alanine_aminotransferase_binned': 4617, 'asparate_aminotransferase_binned': 5655, 'mean_blood_pressure_binned': 860, 'diastolic_blood_pressure_binned': 742, 'systolic_blood_pressure_binned': 432, 'gcs_binned': 0, 'partial_pressure_of_oxygen_binned': 409, 'heart_rate_binned': 450, 'temperature_binned': 826, 'respiratory_rate_binned': 1328, 'cluster_binned': 68, 'creatinine_binned_binned': 709, 'fraction_inspired_oxygen_binned_binned': 282, 'lactate_binned_binned': 830, 'urine_output_binned_binned': 443, 'alanine_aminotransferase_binned_binned': 4542, 'asparate_aminotransferase_binned_binned': 5563, 'mean_blood_pressure_binned_binned': 870, 'diastolic_blood_pressure_binned_binned': 730, 'systolic_blood_pressure_binned_binned': 437, 'gcs_binned_binned': 0, 'partial_pressure_of_oxygen_binned_binned': 402, 'heart_rate_binned_binned': 442, 'temperature_binned_binned': 835, '

<b> We now create the *splittable* matrix </b>

In [192]:
"""
splittable = []

# Iterate over each cluster
for index, row in df_clusters.iterrows():
    cluster_vector = []

    # Iterate over each feature
    for feature, bin_edges in bin_edges_dict.items():
        # The corresponding binned column in df_clusters
        percentile_col_25th = f'{feature}_25th_percentile'
        percentile_col_75th = f'{feature}_75th_percentile'

        # Check if the percentiles for the feature are in the dataframe
        if percentile_col_25th in df_clusters.columns and percentile_col_75th in df_clusters.columns:
            # Get the 5th and 95th percentile values for this cluster and feature
            percentile_25th_val = row[percentile_col_25th]
            percentile_75th_val = row[percentile_col_75th]

            # Determine the range of bins that the 5th to 95th percentile values fall into
            min_bin_index = np.searchsorted(bin_edges, percentile_25th_val, side='right') - 1
            max_bin_index = np.searchsorted(bin_edges, percentile_75th_val, side='left')

            # Create a binary vector for this feature in this cluster
            feature_vector = np.zeros(len(bin_edges) - 1)
            feature_vector[min_bin_index:max_bin_index] = 1
            cluster_vector.extend(feature_vector.tolist())

    # Add the binary vector for this cluster to the splittable list
    splittable.append(cluster_vector)

# Convert the splittable list to a DataFrame
splittable_df = pd.DataFrame(splittable)
"""
pass

In [193]:
# Main input matrix to FIRL
splittable = []

# Iterate over each cluster
for index, row in df_clusters.iterrows():
    cluster_vector = []

    # Iterate over each feature
    for feature, bin_edges in bin_edges_dict.items():
        # The corresponding binned column in df_clusters
        percentile_col_45th = f'{feature}_{BOTTOM_PERCENTILE}th_percentile'
        percentile_col_55th = f'{feature}_{TOP_PERCENTILE}th_percentile'

        # Check if the percentiles for the feature are in the dataframe
        if percentile_col_45th in df_clusters.columns and percentile_col_55th in df_clusters.columns:
            # Get the 5th and 95th percentile values for this cluster and feature
            percentile_45th_val = row[percentile_col_45th]
            percentile_55th_val = row[percentile_col_55th]

            # Determine the range of bins that the 5th to 95th percentile values fall into
            min_bin_index = np.searchsorted(bin_edges, percentile_45th_val, side='right') - 1
            max_bin_index = np.searchsorted(bin_edges, percentile_55th_val, side='left')

            # Create a binary vector for this feature in this cluster
            feature_vector = np.zeros(len(bin_edges) - 1)
            feature_vector[min_bin_index:max_bin_index] = 1
            cluster_vector.extend(feature_vector.tolist())

    # Add the binary vector for this cluster to the splittable list
    splittable.append(cluster_vector)

# Convert the splittable list to a DataFrame
splittable_df = pd.DataFrame(splittable)

In [194]:
splittable_df.values

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

In [195]:
# For testing purposes, we check that the clusters do differ (on a small subset of clusters)
print(f"Total number of binary encoded component features: {len(splittable_df.iloc[0])}")
#n = num_clusters
n = 10
for i in range(3):
  for j in range(i, 3):
    # Skip comparing a cluster to itself
    if i == j:
      continue

    contributions = [a==b for a, b in zip(splittable_df.iloc[i], splittable_df.iloc[j])]
    print(f"Cluster {i} and {j} differ by {len([t for t in contributions if not t])} bits")

Total number of binary encoded component features: 69651
Cluster 0 and 1 differ by 11284 bits
Cluster 0 and 2 differ by 12253 bits
Cluster 1 and 2 differ by 3639 bits


In [196]:
# Now that we have the splittable matrix, we can finally finish creating the
# feature_data dict for FIRL

# sa_s is a matrix we created previously
# that defines whether we *can* transition to state _s from state s by
# taking action a. Now, we create the stateadjacency matrix,
# which only considers whether two states _s and s are connected by *any*
# action, regardless of what that action is

# We'll use the csr_matrix from scipy learn
# for the sparse matrices we are dealing with here
# An alternative would be the lil_matrix class
from scipy.sparse import csr_matrix

sa_s = mdp_data['sa_s']

# Create a 2D matrix for state adjacency
stateadjacency_matrix = np.zeros((num_states, num_states), dtype=int)

# Fill in the adjacency matrix
for s in range(num_states):
    for a in range(num_actions):
        # Get the possible next states for state 's' when action 'a' is taken
        next_states = sa_s[s, a, :]
        # Mark the adjacency matrix for each possible next state
        for next_state in next_states:
            stateadjacency_matrix[s, next_state] = 1

# Since stateadjacency_matrix is typically sparse, convert it to a sparse matrix
stateadjacency = csr_matrix(stateadjacency_matrix)
# ^ The original matlab code does it, but does the same assumption hold for MIMIC?

In [197]:
# Finally, the FIRL feature_data:
feature_data = {
    'stateadjacency': stateadjacency,
    'splittable': splittable_df.values
}

# FIRL
At this point, we are finally ready to run
our FIRL algorithm. As described in the midway check-in
report the necessary inputs are:

- algorithm_params - The parameters of the algorithm. This specifies the hyperparameters for the FIRL algorithm like the number of iterations (of the optimization and fitting steps)

- mdp_data - A specification of the example domain. mdp_data includes the number of states, the number of actions, the transition function (specified by sa_s and sa_p). sa_s specifies the what states we can reach from the current state while sa_p specifies the probability of transitioning into every state from the current state. Unlike the simple gridworld case, in the discreticized state spaces we have, the entries of the sa_s matrix are all ones since we assume we can transition from one state to any other state.

- feature_data - Information about the features. The FIRL implementation in the gridworld uses this to perform rectangular partitioning of the state space. For example, the state (1, 2) in a gridworld with dimensions 4 X 5, has the component feature [0, 1, 1, 1] and [0, 0, 1, 1, 1]. However, this doesn't immediately apply to our MIMIC states. Hence, after careful thought and considering other options, we reached decidded to split each feature into bins. For example, if blood pressure has values ranging from 0 to 100, we split it into e.g. 5 groups each consisting of 20 values. Then for each for these states, we have the binary encoding. We do these for all the 16 features we have. Thus, each state has 16 * 5 = 80 entries in its component feature.

- example_samples - the $\pi^*$ samples, i.e. trajectories dataset.

- true_features - The true features that form a linear basis for the reward. This is the 16 features for each state (if they exist). The feature *extraction* algorithm will ignore these, but these could be useful for e.g. debugging

In [198]:
import numpy as np

def stdvalueiteration(mdp_data, r, vinit=None):
    """
    Run value iteration to solve a standard MDP.

    Parameters:
        mdp_data (dict): Contains MDP related data.
        r (numpy array): Reward function.
        vinit (numpy array, optional): Initial value function.

    Returns:
        numpy array: Computed value function.
    """
    
    # Allocate initial value function & variables.
    diff = 1.0
    if vinit is not None:
        vn = vinit
    else:
        vn = np.zeros(mdp_data['states'])

    # Perform value iteration.
    while diff >= 1e-8:  # Using 1e-8 as the convergence threshold
        vp = vn
        vn = np.max(r + np.sum(mdp_data['sa_p'] * vp[mdp_data['sa_s']], axis=2) * mdp_data['discount'], axis=1)
        diff = np.max(np.abs(vn - vp))

    # Return value function.
    return vn


def stdpolicy(mdp_data, r, v):
    """
    Given reward and value functions, solve for q function and policy.

    Parameters:
        mdp_data (dict): Contains MDP related data.
        r (numpy array): Reward function.
        v (numpy array): Value function.

    Returns:
        tuple: q function and policy.
    """

    # Compute Q function.
    q = r + np.sum(mdp_data['sa_p'] * v[mdp_data['sa_s']], axis=2) * mdp_data['discount']

    # Compute policy.
    p = np.argmax(q, axis=1)

    return q, p


In [199]:
from time import time as get_time
import numpy as np
import cvxpy as cp
from scipy.sparse import csr_matrix
# from mdp import stdpolicy, stdvalueiteration

class TreeNode:
    def __init__(self, type_val, index, test, mean_val, cells=None, ltTree=None, gtTree=None):
        self.type = type_val
        self.index = index
        self.test = test
        self.cells = cells
        self.mean = mean_val
        self.ltTree = ltTree
        self.gtTree = gtTree

def firlmatchdepth(tree, l1, l2) -> int:
    # Check if both leaves match
    if tree.type == 0:
        if tree.index == l1:
            return -1
        elif tree.index == l2:
            return -2
        else:
            return 0
    else:
        mLeft = firlmatchdepth(tree.ltTree, l1, l2)
        mRight = firlmatchdepth(tree.gtTree, l1, l2)
        
        if (mLeft == -1 or mLeft == -2) and mRight == 0:
            return mLeft
        elif (mRight == -1 or mRight == -2) and mLeft == 0:
            return mRight
        elif (mRight == -1 and mLeft == -2) or (mRight == -2 and mLeft == -1):
            return 1
        else:
            matchDepth = max(mLeft, mRight)
            if matchDepth > 0:
                matchDepth += 1
            return matchDepth

# Return index of the leaf that contains state s in tree
def firlcheckleaf(tree, s, feature_data):

    # Check if this is a leaf
    if tree.type == 0:
        # Return result
        return tree.index, tree.mean
    else:
        # Recurse
        if feature_data['splittable'][s, tree.test] == 0:
            branch = tree.ltTree
        else:
            branch = tree.gtTree
        
        return firlcheckleaf(branch, s, feature_data)


def firlaveragereward(tree, R, actions):
    """
    Compute the closest reward function that can be represented by the given tree.

    Args:
    - tree: the tree structure with attributes `type`, `cells`, and `mean`.
    - R: the reward function matrix.
    - actions: the number of actions.

    Returns:
    - Rout: the updated reward function.
    """
    if tree.type == 0:
        count = len(tree.cells)

        # Replace the relevant section of the reward function.
        for i in range(count):
            s = tree.cells[i]
            for a in range(actions):
                R[s][a] = tree.mean[a]
        Rout = R
    else:
        R = firlaveragereward(tree.ltTree, R, actions)
        R = firlaveragereward(tree.gtTree, R, actions)
        Rout = R
    return Rout

def firldefaultparams(algorithm_params={}):
    """
    Fill in default parameters for the FIRL algorithm.

    Args:
    - algorithm_params: dictionary containing provided parameters.

    Returns:
    - algorithm_params: dictionary containing all parameters with defaults filled in.
    """

    # Create default parameters
    default_params = {
        'seed': 0,
        'iterations': 10,
        'depth_step': 1,
        'init_depth': 0
    }
    
    # Set parameters with defaults if not provided
    for key, value in default_params.items():
        algorithm_params.setdefault(key, value)
    
    return algorithm_params

def firlregressiontree(st_states, depth, leavesIn, Eo, R, V, split_thresh, max_depth, mdp_data, feature_data):
    """
    Construct decision subtree.
    """
    leaves = leavesIn
    test = 1
    G = float('inf')

    if depth > max_depth:
        makeLeaf = False
        fMean = R[st_states, :].mean(axis=0)
    else:
        # Step over all possible splitting moves
        for tTest in range(feature_data['splittable'].shape[1]):
            # Split the examples
            st_splits = feature_data['splittable'][st_states, tTest]
            lt_states = st_states[st_splits == 0]
            gt_states = st_states[st_splits == 1]

            # Compute mean
            ltMean = R[lt_states, :].mean()
            gtMean = R[gt_states, :].mean()
            ltVar = ((R[lt_states, :] - ltMean) ** 2).sum()
            gtVar = ((R[gt_states, :] - gtMean) ** 2).sum()
            value = ltVar + gtVar

            if len(lt_states) > 0 and len(gt_states) > 0 and value < G:
                G = value
                test = tTest

        # Construct the partitions
        st_splits = feature_data['splittable'][st_states, test]
        lt_states = st_states[st_splits == 0]
        gt_states = st_states[st_splits == 1]
        fMean = R[st_states, :].mean(axis=0)
        fullMean = fMean.mean()
        maxDeviation = ((R[st_states, :] - fullMean) ** 2).max(axis=0).max()


        if maxDeviation > (split_thresh ** 2) and len(st_states) > 1:
            # Test if this node should be prunable
            Rnew = R.copy()
            Rnew[st_states, :] = fMean
            Vnew = stdvalueiteration(mdp_data, Rnew, V)
            _, P = stdpolicy(mdp_data, Rnew, Vnew)

            # Test if P matches all non-zero values of Eo
            mismatches = Eo * (P != Eo)
            makeLeaf = len(np.nonzero(mismatches)[0]) != 0
        else:
            makeLeaf = False

    if makeLeaf and len(st_states) > 1 and G != float('inf'):
        # Create node with the best split
        rightTree, leaves, R, V = firlregressiontree(gt_states, depth+1, leaves, Eo, R, V, split_thresh, max_depth, mdp_data, feature_data)
        leftTree, leaves, R, V = firlregressiontree(lt_states, depth+1, leaves, Eo, R, V, split_thresh, max_depth, mdp_data, feature_data)
        
        # Create node. TreeNode constructor parameters:
        # type_val, index, test, mean_val, cells=None, ltTree=None, gtTree=None
        tree = TreeNode(1, None, test, fMean, st_states.tolist(), leftTree, rightTree)
        Rout = R
        Vout = V
    else:
        # Create leaf node
        # tree = {'type': 0, 'index': leaves + 1, 'mean': fMean, 'cells': st_states.tolist()}
        tree = TreeNode(0, leaves, None, fMean, st_states.tolist())
        leaves += 1
        Rout = R
        Vout = V

    return tree, leaves, Rout, Vout


def firloptimization(Eo, Rold, ProjToLeaf, LeafToProj, FeatureMatch, mdp_data, verbosity):
    """
    Runs the optimization phase to compute a reward function that is close to 
    the current feature hypothesis 
    """

    # Smoothing term (relative to reward objective)
    # SMOOTH_WEIGHT = 0.02
    SMOOTH_WEIGHT = 0.001

    # Total size
    states = mdp_data['states']
    actions = mdp_data['actions']
    msize = states * actions
    results = mdp_data['sa_s'].shape[2]

    ### Constraint construction ###
    cols = np.nonzero(Eo)[0]
    examples = len(cols)

    sN = np.zeros(msize - examples * actions, dtype=int)            # start state idxs
    rN = np.zeros(msize - examples * actions, dtype=int)            # state-action idxs
    eN = np.zeros((msize - examples * actions, results), dtype=int) # resultant state idxs
    pN = np.zeros((msize - examples * actions, results))            # resultant state coeffs

    sM = np.zeros(examples * (actions - 1), dtype=int)
    rM = np.zeros(examples * (actions - 1), dtype=int)
    eM = np.zeros((examples * (actions - 1), results), dtype=int)
    pM = np.zeros((examples * (actions - 1), results))

    sE = np.zeros(examples, dtype=int)
    rE = np.zeros(examples, dtype=int)
    eE = np.zeros((examples, results), dtype=int)
    pE = np.zeros((examples, results))

    Nrow = 0
    Mrow = 0
    Erow = 0
    for startstate in range(states):
        if Eo[startstate] != 0:
            # We generate destination state and reward under the optimal action
            optaction = Eo[startstate]
            reward = actions * startstate + optaction

            sE[Erow] = startstate
            rE[Erow] = reward
            eE[Erow, :] = mdp_data['sa_s'][startstate, optaction, :]
            pE[Erow, :] = mdp_data['sa_p'][startstate, optaction, :] * mdp_data['discount']
            Erow += 1

            for action in range(actions):
                if action != optaction:
                    reward = actions * startstate + action

                    sM[Mrow] = startstate
                    rM[Mrow] = reward
                    eM[Mrow, :] = mdp_data['sa_s'][startstate, action, :]
                    pM[Mrow, :] = mdp_data['sa_p'][startstate, action, :] * mdp_data['discount']
                    Mrow += 1
        else:
            for action in range(actions):
                # Generate destination state and reward indices
                reward = actions * startstate + action

                sN[Nrow] = startstate
                rN[Nrow] = reward
                eN[Nrow, :] = mdp_data['sa_s'][startstate, action, :]
                pN[Nrow, :] = mdp_data['sa_p'][startstate, action, :] * mdp_data['discount']
                Nrow += 1

    # Determine number of leaves
    _, msize = ProjToLeaf.shape
    leafEntries, leaves = FeatureMatch.shape

    # Margin by which examples should be optimal
    MARGIN = 0.01
    margins = np.ones(examples * (actions - 1)) * MARGIN

    EPSILON = 2.22e-16
    r = cp.Variable(msize)
    v = cp.Variable(states)
    f = cp.Variable(leaves)

    objective = cp.Minimize(cp.norm(LeafToProj @ f - r) ** 2 / msize +
        cp.norm(FeatureMatch @ f, 1) * (SMOOTH_WEIGHT / (leafEntries * 500)))

    constraints = [
        f == ProjToLeaf @ r,
        #v[sN] >= r[rN] + cp.sum(cp.multiply(v[eN], pN), axis=1),
        #v[sM] >= r[rM] + cp.sum(cp.multiply(v[eM],pM), axis=1) + margins,
        #v[sE] == r[rE] + cp.sum(cp.multiply(v[eE], pE), axis=1)
    ]

    # NOTE: In CVXPY, we can't index a variable directly with a list or array
    # of indices like we can in MATLAB. Hence:

    # Add constraints for sN, rN, eN, and pN
    for idx in range(len(sN)):
        constraints.append(v[sN[idx]] >= r[rN[idx]] +
                           cp.sum(cp.multiply(v[eN[idx, :]], pN[idx, :])))

    # Add constraints for sM, rM, eM, and pM with margins
    for idx in range(len(sM)):
        constraints.append(v[sM[idx]] >= r[rM[idx]] +
                           cp.sum(cp.multiply(v[eM[idx, :]], pM[idx, :])) + margins[idx])

    # Add constraints for sE, rE, eE, and pE
    for idx in range(len(sE)):
        constraints.append(v[sE[idx]] == r[rE[idx]] +
                           cp.sum(cp.multiply(v[eE[idx, :]], pE[idx, :])))

    prob = cp.Problem(objective, constraints)
    prob.solve(verbose=verbosity == 2)

    if Rold.shape[0] > 1 and np.isnan(prob.value):
        if verbosity != 0:
            print('WARNING: Failed to obtain solution, reverting to old reward!')
        R = Rold
    else:
        # Recover the reward function
        R = r.value.reshape(actions, states).T

    return R, MARGIN



def firlprojectionfromtree(tree, leaves, states, actions, feature_data):
    
    DEPTH_WEIGHT = 1
    
    # Matrix of adjacencies.
    adjleaves = csr_matrix((leaves, leaves), dtype=np.int32)
    stateleaves = np.zeros(states, dtype=np.int32)
    
    # Count number of elements in each leaf and assign leaf to each state.
    elements = np.zeros(leaves, dtype=np.int32)
    for s in range(states):
        leaf, _mean = firlcheckleaf(tree, s, feature_data)
        elements[leaf] += 1
        stateleaves[s] = leaf
        
    # Count pairs and build adjacency matrix.
    pairs = 0
    for s in range(states):
        leaf = stateleaves[s]
        adj = np.nonzero(feature_data['stateadjacency'][s, :])[0]
        numadj = len(adj)
        
        # Write out adjacencies
        for i in adj:
            lother = stateleaves[i]
            if lother != leaf:
                # Found adjacency
                if adjleaves[lother, leaf] == 0 and adjleaves[leaf, lother] == 0:
                    pairs += 1
                adjleaves[lother, leaf] = 1
                adjleaves[leaf, lother] = 1
                
    # Construct feature match matrix
    FeatureMatch = csr_matrix((pairs, leaves), dtype=np.float64)
    idx = 0
    maxPair = 0
    for l1 in range(leaves):
        for l2 in range(l1 + 1, leaves):
            adjacent = adjleaves[l1, l2]
            if adjacent > 0:
                matchDepth = (firlmatchdepth(tree, l1, l2) - 1)
                FeatureMatch[idx, l1] = adjacent + matchDepth * DEPTH_WEIGHT
                FeatureMatch[idx, l2] = -adjacent - matchDepth * DEPTH_WEIGHT
                if FeatureMatch[idx, l1] > maxPair:
                    maxPair = FeatureMatch[idx, l1]
                idx += 1
    
    if pairs <= 0:
        # Handle degeneracy
        FeatureMatch = csr_matrix((1, leaves), dtype=np.float64)
    else:
        FeatureMatch = FeatureMatch / maxPair
    
    # Construct projection matrix
    ProjToLeaf = csr_matrix((leaves, states * actions), dtype=np.float64)
    LeafToProj = csr_matrix((states * actions, leaves), dtype=np.float64)
    
    for s in range(states):
        leaf = stateleaves[s]
        for a in range(actions):
            pos = s * actions + a
            ProjToLeaf[leaf, pos] = 1.0 / (elements[leaf] * actions)
            LeafToProj[pos, leaf] = 1.0

    # Convert to CSR for efficient operations in future usage
    return ProjToLeaf.tocsr(), LeafToProj.tocsr(), FeatureMatch.tocsr()


def firlrun(algorithm_params, mdp_data, mdp_model, feature_data, example_samples, _, verbosity):

    # Fill in default parameters
    algorithm_params = firldefaultparams(algorithm_params)

    np.random.seed(algorithm_params['seed'])

    # Initialize variables
    states = mdp_data['states']
    actions = mdp_data['actions']
    iterations = algorithm_params['iterations']
    depth_step = algorithm_params['depth_step']
    init_depth = algorithm_params['init_depth']

    # Construct mapping from states to example actions
    Eo = np.zeros(states, dtype=int)
    for i in range(len(example_samples)):
        for t in range(len(example_samples[i])):
            Eo[example_samples[i][t][0]] = example_samples[i][t][1]

    # Construct initial tree
    leaves = 1
    # Note: In python should be zero indexed
    # tree = {'type': 0, 'index': 0, 'mean': np.zeros(actions)}
    tree = TreeNode(0, 0, None, np.zeros(actions))
    ProjToLeaf, LeafToProj, FeatureMatch = firlprojectionfromtree(tree, leaves, states, actions, feature_data)

    # Prepare timing variables.
    optTime, fitTime, vitTime, matTime = [np.zeros(iterations) for _ in range(4)]
    
    # Prepare intermediate output variables
    opt_acc_itr = [None] * (iterations)
    r_itr = [None] * (iterations)
    p_itr = [None] * (iterations)
    model_itr = [None] * (iterations)
    model_r_itr = [None] * (iterations)
    model_p_itr = [None] * (iterations)
    
    # Run firl.
    Rold = np.random.normal(size=(states, actions))
    itr = 0
    while True:
        if verbosity != 0:
            print(f'Beginning FIRL iteration {itr+1}')

        # Run optimization phase
        start_time = get_time()
        R, margin = firloptimization(Eo, Rold, ProjToLeaf, LeafToProj, FeatureMatch, mdp_data, verbosity)
        Rold = R
        threshold = margin * 0.2 * mdp_data['discount']
        optTime[itr] = get_time() - start_time

        # Generate policy
        start_time = get_time()
        V = stdvalueiteration(mdp_data, R)
        _, P = stdpolicy(mdp_data, R, V)
        vitTime[itr] = get_time() - start_time

        # Construct tree
        start_time = get_time()
        # Adjust Eo to exclude violated examples
        # In an exact optimization, there should be no violated examples
        # However, an approximation might violate some examples
        Eadjusted = Eo * (P == Eo)
        totalExamples = np.sum(Eadjusted > 0)
        #opt_acc_itr.append(totalExamples / np.sum(Eo > 0))
        opt_acc_itr[itr] = totalExamples / np.sum(Eo > 0)
        max_depth = init_depth + itr * depth_step
        tree, leaves, _, _ = firlregressiontree(
            np.arange(states),      # Start with all states
            0,                      # Current depth
            0,                      # First leaf index
            Eadjusted,              # Pass in part of policy we want to match
            R,                      # Pass in reward function
            V,                      # Pass in value function
            threshold,              # Pass in termination threshold
            max_depth,              # Pass in maximum depth
            mdp_data,               # Pass in MDP data
            feature_data            # Pass in feature data
        )
        fitTime[itr] = get_time() - start_time

        # Construct projection matrices
        start_time = get_time()
        ProjToLeaf, LeafToProj, FeatureMatch = firlprojectionfromtree(tree, leaves, states, actions, feature_data)
        matTime[itr] = get_time() - start_time

        # Record policy at this iteration
        #r_itr.append(R)
        #p_itr.append(P)
        #model_itr.append(tree)
        r_itr[itr] = R
        p_itr[itr] = P
        model_itr[itr] = tree

        # Increment iteration
        itr += 1
        if itr >= iterations:
            break

    # Compute final policy
    Rout = firlaveragereward(tree, R, actions)
    Vout = stdvalueiteration(mdp_data, Rout)
    Qout, Pout = stdpolicy(mdp_data, Rout, Vout)

    # Compute all intermediate policies
    for i in range(iterations):
        model_r_itr[i] = firlaveragereward(model_itr[i], r_itr[i], actions)
        v = stdvalueiteration(mdp_data, model_r_itr[i])
        _, model_p_itr[i] = stdpolicy(mdp_data, model_r_itr[i], v)

    if verbosity != 0:
        # Report timing
        for itr in range(iterations):
            print(f'Iteration {itr + 1} optimization: {optTime[itr]:.6f}s')
            print(f'Iteration {itr + 1} value iteration: {vitTime[itr]:.6f}s')
            print(f'Iteration {itr + 1} fitting: {fitTime[itr]:.6f}s')
            print(f'Iteration {itr + 1} objective construction: {matTime[itr]:.6f}s')

    total = sum(optTime) + sum(vitTime) + sum(fitTime) + sum(matTime)
    if verbosity != 0:
        print(f'Total time: {total:.6f}s\n')

    time = total
    mean_opt_time = np.mean(optTime)
    mean_fit_time = np.mean(fitTime)

    # Build output structure
    irl_result = {
        'r': Rout,
        'v': Vout,
        'q': Qout,
        'p': Pout,
        'opt_acc_itr': opt_acc_itr,
        'r_itr': r_itr,
        'model_itr': model_itr,
        'model_r_itr': model_r_itr,
        'p_itr': p_itr,
        'model_p_itr': model_p_itr,
        'time': time,
        'mean_opt_time': mean_opt_time,
        'mean_fit_time': mean_fit_time
    }

    return irl_result


In [200]:
trajectories_tmp = [t.transitions() for t in trajectories]
type(trajectories_tmp[0])

list

In [None]:
# firlrun() takes in the arguments as specified previously

# For the unused arguments, we use the defaults
algorithm_params = {'iterations': 25}
mdp_model = None
example_samples = trajectories_tmp
verbosity = 1

# Run FIRL
res = firlrun(algorithm_params, mdp_data, mdp_model, feature_data, example_samples, None, verbosity)

  self._set_intXint(row, col, x.flat[0])


Beginning FIRL iteration 1


  gtMean = R[gt_states, :].mean()
  ret = ret.dtype.type(ret / rcount)
  ltMean = R[lt_states, :].mean()


Beginning FIRL iteration 2
Beginning FIRL iteration 3


In [None]:
res.keys()

In [None]:
print(res['r'])

In [None]:
print(res['model_r_itr'])

# Visualizing best and worst clusters
We sort based on the value function 

In [None]:
res['model_itr'][-1]

In [None]:
print(res['v'])

In [None]:
sorted_clusters = np.argsort(res['v'])
sorted_clusters

In [None]:
top_5 = sorted_clusters[-5::][::-1]
bottom_5 = sorted_clusters[:5][::-1]

In [None]:
selected_clusters = list(top_5) + list(bottom_5)
selected_clusters

In [None]:
res['v'][selected_clusters]

In [None]:
features = data_non_normalized_df.columns[:-2] # no cluster / action cols
means = pd.DataFrame(columns=features, index=selected_clusters)

for cluster in selected_clusters:
    subset = data_non_normalized_df[data_non_normalized_df['cluster'] == cluster]
    means.loc[cluster, features] = subset[features].mean()
    
means

In [None]:
normalized_means = (means - means.min()) / (means.max() - means.min())
normalized_means = normalized_means.astype(float)
normalized_means

In [None]:
import seaborn as sns
custom_labels = ['Best Cluster', 'Second Best Cluster', 'Third Best Cluster', 'Fourth Best Cluster', 'Fifth Best Cluster', 
                 'Fifth Worst Cluster', 'Fourth Worst Cluster', 'Third Worst Cluster', 'Second Worst Cluster', 'Worst Cluster']

plt.figure(figsize=(10,10))
sns.heatmap(normalized_means.T, cmap="YlGnBu", annot=means.T, fmt=".2f")
plt.xticks(ticks=np.arange(len(custom_labels)), labels=custom_labels, rotation=45)
plt.show()

In [None]:
data_df

# Accuracy: 
Compare accuracy to the training set actions (y_df)

In [None]:
firl_policy = res['p']
accuracy = sum((firl_policy[data_df['cluster']]) == y_df) / len(y_df)
print(accuracy)

In [None]:
len(firl_policy)