In [1]:

import minari
import numpy as np
import gymnasium as gym
from PIL import Image
from minari import DataCollector
import torch
import torch.nn as nn
import tqdm
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
# from torchvision.models import resnet50 # No longer using resnet50
from torchvision.models import resnet18, ResNet18_Weights # Import ResNet-18
from sklearn.model_selection import train_test_split
from copy import deepcopy
import d3rlpy

import numpy as np
import pandas as pd
from copy import deepcopy

from simulation import cancer
from generate import simulate_blackwell_glynn
from nsmm import nsmm_lag1, nsmm_lag1_cate
from msm import MarginalStructuralModel


Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:

np.random.seed(100)

num_time_steps = 60  # 6 month followup
num_patients = 1000

simulation_params = cancer.get_confounding_params(num_patients, chemo_coeff=10.0, radio_coeff=10.0)
simulation_params['window_size'] = 15

outputs = cancer.simulate(simulation_params, num_time_steps)

  if recovery_rvs[i, t] < np.exp(-cancer_volume[i, t] * tumour_cell_density):


In [3]:
# Assuming df is your dataframe
def prepare_data_for_outcome(df, drug_half_life = 1):
    df = df.copy()
    df['chemo_dosage'] = 0.0
    df['previous_cancer_volume'] = df['cancer_volume']
    for pid, group in df.groupby('Patient_ID'):
        group = group.sort_values('Time_Point')
        chemo_instant_dosage = group['chemo_instant_dosage']
        previous_chemo_dose = group['chemo_instant_dosage'].shift(1)
        previous_cancer_volume = group['cancer_volume'].shift(1)
        chemo_dosages = cancer.get_chemo_dosage(chemo_instant_dosage, previous_chemo_dose, drug_half_life)
        

        df.loc[group.index, 'chemo_dosage'] = chemo_dosages
        df.loc[group.index, 'previous_cancer_volume'] = previous_cancer_volume

        df['termination'] = 0
        df.loc[group.index[-1], 'termination'] = 1

    return df.dropna().reset_index(drop=True)

# Example usage
df = prepare_data_for_outcome(outputs)

# drop some row since it need lag data
n_time = int(len(df)/num_patients)

In [13]:

state = df[['previous_cancer_volume']].to_numpy()
action = df[['radio_dosage', 'chemo_instant_dosage']].to_numpy()
reward = df[['previous_cancer_volume']].to_numpy()-df[['cancer_volume']].to_numpy()
next_state = df[['cancer_volume']].to_numpy()
terminations = df[['termination']].to_numpy()

# Assuming all arrays are already defined and have the same length
dataset = list(zip(state, action, reward, next_state, terminations))

In [48]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim=2, hidden_dim=256):
        super(QNetwork, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # Q-value scalar
        )

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        return self.net(x)  # returns Q-value for (s, a)


