# Import

In [None]:
import gym
from gym import spaces
import numpy as np
import matplotlib.pyplot as plt

from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env

# Génération des patients (dataset fictif)

In [None]:
N_PATIENTS = 1000

np.random.seed(42)

ages = np.random.randint(18, 90, size=N_PATIENTS)
weights = np.random.randint(40, 120, size=N_PATIENTS)

patients = np.column_stack([ages, weights])

# Dose cible

In [None]:
def target_dose(age, weight):
    return 0.5 * weight + 0.2 * age


In [None]:
target_doses = target_dose(ages, weights)

plt.figure()
plt.hist(target_doses, bins=30)
plt.xlabel("Dose cible (mg)")
plt.ylabel("Nombre de patients")
plt.title("Distribution des doses cibles")
plt.show()


# Env gym

In [None]:
class DrugDoseEnv(gym.Env):
    def __init__(self, patients):
        super().__init__()

        self.patients = patients
        self.n_patients = len(patients)

        # Observation : âge + poids normalisés
        self.observation_space = spaces.Box(
            low=0.0, high=1.0, shape=(2,), dtype=np.float32
        )

        # Action : dose entre 0 et 100 mg
        self.action_space = spaces.Box(
            low=0.0, high=100.0, shape=(1,), dtype=np.float32
        )

        self.current_patient = None

    def reset(self):
        idx = np.random.randint(0, self.n_patients)
        self.current_patient = self.patients[idx]

        age, weight = self.current_patient
        obs = np.array([age / 100, weight / 150], dtype=np.float32)
        return obs

    def step(self, action):
        age, weight = self.current_patient
        dose = action[0]

        optimal = target_dose(age, weight)

        error = abs(dose - optimal)

        reward = -error  

        done = True

        obs = np.array([age / 100, weight / 150], dtype=np.float32)
        info = {
            "dose": dose,
            "optimal_dose": optimal,
            "error": error
        }

        return obs, reward, done, info


# Verif env

In [None]:
env = DrugDoseEnv(patients)
check_env(env)

# Références

In [None]:
# Baseline : pas de traitement
def baseline_no_treatment(env, n=500):
    rewards = []
    for _ in range(n):
        obs = env.reset()
        _, reward, _, _ = env.step([0.0])
        rewards.append(reward)
    return np.mean(rewards)

In [None]:
# Baseline : dose constante de 50 mg
def baseline_constant(env, dose=50.0, n=500):
    rewards = []
    for _ in range(n):
        obs = env.reset()
        _, reward, _, _ = env.step([dose])
        rewards.append(reward)
    return np.mean(rewards)

In [None]:
baseline_none = baseline_no_treatment(env)
baseline_const = baseline_constant(env)

baseline_none, baseline_const

# Entrainement PPO

In [None]:
model = PPO(
    "MlpPolicy",
    env,
    verbose=1,
    learning_rate=3e-4,
    batch_size=64,
    gamma=0.99
)

model.learn(total_timesteps=20_000)

# Evaluation de l'agent

In [None]:
def evaluate_agent(model, env, n=500):
    rewards = []
    doses = []
    optimal = []

    for _ in range(n):
        obs = env.reset()
        action, _ = model.predict(obs)
        _, reward, _, info = env.step(action)

        rewards.append(reward)
        doses.append(info["dose"])
        optimal.append(info["optimal_dose"])

    return rewards, doses, optimal

In [None]:
ppo_rewards, ppo_doses, optimal_doses = evaluate_agent(model, env)

# Comparaison des rewards

In [None]:
labels = ["No treatment", "Constant dose", "PPO"]
values = [
    baseline_none,
    baseline_const,
    np.mean(ppo_rewards)
]

plt.figure()
plt.bar(labels, values)
plt.ylabel("Reward moyen")
plt.title("Comparaison des stratégies")
plt.show()

# Erreur de dosage

In [None]:
errors = np.abs(np.array(ppo_doses) - np.array(optimal_doses))

plt.figure()
plt.hist(errors, bins=30)
plt.xlabel("Erreur de dosage (mg)")
plt.ylabel("Nombre de patients")
plt.title("Distribution des erreurs PPO")
plt.show()

# Dose prédite vs dose optimale

In [None]:
plt.figure()
plt.scatter(optimal_doses, ppo_doses, alpha=0.5)
plt.plot([0, 100], [0, 100], "--")
plt.xlabel("Dose optimale")
plt.ylabel("Dose PPO")
plt.title("PPO vs vérité terrain")
plt.show()