In [1]:
NSIMSAMPS = 100_000    # Samples to draw from the simulator
epsilon = 8/7          # epsilon-greedy behavior policy
runs = list(range(10)) # Repeat for 10 replications

output_dir = '../datagen/suboptimal-100k/'

# Features

In [2]:
import numpy as np
import pandas as pd
from pandas import DataFrame
from tqdm import tqdm
from collections import defaultdict
import pickle
import itertools
import copy
import random
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

In [3]:
from sklearn.linear_model import LinearRegression
from sklearn.neural_network import MLPRegressor
from sklearn import metrics
import joblib
from joblib import Parallel, delayed

In [4]:
import scipy.sparse

## 21-dimensional state features

In [5]:
nS, nA = 1442, 8
d = 21

In [6]:
def get_state_action_feature(x_s, a):
    x_sa = np.zeros((nA, d))
    x_sa[a, :] = x_s
    return x_sa.flatten()

In [7]:
def make_features_single_trajectory(df_i):
    # Initial timestep
    s_init = df_i.iloc[0, 7:28].values
    x_s_init = np.array(s_init)
    xa_s_init_all = np.array([get_state_action_feature(x_s_init, a_) for a_ in range(nA)])

    # Intermediate timestep
    if len(df_i) > 1:
        s = df_i.iloc[:-1, 7:28].values
        a = df_i.iloc[:-1]['Action'].values
        r = df_i.iloc[:-1]['Reward'].values
        s_next = df_i.iloc[1:, 7:28].values

        n = len(s)
        x_s = np.array(s)
        xa_sa = np.array([get_state_action_feature(x_s[j, :], a[j]) for j in range(n)])

        x_s_next = np.array(s_next)
        xa_s_next_all = np.vstack([
            np.vstack([get_state_action_feature(x_s_next[j], a_) for a_ in range(nA)]) 
            for j in range(n)
        ])
    else:
        x_s = np.array((0, d))
        a = np.zeros((0), dtype=int)
        xa_sa = np.array((0, d*nA))
        r = np.zeros((0))
        x_s_next = np.array((0, d))
        xa_s_next_all = np.array((0, d*nA))

    # Final timestep
    s_last = df_i.iloc[-1, 7:28].values
    a_last = df_i.iloc[-1]['Action']
    r_last = df_i.iloc[-1]['Reward']
    if r_last == -1 or r_last == 1:
        # Reached death/disch states
        # every action leads to reward
        x_s_last = np.array(s_last)
        xa_s_last_all = np.array([get_state_action_feature(x_s_last, a_) for a_ in range(nA)])
        r_last_all = np.array(nA * [r_last])

        xa_out = np.vstack([xa_sa, xa_s_last_all])
        xa_next_out = np.vstack([xa_s_next_all, np.zeros((nA*nA, nA*d))])
        r_out = np.concatenate([r, r_last_all])

        a_out = np.concatenate([a, (list(range(nA)))])
        x_out = np.vstack([x_s, *(nA*[x_s_last])])
        x_next_out = np.vstack([x_s_next, np.zeros((nA, d))])
    else: 
        # terminated early due to max length, so no next state information
        xa_out = xa_sa
        xa_next_out = xa_s_next_all
        r_out = r

        x_out = x_s
        a_out = a
        x_next_out = x_s_next
    
    return x_s_init, xa_s_init_all, x_out, a_out, xa_out, r_out, x_next_out, xa_next_out

