# Explainable Recurrent PPO: Determining Proper Insulin Dosage

Parsa Youssefpour, Bo Gong, Pak Hop Chan

## Required installations

In [None]:
!pip install stable_baselines3
!pip install sb3_contrib
!pip install simglucose

## Custom Enviroment

In [None]:
from simglucose.simulation.scenario_gen import RandomScenario
from datetime import datetime
from simglucose.simulation.scenario import CustomScenario
from gymnasium.envs.registration import register

now = datetime.now()
start_time = datetime.combine(now.date(), datetime.min.time())

patient_name = [
    "adult#001",
    "adult#002",
    "adult#003",
    "adult#004",
    "adult#005",
]

# Randomized Meal Plan:
scenario = RandomScenario(start_time=start_time, seed=1)

# Custom meal Plan:
# scen = [(7, 45), (12, 70), (16, 15), (18, 80), (23, 10)]
# scenario = CustomScenario(start_time=start_time, scenario=scen)

register(
    id="simglucose_attn",
    entry_point="simglucose.envs:T1DSimGymnaisumEnv",
    max_episode_steps=480,  # 24 hours at 3-min steps
    kwargs={"patient_name": patient_name,
            "custom_scenario": scenario},
)

In [None]:
import gymnasium
from gymnasium.spaces import Box as GymnasiumBox
import numpy as np

