# Offline RL Pipeline
This notebook contains the full pipeline described in "Optimizing Loop Diuretic Treatment in Hospitalized Patients: A Case Study in Practical Application of Offline Reinforcement Learning to Healthcare". The code provided here is a psuedocode designed to demonstrate the full pipeline and cannot be executed on its own. 

## 1. Data Partition
- The development data is randomly split at the trajectory level into X partitions of training/validation splits

In [None]:
for split_id in range(X):
    new_val_ids = random.sample(combined_ids, 7289)
    new_train_ids = list(set(combined_ids) - set(new_val_ids))

    df_tr = df_dev[df_dev.pt_id.isin(new_train_ids)]
    df_va = df_dev[df_dev.pt_id.isin(new_val_ids)]

    df_tr.to_csv('./data_splits/split_{:03d}_tr.csv'.format(split_id), index = False)
    df_va.to_csv('./data_splits/split_{:03d}_va.csv'.format(split_id), index = False)

## 2. Defining the State Space
- A set of candidate state definitions are generated by varying the data partition and number of states.

### 2.1. Creating the embeddings
- For each data partition, learn a embedding model.
- Then use the embedding models to create X embeddings for each data partition. 

In [None]:
for state_id in range(X):
    embedding_model = torch.load('state_{:03d}_embedding_model.pt'.format(state_id))
    for split_id in range(X):
        df_tr = pd.read_csv('./data_splits/split_{:03d}_tr.csv'.format(split_id))
        df_va = pd.read_csv('./data_splits/split_{:03d}_va.csv'.format(split_id))
        df_te = pd.read_csv('./data_splits/split_{:03d}_te.csv'.format(split_id))
        
        tr_embedded = embedding_model(df_tr)
        va_embedded = embedding_model(df_va)
        te_embedded = embedding_model(df_te)
        
        tr_embedded.to_pickle('./data_splits/split_{:03d}/state_{:03d}_tr_embedded.p'.format(split_id, state_id))
        va_embedded.to_pickle('./data_splits/split_{:03d}/state_{:03d}_va_embedded.p'.format(split_id, state_id))
        te_embedded.to_pickle('./data_splits/split_{:03d}/state_{:03d}_te_embedded.p'.format(split_id, state_id))

### 2.2. Creating the discrete clustering solutions, and cluster all embeddings
- Cluster the embedded training data (for convenience, we use the cases where split_id = state_id) and using ensemble k-means clustering.
- Relevant code is in `pipeline/1_run_kmeans.py`

In [None]:
for state_id in range(X):
    tr_embedded = pd.read_pickle('./data_splits/split_{:03d}/state_{:03d}_tr_embedded.p'.format(state_id, state_id))
    for k in [list_of_k]:
        # create clustering solutions
        run_ensemble_kmeans(tr_embedded, save_dir, k, E = 150)

In [None]:
for state_id in range(X):
    for k in [list_of_k]:
        kmeans_model = joblib.load('./kmeans_model/state_{:03d}_k_{}_model.pkl'.format(state_id, k))
        for split_id in range(X):
            tr_embedded = pd.read_pickle('./data_splits/split_{:03d}/state_{:03d}_tr_embedded.p'.format(state_id, state_id))
            # predict the clusters for each embedding
            predict_clusters_for_embedded_data(tr_embedded, kmeans_model, k)

### 2.3. Using the clustering results, create the trajectory files
- Convert the embedding data into trajectory files in the format of `data/ens.csv`

## 3. Estimate the Behavior Policy
- `compute_behavior_policy()` from `pipeline/OPE_utils.py`

In [None]:
for split_id in range(X):
    for state_id in range(X):
        for k in [list_of_k]:
            
            df_va = pd.read_csv('./data_splits/split_{:03d}/state_{:03d)_k_{}/val_ens.csv'.format(split_id, state_id, k))
            pi_b_va = compute_behavior_policy(df_va)

## 4. Training the RL Policy

### 4.1. Create the environment
- Learn the transition and reward functions from the environment. 
- Use functions in `2_run_environment.py`

In [None]:
threshold_list = [0.1, 0.15, 0.2, 0.3]

for split_id in range(X):
    for state_id in range(X):
        for k in [list_of_k]:
            df_tr = pd.read_csv('./data_splits/split_{:03d}/state_{:03d)_k_{}/train_ens.csv'.format(split_id, state_id, k))
            # Functions from 2_run_environment.py. df_tr will be passed to each of these functions
            # Refer to that for more information.
            create_transitions()
            create_dictionary()
            create_uncertainty_transitions()
            create_pMDP_dictionary()

### 4.2 Train the policies
- Use functions in `3_run_train_policy.py`

In [None]:
for split_id in range(X):
    for state_id in range(X):
        for k in [list_of_k]:
            # For each train_ens.csv, iterate through all the parameters for training the policy
            # Refer to functions in 3_run_train_policy.py for more information
            generate_Q_masks()
            train_policy()

## 5. Final Policy Selection and Evaluation
- Evaluate all the policies for each `train_ens.csv` using functions in `4_run_evaluation.py`
- Aggregate the results and select the best hyperparameter (not shown here)

In [None]:
for split_id in range(X):
    for state_id in range(X):
        for k in [list_of_k]:
            # For each train_ens.csv, evaluate all the policies generated using that training data
            evaluate_policy_for_each_setting(df_va, df_te, main_dir, k)