In [8]:
for it in runs:
    df_features = pd.read_csv('{}/{}-features.csv'.format(output_dir, it))
    out = [make_features_single_trajectory(df_i) for i, df_i in tqdm(df_features.groupby('pt_id'))]
    X_init, Xa_init, X, A, Xa, R, X_next, Xa_next = zip(*out)
    X_init = np.vstack(X_init)
    Xa_init = np.vstack(Xa_init)
    X = np.vstack(X)
    Xa = np.vstack(Xa)
    A = np.concatenate(A)
    R = np.concatenate(R)
    X_next = np.vstack(X_next)
    Xa_next = np.vstack(Xa_next)
    print(Xa_init.shape, Xa.shape, Xa_next.shape, R.shape)
    print(X_init.shape, X.shape, A.shape, R.shape, X_next.shape)

    # Store indices of beginning of each episode
    lengths = [len(x_i) for x_i in list(zip(*out))[2]]
    inds_init = np.cumsum([0] + lengths)

    joblib.dump({
        'X_init': X_init, 'X': X, 'A': A, 'R': R, 'X_next': X_next, 
        'Xa_init': Xa_init, 'Xa': Xa, 'Xa_next': Xa_next,
        'lengths': lengths, 'inds_init': inds_init,
    }, '{}/{}-21d-feature-matrices.joblib'.format(output_dir, it))

    joblib.dump({
        'X_init': scipy.sparse.csr_matrix(X_init), 'X': scipy.sparse.csr_matrix(X), 'A': A, 'R': R, 'X_next': scipy.sparse.csr_matrix(X_next), 
        'Xa_init': scipy.sparse.csr_matrix(Xa_init), 'Xa': scipy.sparse.csr_matrix(Xa), 'Xa_next': scipy.sparse.csr_matrix(Xa_next),
        'lengths': lengths, 'inds_init': inds_init,
    }, '{}/{}-21d-feature-matrices.sparse.joblib'.format(output_dir, it))

100%|██████████| 100000/100000 [03:01<00:00, 551.93it/s]


(800000, 168) (1503553, 168) (12028424, 168) (1503553,)
(100000, 21) (1503553, 21) (1503553,) (1503553,) (1503553, 21)


100%|██████████| 100000/100000 [03:03<00:00, 546.31it/s]


(800000, 168) (1505450, 168) (12043600, 168) (1505450,)
(100000, 21) (1505450, 21) (1505450,) (1505450,) (1505450, 21)


100%|██████████| 100000/100000 [02:57<00:00, 564.22it/s]


(800000, 168) (1504864, 168) (12038912, 168) (1504864,)
(100000, 21) (1504864, 21) (1504864,) (1504864,) (1504864, 21)


100%|██████████| 100000/100000 [03:01<00:00, 551.94it/s]


(800000, 168) (1505736, 168) (12045888, 168) (1505736,)
(100000, 21) (1505736, 21) (1505736,) (1505736,) (1505736, 21)


100%|██████████| 100000/100000 [03:00<00:00, 554.43it/s]


(800000, 168) (1502535, 168) (12020280, 168) (1502535,)
(100000, 21) (1502535, 21) (1502535,) (1502535,) (1502535, 21)


100%|██████████| 100000/100000 [03:00<00:00, 553.38it/s]


(800000, 168) (1503469, 168) (12027752, 168) (1503469,)
(100000, 21) (1503469, 21) (1503469,) (1503469,) (1503469, 21)


100%|██████████| 100000/100000 [02:59<00:00, 557.31it/s]


(800000, 168) (1503154, 168) (12025232, 168) (1503154,)
(100000, 21) (1503154, 21) (1503154,) (1503154,) (1503154, 21)


100%|██████████| 100000/100000 [02:59<00:00, 556.80it/s]


(800000, 168) (1503812, 168) (12030496, 168) (1503812,)
(100000, 21) (1503812, 21) (1503812,) (1503812,) (1503812, 21)


100%|██████████| 100000/100000 [02:59<00:00, 556.36it/s]


(800000, 168) (1505702, 168) (12045616, 168) (1505702,)
(100000, 21) (1505702, 21) (1505702,) (1505702,) (1505702, 21)


100%|██████████| 100000/100000 [02:59<00:00, 556.90it/s]


(800000, 168) (1501870, 168) (12014960, 168) (1501870,)
(100000, 21) (1501870, 21) (1501870,) (1501870,) (1501870, 21)