# Custom wrapper to modify observation and the action space
class CustomSimglucoseWrapper(gymnasium.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        low = np.array([0.0, 0.0, 0.0, 0.0], dtype=np.float32)  # [CGM, last_insulin, meal, hour]
        high = np.array([2, 1, 4, 23.0/23], dtype=np.float32)
        self.observation_space = GymnasiumBox(low, high)


        self.action_space = GymnasiumBox(
            low=np.array([0.0], dtype=np.float32),
            high=np.array([1], dtype=np.float32)
        )

        self.last_insulin = 0.0

    def reset(self, **kwargs):
        cgm,info = self.env.reset(**kwargs)
        meal = info['meal']
        hour = info['time'].hour
        self.last_insulin = 0.0
        obs = obs = np.array([cgm[0]/300, self.last_insulin/0.1, meal/50, hour/23], dtype=np.float32) #Features Scaled to what worked best in training (not normalized)
        return obs, {}

    def step(self, action):
        action = float(np.clip(action[0], 0.0, 1.0))
        self.last_insulin = action

        cgm, reward, terminated, truncated, info = self.env.step(scaled_action)

        bg = info['bg']
        meal = info['meal']
        hour = info['time'].hour
        obs = np.array([cgm[0]/300, self.last_insulin/0.1, meal/50, hour/23], dtype=np.float32)

        if bg < 54 or bg > 300:
          terminated = True

        return obs, reward, bool(terminated), bool(truncated), info

## Baseline Model (PPO) - Training

In [None]:
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
import os
from stable_baselines3.common.callbacks import EvalCallback
from google.colab import files

#load the Environment
env = CustomSimglucoseWrapper(gymnasium.make("simglucose_attn"))
env = Monitor(env, filename="baseline_PPO")                                     # data for the training reward curve

#Eval_callback to save the best model
eval_callback = EvalCallback(
    env,
    best_model_save_path="./best_model_PPO2/",
    log_path="./logs/",
    eval_freq=5000,
    n_eval_episodes=10,
    deterministic=True,
    render=False
)

#Model Architecture - Separate layers for actor (pi) and critic (vf)
policy_kwargs = dict(
    net_arch=dict(pi=[256, 64, 64], vf=[256, 64, 64]),
)

model = PPO(
    policy="MlpPolicy",
    env=env,
    policy_kwargs=policy_kwargs,
    verbose=1,
    gamma=0.99,
    n_steps=512,
    batch_size=64,
    learning_rate=3e-4,
    ent_coef=0.3,
)

# Train the model
model.learn(total_timesteps=500_000, progress_bar= True, callback= eval_callback)

# Save the model
model.save("baseline_PPO")

## Recurrent PPO Model

In [None]:
from sb3_contrib import RecurrentPPO
from stable_baselines3.common.monitor import Monitor
import os
from stable_baselines3.common.callbacks import EvalCallback

# load the enviroment
env = CustomSimglucoseWrapper(gymnasium.make("simglucose_attn"))
env = Monitor(env, filename="Recurrent_PPO.csv")                                # data for the training reward curve

# Call back to save the best model
eval_callback = EvalCallback(
    env,
    best_model_save_path="./best_model_og/",
    log_path="./logs/",
    eval_freq=5000,
    n_eval_episodes=10,
    deterministic=True,
    render=False
)

# Model Architecture
model = RecurrentPPO(
    policy="MlpLstmPolicy",
    env=env,
    n_steps=512,
    batch_size=64,
    n_epochs=10,
    gamma=0.99,
    gae_lambda=0.95,
    learning_rate=5e-5,
    ent_coef=0.3,
    verbose=1,
    seed=1,
    device="cuda",
    clip_range=0.1,
    max_grad_norm= 0.5
)


# Train the model
model.learn(total_timesteps=500_000, progress_bar= True, callback = eval_callback)

# Save the model
model.save("RPPO")

## Attention Recurrent PPO

In [None]:
from typing import Any, Optional, Union, Dict, List, Tuple

import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.distributions import Distribution
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.torch_layers import (
    BaseFeaturesExtractor,
    FlattenExtractor,
    MlpExtractor,
)
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.utils import zip_strict
from torch import nn

from sb3_contrib.common.recurrent.type_aliases import RNNStates
from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy

class FeatureAttention(nn.Module):
    """
    Get attention and attention features
    dynamic attention weights for each observation.
    """
    def __init__(self, feature_dim: int):
        super().__init__()
        # Dense layer to calculate attention weights
        self.attention_layer = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 2),
            nn.ReLU(),
            nn.Linear(feature_dim // 2, feature_dim),
            nn.Softmax(dim=-1)
        )

        # Store last calculated weights for explainability
        self.last_attention_weights = None

    def forward(self, x: th.Tensor) -> th.Tensor:
        """
        Apply observation-specific feature attention.

        inputs:
            x: Input tensor with shape (batch_size, feature_dim)

        Returns:
            Weighted features with shape (batch_size, feature_dim)
        """
        # Calculate attention weights for each observation in batch
        attention_weights = self.attention_layer(x)

        # print("attn weights:", attention_weights)
        # Store for later retrieval (saving last batch)
        self.last_attention_weights = attention_weights

        # Apply attention weights to features
        weighted_features = x * attention_weights

        return weighted_features

    def get_attention_weights(self, batch_idx: int = 0) -> th.Tensor:
        """
        Get the attention weights for a specific observation in the last batch.

        Input:
            batch_idx: Index of the observation in the batch

        Returns:
            Attention weights tensor
        """
        if self.last_attention_weights is None:
            raise ValueError("No attention weights calculated yet. Run a forward pass first.")

        # Return weights for the specified observation
        return self.last_attention_weights[batch_idx]


class AttentionFeaturesExtractor(BaseFeaturesExtractor):
    """
    Create a richer representation vector to be sent to the Recurrent PPO.
    Changes the type to BaseFeaturesExtractor to be applied in the AttentionLstmPPOPolicy for sb3_contrib Recurrent PPO.

    """
    def __init__(
        self,
        observation_space: spaces.Space,
        features_dim: int = 64,
    ):
        super().__init__(observation_space, features_dim)

        n_input_features = int(np.prod(observation_space.shape))

        # Attention layer for input features
        self.attention = FeatureAttention(n_input_features)

        # Feature processing layers (after attention)
        self.feature_layers = nn.Sequential(
            nn.Linear(n_input_features, features_dim),
            nn.ReLU(),
            nn.Linear(features_dim, features_dim),
            nn.ReLU()
        )

    def forward(self, observations: th.Tensor) -> th.Tensor:
        # Apply attention to raw features first
        attended_features = self.attention(observations)

        # Then process through the feature layers
        features = self.feature_layers(attended_features)

        return features