class OfflineRL:
    def __init__(
        self,
        state_dim,
        gamma=0.99,
        lr=1e-3,
        batch_size=64,
        device="cpu"
    ):
        self.state_dim = state_dim
        self.action_dim = 2  # fixed 2D binary action
        self.gamma = gamma
        self.batch_size = batch_size
        self.device = torch.device(device)

        self.q_net = QNetwork(state_dim).to(self.device)
        self.target_q_net = QNetwork(state_dim).to(self.device)
        self.target_q_net.load_state_dict(self.q_net.state_dict())
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)

        self.replay_buffer = []

    def load_dataset(self, transitions):
        """
        transitions: list of (state, action_vec[2], reward, next_state, done)
        """
        self.replay_buffer = transitions

    def sample_batch(self):
        batch = random.sample(self.replay_buffer, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.tensor(states, dtype=torch.float32, device=self.device)
        actions = torch.tensor(actions, dtype=torch.float32, device=self.device)
        rewards = torch.tensor(rewards, dtype=torch.float32, device=self.device).unsqueeze(1)
        next_states = torch.tensor(next_states, dtype=torch.float32, device=self.device)
        dones = torch.tensor(dones, dtype=torch.float32, device=self.device).unsqueeze(1)

        return states, actions, rewards, next_states, dones

    def all_binary_actions(self):
        return torch.tensor([
            [0, 0],
            [0, 1],
            [1, 0],
            [1, 1]
        ], dtype=torch.float32, device=self.device)

    def train(self, epochs=100):
        for epoch in range(epochs):
            losses = []
            for _ in range(len(self.replay_buffer) // self.batch_size):
                states, actions, rewards, next_states, dones = self.sample_batch()

                with torch.no_grad():
                    next_q_values = []
                    all_actions = self.all_binary_actions()
                    for a in all_actions:
                        repeated_a = a.unsqueeze(0).expand(next_states.size(0), -1)
                        q_val = self.target_q_net(next_states, repeated_a)
                        next_q_values.append(q_val)
                    next_q_values = torch.cat(next_q_values, dim=1)
                    max_next_q = next_q_values.max(dim=1, keepdim=True)[0]

                    target = rewards + self.gamma * (1 - dones) * max_next_q

                q_pred = self.q_net(states, actions)

                loss = nn.MSELoss()(q_pred, target)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                losses.append(loss.item())

            self.update_target_network()
            print(f"Epoch {epoch + 1}, Loss: {np.mean(losses):.4f}")

    def update_target_network(self, tau=0.005):
        for param, target_param in zip(self.q_net.parameters(), self.target_q_net.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

    def predict_action(self, state):
        state = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
        all_actions = self.all_binary_actions()
        q_vals = []

        with torch.no_grad():
            for a in all_actions:
                a_input = a.unsqueeze(0).expand(state.size(0), -1)
                q = self.q_net(state, a_input)
                q_vals.append(q.item())

        best_idx = np.argmax(q_vals)
        return all_actions[best_idx].cpu().numpy()


In [None]:

sac = OfflineRL(state.shape[1], action.shape[1], batch_size= 3)
sac.load_dataset(dataset)  # Load the dataset into the replay buffer
sac.train(epochs=2)


  return F.mse_loss(input, target, reduction=self.reduction)


In [None]:
a = sac.predict(X)

In [31]:
len(a)

59000

In [None]:

sac = d3rlpy.algos.DiscreteCQLConfig().create(device="cuda:0")

# train offline
sac.fit(dataset, n_steps=50000)


# --- Episode Runner Class (Interacting with Gym Environment) ---
class EpisodeRunner:
    def __init__(self, env, observation_transform=None):
        self.env = env
        self.observation_transform = observation_transform

    def run_episode(self, initial_observation_raw, decision_policy_fn, model_estimator=None, use_random_policy=False):
        episode_over = False
        truncated = False
        terminated = False
        total_reward = 0
        
        current_observation_raw = initial_observation_raw

        while not (terminated or truncated):
            current_observation_raw = np.expand_dims(current_observation_raw, axis=0)
            if use_random_policy:
                action = self.env.action_space.sample()
            else:
                assert model_estimator is not None, "A model estimator must be provided for non-random actions"
                # Apply observation transform for the model
                if self.observation_transform is not None:

                    transformed_obs = self.observation_transform(current_observation_raw)
                else: # Should not happen if model is used
                    transformed_obs = torch.tensor(current_observation_raw, dtype=torch.float32).unsqueeze(0)

                action = decision_policy_fn(transformed_obs, model_estimator)
            
            next_observation_raw, reward, terminated, truncated, info = self.env.step(action)
            total_reward += reward
            current_observation_raw = next_observation_raw
            
            if terminated or truncated:
                episode_over = True
        
        return total_reward

def gym_observation_transform(observation_np_array): # Raw observation from env.step() or env.reset()

    return np.transpose(observation_np_array, (0, 3, 1, 2))

def decision_function_dl(obs_tensor, model_instance): # obs_tensor is already preprocessed

    action = model_instance.predict(obs_tensor) # obs_tensor should be [1, C, H, W]

    return action[0]

# --- Environment Evaluation ---
print("\nStarting environment evaluation...")
total_reward_model_policy = 0
total_reward_random_policy = 0


env  = minari_dataset.recover_environment() 



print(f"Running {N_EVAL_EPISODES} episodes for model policy and random policy...")
for i in tqdm.tqdm(range(N_EVAL_EPISODES), desc="Evaluating Policies"):
    # Model Policy

    obs_model, info_model = env.reset()

    env_model = deepcopy(env)
    env_random = deepcopy(env)

    runner_model = EpisodeRunner(env_model, observation_transform=gym_observation_transform)
    reward_model = runner_model.run_episode(
        obs_model,
        decision_policy_fn=decision_function_dl,
        model_estimator=sac,
        use_random_policy=False
    )
    total_reward_model_policy += reward_model


    runner_random = EpisodeRunner(env_random) # No transform needed for random
    reward_random = runner_random.run_episode(
        obs_model,
        decision_policy_fn=None, # Not used for random
        use_random_policy=True
    )
    total_reward_random_policy += reward_random

avg_model_reward = total_reward_model_policy / N_EVAL_EPISODES if N_EVAL_EPISODES > 0 else 0
avg_random_reward = total_reward_random_policy / N_EVAL_EPISODES if N_EVAL_EPISODES > 0 else 0

print(f"\n--- Final Evaluation Results ({N_EVAL_EPISODES} episodes) ---")
print(f"Average Total Reward (Model Policy): {avg_model_reward:.2f}")
print(f"Average Total Reward (Random Policy): {avg_random_reward:.2f}")

env.close()

print("Evaluation complete.")