In [60]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [61]:
import numpy as np
from cancer import CancerSim, PolicyCancer

In [62]:
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import pickle
from sklearn.neighbors import NearestNeighbors

In [68]:
def collect_data(n, ttype="PCV", penalty=1.0):
    episodes_per_policy = 10
    env = CancerSim(dose_penalty=penalty, therapy_type=ttype, env_seed=None, max_steps=30, state_dim=5, reward_type="dense")

    dataset = {'observations': np.zeros((n, env.max_steps, env.state_dim)),
               'actions': np.zeros((n, env.max_steps, 1), np.int),
               'rewards': np.zeros((n, env.max_steps, 1)),
               'sparse_rewards': np.zeros((n, env.max_steps, 1)),
               'dense_rewards': np.zeros((n, env.max_steps, 1)),
               'not_done': np.zeros((n, env.max_steps, 1)),
               'pibs': np.zeros((n, env.max_steps, env.num_actions)),
               'nn_action_dist': np.ones((n, env.max_steps, env.num_actions)) * 1e9,
               }

    for i in range(n):
        if (i % episodes_per_policy) == 0:
            policy = PolicyCancer(months_for_treatment=9, eps_behavior=0.3)

        env = CancerSim(dose_penalty=penalty, therapy_type=ttype, env_seed=None, max_steps=30, state_dim=5, reward_type="dense")
        obs = env.reset()
        t = 0
        done = False
        while not done:
            a = policy(obs, t)
            new_obs, rt, done, _ = env.step(a)

            dataset['observations'][i, t, :] = obs
            dataset['actions'][i, t, :] = a
            dataset['rewards'][i, t, :] = rt
            dataset['not_done'][i, t, :] = float(1-done)
            dataset['pibs'][i, t, :] = policy.return_probs(obs, t)
            
            obs = new_obs
            t += 1

    return dataset

In [69]:
def process_data(dataset1, dataset2=None):
    if dataset2 is None:
        dataset2 = dataset1
    n = dataset1['observations'].shape[0]
    horizon = dataset1['observations'].shape[1]
    state_dim = dataset1['observations'].shape[-1]
    num_actions = dataset1["pibs"].shape[-1]
    
    trees = dict()
    for a in range(num_actions):
        trees[a] = NearestNeighbors(n_neighbors=1)
        is_action = (dataset2['actions'][:, :, 0] == a)
        trees[a].fit(dataset2["observations"][is_action, :])
    
    X = dataset1["observations"].reshape(-1,state_dim)
    for a in range(num_actions):
        dists = (trees[a].kneighbors(X)[0]**2)/state_dim
        dataset1['nn_action_dist'][:, :, a] = dists.reshape(n, horizon)
    return dataset1

In [79]:
for i in range(1,6):
    dataset_pcv_high = collect_data(1000)
    dataset_pcv_high = process_data(dataset_pcv_high)
    
    dataset_pcv_high_val = collect_data(1000)
    dataset_pcv_high_val = process_data(dataset_pcv_high_val, dataset_pcv_high)
    
    with open(f"../../data/cancer_mdp_pcv{i}_train_episodes", 'wb') as f:
        pickle.dump(dataset_pcv_high, f)
    
    with open(f"../../data/cancer_mdp_pcv{i}_val_episodes", 'wb') as f:
        pickle.dump(dataset_pcv_high_val, f)

In [77]:
for i in range(1,6):
    dataset_pcv_high = collect_data(10000)
    dataset_pcv_high = process_data(dataset_pcv_high)
    
    dataset_pcv_high_val = collect_data(10000)
    dataset_pcv_high_val = process_data(dataset_pcv_high_val, dataset_pcv_high)
    
    with open(f"../../data/cancer_mdp_pcv{i}_train_episodes", 'wb') as f:
        pickle.dump(dataset_pcv_high, f)
    
    with open(f"../../data/cancer_mdp_pcv{i}_val_episodes", 'wb') as f:
        pickle.dump(dataset_pcv_high_val, f)

In [76]:
dataset_pcv_high["rewards"].sum(axis=1).mean()

-8.747827544294084

In [44]:
dataset_pcv_high = collect_data(1000)
dataset_pcv_high = process_data(dataset_pcv_high)

In [36]:
dataset_pcv_high["observations"].reshape(-1,3).max(axis=0)

array([ 3.18246399, 52.48354535, 29.        ])

In [57]:
dataset_pcv_high["observations"].reshape(-1,5).max(axis=0)

array([ 3.14297333,  7.83364022, 45.23961636, 30.21838682, 29.        ])

In [55]:
(dataset_pcv_high["nn_action_dist"].reshape(-1,2) > 0.05).mean()

0.05121666666666667

In [56]:
(dataset_pcv_high["nn_action_dist"].reshape(-1,2) > 0.5).mean()

0.017