# taken and modified from: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/sb3_contrib/common/recurrent/policies.py
# Added the attention retrieval

class AttentionLstmPPOPolicy(RecurrentActorCriticPolicy):
    """
    Recurrent policy class for actor-critic algorithms with attention mechanism for explainability.
    To be used with RecurrentPPO.

    :param observation_space: Observation space
    :param action_space: Action space
    :param lr_schedule: Learning rate schedule (could be constant)
    :param net_arch: The specification of the policy and value networks.
    :param activation_fn: Activation function
    :param ortho_init: Whether to use or not orthogonal initialization
    :param use_sde: Whether to use State Dependent Exploration or not
    :param log_std_init: Initial value for the log standard deviation
    :param full_std: Whether to use (n_features x n_actions) parameters
        for the std instead of only (n_features,) when using gSDE
    :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
        a positive standard deviation (cf paper). It allows to keep variance
        above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
    :param squash_output: Whether to squash the output using a tanh function,
        this allows to ensure boundaries when using gSDE.
    :param features_extractor_class: Features extractor to use.
    :param features_extractor_kwargs: Keyword arguments
        to pass to the features extractor.
    :param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
    :param normalize_images: Whether to normalize images or not,
         dividing by 255.0 (True by default)
    :param optimizer_class: The optimizer to use,
        ``th.optim.Adam`` by default
    :param optimizer_kwargs: Additional keyword arguments,
        excluding the learning rate, to pass to the optimizer
    :param lstm_hidden_size: Number of hidden units for each LSTM layer.
    :param n_lstm_layers: Number of LSTM layers.
    :param shared_lstm: Whether the LSTM is shared between the actor and the critic
        (in that case, only the actor gradient is used)
        By default, the actor and the critic have two separate LSTM.
    :param enable_critic_lstm: Use a seperate LSTM for the critic.
    :param lstm_kwargs: Additional keyword arguments to pass the the LSTM
        constructor.
    """

    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        lr_schedule: Schedule,
        net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
        activation_fn: type[nn.Module] = nn.Tanh,
        ortho_init: bool = True,
        use_sde: bool = False,
        log_std_init: float = 0.0,
        full_std: bool = True,
        use_expln: bool = False,
        squash_output: bool = False,
        features_extractor_class: type[BaseFeaturesExtractor] = AttentionFeaturesExtractor,
        features_extractor_kwargs: Optional[dict[str, Any]] = None,
        share_features_extractor: bool = True,
        normalize_images: bool = True,
        optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
        optimizer_kwargs: Optional[dict[str, Any]] = None,
        lstm_hidden_size: int = 256,
        n_lstm_layers: int = 1,
        shared_lstm: bool = False,
        enable_critic_lstm: bool = True,
        lstm_kwargs: Optional[dict[str, Any]] = None,
    ):
        # Initialize with parent class
        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            net_arch,
            activation_fn,
            ortho_init,
            use_sde,
            log_std_init,
            full_std,
            use_expln,
            squash_output,
            features_extractor_class,
            features_extractor_kwargs,
            share_features_extractor,
            normalize_images,
            optimizer_class,
            optimizer_kwargs,
            lstm_hidden_size,
            n_lstm_layers,
            shared_lstm,
            enable_critic_lstm,
            lstm_kwargs,
        )

        # Access the attention layer from the features extractor
        if hasattr(self.features_extractor, 'attention'):
            self.attention = self.features_extractor.attention
        else:
            # Fallback if the features extractor was overridden (just in case)
            self.attention = None

    @property
    def feature_names(self) -> List[str]:
        """
        Return the feature names for explainability.
        This should be customized based on your environment.
        used @property decorator to turn into read only attribute. Ensuring no accidental changes/overwriting.

        """

        return ["CGM", "Last_Insulin", "Meal", "Hour"]

    def get_attention_weights(self) -> th.Tensor:
        """
        Get the attention weights from the feature attention.
        """
        if self.attention is None:
            raise ValueError("Attention mechanism not available in this policy.")
        return self.attention.get_attention_weights()

    def explain(self, obs: np.ndarray) -> Dict[str, float]:
        """
        Return feature importance as a dictionary.

        Inputs:
            obs: Observation (state)

        Returns:
            Dictionary mapping feature names to their importance scores
        """
        # Convert observation to tensor
        device = self.device
        obs_tensor = th.as_tensor(obs).float().to(device)

        # Ensure proper shape
        if len(obs_tensor.shape) == 1:
            obs_tensor = obs_tensor.unsqueeze(0)

        # Just extract features to trigger attention computation
        with th.no_grad():
            # Use the extract_features method directly without requiring LSTM states
            _ = self.features_extractor(obs_tensor)

        # Get attention weights
        if not hasattr(self, 'attention') or self.attention is None:
            raise ValueError("This policy doesn't have an attention mechanism")

        attention_weights = self.attention.get_attention_weights().detach().cpu().numpy()
        # print("attn Weights:", attention_weights)

        # Matching weights with feature names
        feature_names = self.feature_names

        # Handling potental dimension mismatch
        if len(feature_names) !=len(attention_weights):
            print(f"Warning: feature_names length ({len(feature_names)}) doesn't match "
                  f"attention_weights length ({len(attention_weights)})")
            # Truncate the longer one
            min_len = min(len(feature_names), len(attention_weights))
            feature_names = feature_names[:min_len]
            attention_weights = attention_weights[:min_len]

        feature_importance = dict(zip(feature_names, attention_weights.flatten()))

        return feature_importance


# Modified R-PPO architecture kwargs
def create_attention_lstm_policy():
    policy_kwargs = dict(
        features_extractor_class=AttentionFeaturesExtractor,
        features_extractor_kwargs=dict(features_dim=64),
        net_arch=dict(
            pi=[64, 64],
            vf=[64, 64]
        ),
    )

    return AttentionLstmPPOPolicy, policy_kwargs

### Train the Model

In [None]:
# Decaying Entropy Coefficient - Tested but NOT USED in Final model
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback

class EntropyDecayCallback(BaseCallback):
    def __init__(self, initial_value=0.5, final_value=0.01, max_timesteps=1_000_000, verbose=0):
        super().__init__(verbose)
        self.initial_value = initial_value
        self.final_value = final_value
        self.max_timesteps = max_timesteps

    def _on_step(self):
        progress = min(1.0, self.num_timesteps / self.max_timesteps)
        new_ent_coef = self.initial_value * (1 - progress) + self.final_value * progress
        self.model.ent_coef = new_ent_coef
        return True

In [None]:
from sb3_contrib import RecurrentPPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import EvalCallback
from google.colab import files

# Create your environment
env = CustomSimglucoseWrapper(gymnasium.make("simglucose_attn"))
env = Monitor(env, filename="Attention_RPPO")


# Call back to save the best model
eval_callback = EvalCallback(
    env,
    best_model_save_path="./best_model_A_RPPO/",
    log_path="./logs/",
    eval_freq=5000,
    n_eval_episodes=10,
    deterministic=True,
    render=False
)

# Get the custom policy and kwargs
policy_class, policy_kwargs = create_attention_lstm_policy()


# Create the model with custom policy
model = RecurrentPPO(
    policy=policy_class,
    env=env,
    policy_kwargs=policy_kwargs,
    n_steps=512,
    batch_size=64,
    n_epochs=10,
    gamma=0.99,
    gae_lambda=0.95,
    learning_rate=5e-5,
    ent_coef=0.3,
    verbose=1,
    seed=1,
    device="cuda",
    clip_range=0.1,
    max_grad_norm= 0.5
)

# Train the model
model.learn(total_timesteps=500_000, progress_bar= True, callback= eval_callback)

# Save the model
model.save("Attention_RPPO")

# Download the model file
files.download('Attention_RPPO.zip')