In [None]:

import os
import sys
import random
import json
import warnings
from pathlib import Path
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass

# Data manipulation
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Reinforcement Learning
import gym
from gym import spaces
from stable_baselines3 import PPO, A2C, DQN
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import BaseCallback

# Model interpretability
try:
    import shap
    SHAP_AVAILABLE = True
except ImportError:
    SHAP_AVAILABLE = False
    print("⚠️  SHAP not available - interpretability features disabled")


In [None]:
# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
random.seed(SEED)

# Matplotlib styling
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Configuration dataclass
@dataclass
class Config:
    """IMPROVED configuration parameters"""
    seed: int = 42
    train_split: float = 0.70
    val_split: float = 0.15
    test_split: float = 0.15
    min_episode_length: int = 3
    training_timesteps: int = 150000  # INCREASED for better learning
    eval_episodes: int = 20
    history_window: int = 5
    
    # NEW: Training hyperparameters
    learning_rate: float = 0.0003
    gamma: float = 0.99
    batch_size: int = 64
    
config = Config()

print("✅ All libraries imported successfully!")
print(f"📊 Configuration: Train={config.train_split}, Val={config.val_split}, Test={config.test_split}")

In [None]:
# ==============================================================================
# SECTION 2: LAB ITEM METADATA DEFINITION
# PURPOSE: Define standardized metadata for key MIMIC-III lab tests
# ==============================================================================

LAB_ITEM_INFO = {
    50931: {  # Glucose
        'name': 'Glucose',
        'normal_range': (70.0, 100.0),
        'unit': 'mg/dL',
        'critical': True,
        'description': 'Blood sugar level'
    },
    50912: {  # Creatinine
        'name': 'Creatinine',
        'normal_range': (0.6, 1.2),
        'unit': 'mg/dL',
        'critical': True,
        'description': 'Kidney function marker'
    },
    50902: {  # Chloride
        'name': 'Chloride',
        'normal_range': (98.0, 106.0),
        'unit': 'mEq/L',
        'critical': False,
        'description': 'Electrolyte balance'
    },
    50882: {  # Bicarbonate
        'name': 'Bicarbonate',
        'normal_range': (22.0, 29.0),
        'unit': 'mEq/L',
        'critical': False,
        'description': 'Acid-base balance'
    },
    50971: {  # Potassium
        'name': 'Potassium',
        'normal_range': (3.5, 5.0),
        'unit': 'mEq/L',
        'critical': True,
        'description': 'Critical electrolyte affecting heart rhythm'
    },
    50983: {  # Sodium
        'name': 'Sodium',
        'normal_range': (136.0, 145.0),
        'unit': 'mEq/L',
        'critical': True,
        'description': 'Primary extracellular electrolyte'
    },
    51006: {  # Urea Nitrogen (BUN)
        'name': 'Urea Nitrogen',
        'normal_range': (7.0, 20.0),
        'unit': 'mg/dL',
        'critical': False,
        'description': 'Protein metabolism marker; indicates kidney function'
    },
    50868: {  # Anion Gap
        'name': 'Anion Gap',
        'normal_range': (8.0, 16.0),
        'unit': 'mEq/L',
        'critical': False,
        'description': 'Calculated marker for acid-base balance'
    },
    51265: {  # Platelet Count
        'name': 'Platelet Count',
        'normal_range': (150.0, 400.0),
        'unit': 'K/uL',
        'critical': False,
        'description': 'Clotting ability indicator'
    },
    51221: {  # Hematocrit
        'name': 'Hematocrit',
        'normal_range': (36.0, 46.0),
        'unit': '%',
        'critical': False,
        'description': 'Proportion of red blood cells in blood'
    },
    51301: {  # WBC
        'name': 'WBC',
        'normal_range': (4.5, 11.0),
        'unit': 'K/uL',
        'critical': True,
        'description': 'White blood cell count — infection or inflammation marker'
    },
    51222: {  # Hemoglobin
        'name': 'Hemoglobin',
        'normal_range': (13.5, 17.5),
        'unit': 'g/dL',
        'critical': True,
        'description': 'Oxygen-carrying capacity of blood'
    }
}

print(f"✅ Loaded metadata for {len(LAB_ITEM_INFO)} lab tests")
print("\n📋 Lab Tests:")
for itemid, info in LAB_ITEM_INFO.items():
    print(f"  {itemid}: {info['name']} "
          f"(Normal: {info['normal_range'][0]}–{info['normal_range'][1]} {info['unit']}) "
          f"| Critical: {info['critical']}")


In [None]:
# SECTION 3: DATA LOADING FROM MIMIC-III
# PURPOSE: Load and perform initial filtering of lab events data
# ===============================================================================

# Path to MIMIC-III data (adjust based on your setup)
DATA_PATH = "/kaggle/input/mimiciii/mimic-iii-clinical-database-demo-1.4"

print("\n" + "="*80)
print("📂 LOADING MIMIC-III DATA")
print("="*80)

# Load lab events
print("Loading LABEVENTS.csv...")
labevents = pd.read_csv(f"{DATA_PATH}/LABEVENTS.csv")
print(f"✅ Loaded {len(labevents):,} raw lab measurements")

# Select relevant columns
labevents = labevents[['subject_id', 'hadm_id', 'itemid', 'charttime', 'valuenum', 'valueuom']]

# Basic cleaning
print("\n🧹 Cleaning data...")
initial_count = len(labevents)
labevents = labevents.dropna(subset=['valuenum'])
print(f"  • Removed {initial_count - len(labevents):,} rows with missing values")

labevents['charttime'] = pd.to_datetime(labevents['charttime'], errors='coerce')
labevents = labevents.dropna(subset=['charttime'])

# Filter to top 10 most common lab items
print("\n🔬 Filtering to top 10 lab items...")
top_items = labevents['itemid'].value_counts().head(10).index.tolist()
labevents_filtered = labevents[labevents['itemid'].isin(top_items)]

print(f"✅ Filtered to {len(labevents_filtered):,} measurements")
print(f"📊 Top 10 Lab Items: {top_items}")

# Display summary statistics
print("\n" + "="*80)
print("📊 DATA SUMMARY")
print("="*80)
print(labevents_filtered.describe())

In [5]:
# SECTION 4: ADVANCED DATA PREPROCESSING
# PURPOSE: Transform raw data into structured episodes for RL
# CRITICAL FIX: Create multi-dimensional states (NOT averaging different labs!)
# ===============================================================================

class DataPreprocessor:
    """
    Advanced data preprocessing for medical RL
    
    KEY IMPROVEMENT: Maintains separate values for each lab test type
    (Original version incorrectly averaged glucose, WBC, creatinine together!)
    """
    
    def __init__(self, lab_item_info: Dict, config: Config):
        self.lab_item_info = lab_item_info
        self.config = config
        
    def create_episodes(self, labevents_df: pd.DataFrame) -> List[Dict]:
        """
        IMPROVED: Better handling of missing data and outliers
        """
        
        episodes = []
        grouped = labevents_df.groupby(['subject_id', 'hadm_id'])
        
        print(f"\n🔄 Processing {len(grouped)} patient admissions...")
        
        for (subject_id, hadm_id), group in grouped:
            # Pivot table
            pivot = group.pivot_table(
                index='charttime',
                columns='itemid',
                values='valuenum',
                aggfunc='mean'
            )
            
            # Sort by time
            pivot = pivot.sort_index()
            
            # IMPROVED: Remove outliers before forward-fill
            for col in pivot.columns:
                if col in self.lab_item_info:
                    normal_range = self.lab_item_info[col]['normal_range']
                    # Remove values > 10x normal range (likely errors)
                    upper_limit = normal_range[1] * 10
                    lower_limit = max(0, normal_range[0] * 0.1)
                    pivot[col] = pivot[col].clip(lower=lower_limit, upper=upper_limit)
            
            # Forward-fill with limit
            pivot = pivot.fillna(method='ffill', limit=3)  # Max 3 consecutive fills
            
            # Backward-fill remaining
            pivot = pivot.fillna(method='bfill', limit=1)
            
            # Drop rows still with NaN
            pivot = pivot.dropna()
            
            # IMPROVED: More lenient length requirement
            if len(pivot) >= self.config.min_episode_length:
                episode = {
                    'subject_id': int(subject_id),
                    'hadm_id': int(hadm_id),
                    'timestamps': pivot.index.tolist(),
                    'lab_values': pivot.values.astype(np.float32),
                    'lab_items': pivot.columns.tolist(),
                    'length': len(pivot),
                    'duration_hours': (pivot.index[-1] - pivot.index[0]).total_seconds() / 3600
                }
                
                # Calculate statistics
                episode['mean_values'] = pivot.mean().to_dict()
                episode['std_values'] = pivot.std().to_dict()
                
                # Calculate % in normal range
                normal_count = 0
                total_count = 0
                for itemid in pivot.columns:
                    if itemid in self.lab_item_info:
                        normal_range = self.lab_item_info[itemid]['normal_range']
                        values = pivot[itemid].values
                        in_range = ((values >= normal_range[0]) & 
                                  (values <= normal_range[1])).sum()
                        normal_count += in_range
                        total_count += len(values)
                
                episode['pct_in_normal_range'] = normal_count / total_count if total_count > 0 else 0
                
                episodes.append(episode)
        
        print(f"✅ Created {len(episodes)} valid episodes")
        return episodes
    
    def split_episodes(self, episodes: List[Dict]) -> Tuple[List, List, List]:
        """
        Split episodes into train/val/test sets
        
        Args:
            episodes: List of episode dictionaries
            
        Returns:
            Tuple of (train_episodes, val_episodes, test_episodes)
        """
        
        n = len(episodes)
        indices = np.arange(n)
        np.random.shuffle(indices)
        
        train_end = int(self.config.train_split * n)
        val_end = int((self.config.train_split + self.config.val_split) * n)
        
        train_eps = [episodes[i] for i in indices[:train_end]]
        val_eps = [episodes[i] for i in indices[train_end:val_end]]
        test_eps = [episodes[i] for i in indices[val_end:]]
        
        print(f"\\n📊 Split Summary:")
        print(f"  • Training: {len(train_eps)} episodes")
        print(f"  • Validation: {len(val_eps)} episodes")
        print(f"  • Test: {len(test_eps)} episodes")
        
        return train_eps, val_eps, test_eps
    
    def generate_statistics(self, episodes: List[Dict]) -> Dict:
        """Generate comprehensive dataset statistics"""
        
        stats = {}
        
        # Episode statistics
        lengths = [ep['length'] for ep in episodes]
        durations = [ep['duration_hours'] for ep in episodes]
        normal_pcts = [ep['pct_in_normal_range'] for ep in episodes]
        
        stats['episodes'] = {
            'count': len(episodes),
            'length_mean': float(np.mean(lengths)),
            'length_std': float(np.std(lengths)),
            'length_min': int(np.min(lengths)),
            'length_max': int(np.max(lengths)),
            'duration_mean': float(np.mean(durations)),
            'duration_std': float(np.std(durations)),
            'normal_pct_mean': float(np.mean(normal_pcts)),
            'normal_pct_std': float(np.std(normal_pcts))
        }
        
        return stats


In [None]:
# EXECUTE PREPROCESSING
# ===============================================================================

preprocessor = DataPreprocessor(LAB_ITEM_INFO, config)

# Create episodes
episodes = preprocessor.create_episodes(labevents_filtered)

# Split data
train_episodes, val_episodes, test_episodes = preprocessor.split_episodes(episodes)

# Generate statistics
stats = preprocessor.generate_statistics(episodes)

print("\\n" + "="*80)
print("📈 DATASET STATISTICS")
print("="*80)
print(f"Episode count: {stats['episodes']['count']}")
print(f"Average length: {stats['episodes']['length_mean']:.1f} ± {stats['episodes']['length_std']:.1f}")
print(f"Length range: {stats['episodes']['length_min']} - {stats['episodes']['length_max']}")
print(f"Average duration: {stats['episodes']['duration_mean']:.1f} ± {stats['episodes']['duration_std']:.1f} hours")
print(f"% in normal range: {stats['episodes']['normal_pct_mean']:.1%} ± {stats['episodes']['normal_pct_std']:.1%}")

In [None]:
# SECTION 5: REINFORCEMENT LEARNING ENVIRONMENT
# PURPOSE: Define the MDP (Markov Decision Process) for medical treatment
# KEY IMPROVEMENT: Multi-dimensional state space with temporal features
# ===============================================================================

class MedicalTreatmentEnv(gym.Env):
    """
    Advanced Medical Treatment Environment
    
    STATE SPACE (26 dimensions):
    - Lab values (10): Current measurements for each lab test
    - Normalized values (10): Z-scores relative to normal ranges
    - Short-term trends (5): Recent changes in key labs
    - Time (1): Normalized time in episode
    
    ACTION SPACE (4 discrete actions):
    - 0: No intervention (watchful waiting)
    - 1: Medication A (e.g., insulin for glucose control)
    - 2: Medication B (e.g., diuretic for fluid balance)
    - 3: Medication C (broad-spectrum stabilization)
    
    REWARD STRUCTURE:
    - Clinical reward: Based on # of abnormal labs and severity
    - Treatment penalty: Small cost for interventions (avoid over-treatment)
    - Stability reward: Bonus for maintaining stable values
    
    TRANSITION DYNAMICS:
    - Realistic treatment effects (lab-specific responses)
    - Stochastic outcomes (variability in patient response)
    - Delayed effects (treatments take time to work)
    """
    
    def __init__(self, episodes_list: List[Dict], lab_item_info: Dict):
        super().__init__()
        
        self.episodes_list = episodes_list
        self.lab_item_info = lab_item_info
        
        # State space: [10 labs + 10 normalized + 5 trends + 1 time] = 26 features
        self.observation_space = spaces.Box(
            low=-10.0,
            high=10.0,
            shape=(26,),
            dtype=np.float32
        )
        
        # Action space: 4 discrete treatment options
        self.action_space = spaces.Discrete(4)
        
        # Episode state
        self.current_episode = None
        self.current_step = 0
        self.state_history = []
        self.history_window = 5
        
    def reset(self):
        """Reset environment for new episode"""
        
        # Randomly select an episode
        self.current_episode = random.choice(self.episodes_list)
        self.current_step = 0
        self.state_history = []
        
        # Get initial observation
        obs = self._get_observation()
        return obs
    
    def _get_observation(self) -> np.ndarray:
        """
        Construct rich observation from current state
        
        Returns:
            26-dimensional observation vector
        """
        
        # Get current lab values
        current_labs = self.current_episode['lab_values'][self.current_step]
        lab_items = self.current_episode['lab_items']
        
        # Ensure we have 10 values (pad if necessary)
        if len(current_labs) < 10:
            current_labs = np.pad(current_labs, (0, 10 - len(current_labs)), 'constant')
            lab_items = lab_items + [0] * (10 - len(lab_items))
        
        # 1. Raw lab values (first 10)
        raw_values = current_labs[:10]
        
        # 2. Normalized values (z-scores)
        normalized = np.zeros(10)
        for i in range(min(10, len(lab_items))):
            itemid = lab_items[i]
            if itemid in self.lab_item_info:
                normal_range = self.lab_item_info[itemid]['normal_range']
                mean_normal = (normal_range[0] + normal_range[1]) / 2
                std_normal = (normal_range[1] - normal_range[0]) / 4
                normalized[i] = (raw_values[i] - mean_normal) / (std_normal + 1e-6)
            else:
                normalized[i] = 0.0
        
        # 3. Trend indicators (first 5 labs only)
        trends = np.zeros(5)
        if len(self.state_history) > 0:
            prev_labs = self.state_history[-1][:10]
            for i in range(5):
                if prev_labs[i] != 0:
                    trends[i] = (raw_values[i] - prev_labs[i]) / (abs(prev_labs[i]) + 1e-6)
        
        # 4. Time feature
        time_norm = self.current_step / max(len(self.current_episode['lab_values']) - 1, 1)
        
        # Combine all features
        obs = np.concatenate([
            raw_values,      # 10 features
            normalized,      # 10 features
            trends,          # 5 features
            [time_norm]      # 1 feature
        ]).astype(np.float32)
        
        # Handle edge cases
        obs = np.nan_to_num(obs, nan=0.0, posinf=10.0, neginf=-10.0)
        obs = np.clip(obs, -10.0, 10.0)
        
        return obs
    
    def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
        """
        Execute action and transition to next state
        
        Args:
            action: Treatment decision (0-3)
            
        Returns:
            Tuple of (observation, reward, done, info)
        """
        
        # Get current state
        current_labs = self.current_episode['lab_values'][self.current_step].copy()
        lab_items = self.current_episode['lab_items']
        
        # Store in history
        self.state_history.append(current_labs.copy())
        if len(self.state_history) > self.history_window:
            self.state_history.pop(0)
        
        # Apply treatment effect
        if action > 0:
            current_labs = self._apply_treatment(current_labs, action, lab_items)
        
        # Calculate reward
        reward = self._calculate_reward(current_labs, lab_items, action)
        
        # Move to next step
        self.current_step += 1
        done = self.current_step >= len(self.current_episode['lab_values']) - 1
        if done:
                final_abnormal = self._count_abnormal(current_labs, lab_items)
                
                # Big bonus for ending with few abnormal labs
                if final_abnormal == 0:
                    reward += 20.0  # Perfect outcome!
                elif final_abnormal == 1:
                    reward += 15.0  # Excellent
                elif final_abnormal == 2:
                    reward += 10.0  # Very good
                elif final_abnormal == 3:
                    reward += 5.0   # Good
                elif final_abnormal == 4:
                    reward += 2.0   # Acceptable
                # No bonus for 5+ abnormal

    
        # Get next observation
        if not done:
            obs = self._get_observation()
        else:
            obs = np.zeros(26, dtype=np.float32)
        
        # Info for logging
        info = {
            'abnormal_count': self._count_abnormal(current_labs, lab_items),
            'episode_length': len(self.current_episode['lab_values'])
        }
        
        return obs, reward, done, info
    
    def _apply_treatment(self, labs: np.ndarray, action: int, lab_items: List) -> np.ndarray:
        """
        Apply treatment effects to lab values
        
        Simulates realistic pharmacological responses
        """
        
        labs = labs.copy()
        
        # Treatment effect parameters
        if action == 1:  # Medication A (e.g., insulin)
            target_items = [50809]  # Glucose
            effect_strength = 0.15
            variability = 0.05
        elif action == 2:  # Medication B (e.g., diuretic)
            target_items = [50971, 50983]  # Potassium, Sodium
            effect_strength = 0.10
            variability = 0.08
        else:  # action == 3 - Broad treatment
            target_items = lab_items
            effect_strength = 0.05
            variability = 0.10
        
        # Apply effects
        for i, itemid in enumerate(lab_items):
            if i < len(labs) and itemid in target_items:
                if itemid in self.lab_item_info:
                    # Target = center of normal range
                    normal_range = self.lab_item_info[itemid]['normal_range']
                    target = (normal_range[0] + normal_range[1]) / 2
                    
                    # Move toward target with noise
                    delta = effect_strength * (target - labs[i])
                    noise = np.random.normal(0, variability * abs(labs[i] + 1e-6))
                    labs[i] += delta + noise
                    
                    # Ensure positive values
                    labs[i] = max(0.1, labs[i])
        
        return labs
    
    def _calculate_reward(self, labs, lab_items, action):
        """
        SIMPLIFIED & CONSISTENT reward function
        
        Key principles:
        1. Clear positive/negative signals
        2. Scale matches evaluation expectations
        3. No complex shaping that breaks transfer
        """
        
        abnormal_count = 0
        normal_count = 0
        improvement_bonus = 0.0
        
        for i, itemid in enumerate(lab_items):
            if i < len(labs) and itemid in self.lab_item_info:
                normal_range = self.lab_item_info[itemid]['normal_range']
                val = labs[i]
                
                # Count normal vs abnormal
                if normal_range[0] <= val <= normal_range[1]:
                    normal_count += 1
                else:
                    abnormal_count += 1
                    
                    # IMPROVEMENT TRACKING
                    if len(self.state_history) > 0 and i < len(self.state_history[-1]):
                        prev_val = self.state_history[-1][i]
                        
                        # Was it abnormal before?
                        was_abnormal = (prev_val < normal_range[0] or 
                                      prev_val > normal_range[1])
                        
                        if was_abnormal:
                            # Calculate improvement
                            prev_dist = min(abs(prev_val - normal_range[0]), 
                                          abs(prev_val - normal_range[1]))
                            curr_dist = min(abs(val - normal_range[0]), 
                                          abs(val - normal_range[1]))
                            
                            if curr_dist < prev_dist:
                                improvement_bonus += 2.0  # Reward improvement
        
        # SIMPLE, CLEAR REWARD COMPONENTS
        # Base reward: Strong positive for normal, moderate negative for abnormal
        base_reward = (normal_count * 3.0) - (abnormal_count * 1.0)
        
        # Treatment cost: Discourage unnecessary treatment
        if action > 0:
            # Only penalize if patient is mostly healthy
            if abnormal_count <= 2:
                treatment_cost = -2.0  # Don't treat healthy patients!
            else:
                treatment_cost = -0.3  # Small cost for necessary treatment
        else:
            # Reward for not treating when appropriate
            treatment_cost = 0.5 if abnormal_count <= 2 else 0.0
        
        # Bonus for improvements
        reward_bonus = improvement_bonus
        
        # TOTAL REWARD
        total_reward = base_reward + treatment_cost + reward_bonus
        
        return float(total_reward)

        
    def _count_abnormal(self, labs: np.ndarray, lab_items: List) -> int:
        """Count number of abnormal lab values"""
        
        count = 0
        for i, itemid in enumerate(lab_items):
            if i < len(labs) and itemid in self.lab_item_info:
                normal_range = self.lab_item_info[itemid]['normal_range']
                if labs[i] < normal_range[0] or labs[i] > normal_range[1]:
                    count += 1
        return count

print("✅ Medical Treatment Environment defined")
print(f"   • State space: {26} dimensions")
print(f"   • Action space: {4} discrete actions")
print(f"   • Reward: Multi-component (clinical + stability - cost)")


In [9]:
# ==============================================================================
# SECTION 6: REINFORCEMENT LEARNING TRAINING
# PURPOSE: Train RL agents with consistent reward structure
# ==============================================================================

import os
import json
from stable_baselines3 import PPO, A2C, DQN
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import EvalCallback, BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy

print("\n" + "="*80)
print("🚀 SECTION 6: REINFORCEMENT LEARNING TRAINING")
print("="*80)

# -----------------------------------------------------------------------------
# 6.1: Create Environments
# -----------------------------------------------------------------------------

print("\n📦 Creating training environments...")

train_env = DummyVecEnv([lambda: MedicalTreatmentEnv(train_episodes, LAB_ITEM_INFO)])
val_env = DummyVecEnv([lambda: MedicalTreatmentEnv(val_episodes, LAB_ITEM_INFO)])
test_env = MedicalTreatmentEnv(test_episodes, LAB_ITEM_INFO)

print(f"✅ Environments created:")
print(f"   • Training episodes: {len(train_episodes)}")
print(f"   • Validation episodes: {len(val_episodes)}")
print(f"   • Test episodes: {len(test_episodes)}")

# -----------------------------------------------------------------------------
# 6.2: Enhanced Callbacks
# -----------------------------------------------------------------------------

class ProgressCallback(BaseCallback):
    """Track training progress with detailed metrics"""
    
    def __init__(self, verbose=0):
        super().__init__(verbose)
        self.episode_rewards = []
        self.episode_lengths = []
        
    def _on_step(self) -> bool:
        if len(self.model.ep_info_buffer) > 0:
            for info in self.model.ep_info_buffer:
                self.episode_rewards.append(info['r'])
                self.episode_lengths.append(info['l'])
        return True

class EarlyStoppingCallback(EvalCallback):
    """Stop training if no improvement for patience evaluations"""
    
    def __init__(self, eval_env, patience=5, min_delta=1.0, **kwargs):
        super().__init__(eval_env, **kwargs)
        self.patience = patience
        self.min_delta = min_delta
        self.best_mean = -np.inf
        self.wait_count = 0
        
    def _on_step(self) -> bool:
        result = super()._on_step()
        
        if len(self.evaluations_results) > 0:
            current_mean = np.mean(self.evaluations_results[-1])
            
            if current_mean > self.best_mean + self.min_delta:
                self.best_mean = current_mean
                self.wait_count = 0
                print(f"✨ New best reward: {current_mean:.2f}")
            else:
                self.wait_count += 1
                
            if self.wait_count >= self.patience:
                print(f"⏹️  Early stopping triggered (no improvement for {self.patience} evals)")
                return False
                
        return result

# -----------------------------------------------------------------------------
# 6.3: Training Configuration (✅ Fixed argument names)
# -----------------------------------------------------------------------------

TRAINING_CONFIG = {
    'PPO': {
        'timesteps': 150000,
        'learning_rate': 3e-4,       # ✅ renamed from 'lr'
        'gamma': 0.99,
        'n_steps': 2048,
        'batch_size': 64,
        'n_epochs': 10,
        'gae_lambda': 0.95,
        'clip_range': 0.2,
        'ent_coef': 0.01,
        'vf_coef': 0.5,
        'max_grad_norm': 0.5
    },
    'A2C': {
        'timesteps': 120000,
        'learning_rate': 2e-4,       # ✅ renamed from 'lr'
        'gamma': 0.99,
        'n_steps': 5,
        'gae_lambda': 0.95,
        'ent_coef': 0.01,
        'vf_coef': 0.5,
        'max_grad_norm': 0.5
    },
    'DQN': {
        'timesteps': 150000,
        'learning_rate': 1e-4,       # ✅ renamed from 'lr'
        'gamma': 0.99,
        'buffer_size': 100000,
        'learning_starts': 1000,
        'batch_size': 64,
        'tau': 0.005,
        'train_freq': 4,
        'gradient_steps': 1,
        'target_update_interval': 1000,
        'exploration_fraction': 0.3,
        'exploration_initial_eps': 1.0,
        'exploration_final_eps': 0.05
    }
}


# -----------------------------------------------------------------------------
# 6.4: Universal Training Function
# -----------------------------------------------------------------------------

def train_rl_agent(algo_class, algo_name, train_env, val_env, config, seed=42):
    """
    Universal training function for any SB3 algorithm
    
    Args:
        algo_class: PPO, A2C, or DQN class
        algo_name: String identifier
        train_env: Training environment
        val_env: Validation environment
        config: Hyperparameter dictionary
        seed: Random seed
    
    Returns:
        Trained model, progress callback
    """
    
    print(f"\n{'='*80}")
    print(f"🎯 TRAINING {algo_name}")
    print(f"{'='*80}")
    
    # Create save directory
    save_dir = f"/kaggle/working/models/{algo_name}_best"
    os.makedirs(save_dir, exist_ok=True)
    
    # Extract hyperparameters
    timesteps = config.pop('timesteps')
    
    # Setup callbacks
    eval_callback = EarlyStoppingCallback(
        val_env,
        eval_freq=5000,
        n_eval_episodes=10,
        best_model_save_path=save_dir,
        log_path=save_dir,
        patience=5,
        min_delta=1.0,
        deterministic=True,
        render=False,
        verbose=1
    )
    
    progress_callback = ProgressCallback(verbose=0)
    
    # Initialize model
    print(f"📊 Hyperparameters: {config}")
    
    model = algo_class(
        "MlpPolicy",
        train_env,
        verbose=1,
        seed=seed,
        tensorboard_log=f"/kaggle/working/tb/{algo_name}",
        **config
    )
    
    # Train
    print(f"🏃 Training for {timesteps:,} timesteps...")
    model.learn(
        total_timesteps=timesteps,
        callback=[eval_callback, progress_callback]
    )
    
    # Save final model
    model.save(f"/kaggle/working/models/{algo_name}_final.zip")
    print(f"✅ {algo_name} training complete!")
    print(f"   📁 Saved to: /kaggle/working/models/{algo_name}_final.zip")
    
    return model, progress_callback

# -----------------------------------------------------------------------------
# 6.5: Train All Algorithms
# -----------------------------------------------------------------------------

print("\n" + "="*80)
print("🤖 TRAINING ALL ALGORITHMS")
print("="*80)

models = {}
callbacks = {}
SEED = int(config.seed)

# Train PPO
models['PPO'], callbacks['PPO'] = train_rl_agent(
    PPO, 'PPO', 
    train_env, val_env, 
    TRAINING_CONFIG['PPO'].copy(),
    seed=SEED
)

# Train A2C
models['A2C'], callbacks['A2C'] = train_rl_agent(
    A2C, 'A2C',
    train_env, val_env,
    TRAINING_CONFIG['A2C'].copy(),
    seed=SEED
)

# Train DQN
models['DQN'], callbacks['DQN'] = train_rl_agent(
    DQN, 'DQN',
    train_env, val_env,
    TRAINING_CONFIG['DQN'].copy(),
    seed=SEED
)

print("\n" + "="*80)
print("✅ ALL MODELS TRAINED SUCCESSFULLY!")
print("="*80)


🚀 SECTION 6: REINFORCEMENT LEARNING TRAINING

📦 Creating training environments...
✅ Environments created:
   • Training episodes: 84
   • Validation episodes: 18
   • Test episodes: 18

🤖 TRAINING ALL ALGORITHMS

🎯 TRAINING PPO
📊 Hyperparameters: {'learning_rate': 0.0003, 'gamma': 0.99, 'n_steps': 2048, 'batch_size': 64, 'n_epochs': 10, 'gae_lambda': 0.95, 'clip_range': 0.2, 'ent_coef': 0.01, 'vf_coef': 0.5, 'max_grad_norm': 0.5}
Using cuda device
🏃 Training for 150,000 timesteps...
Logging to /kaggle/working/tb/PPO/PPO_1
-----------------------------
| time/              |      |
|    fps             | 531  |
|    iterations      | 1    |
|    time_elapsed    | 3    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 463         |
|    iterations           | 2           |
|    time_elapsed         | 8           |
|    total_timesteps      | 4096        |
| train/ 

In [10]:
# ==============================================================================
# SECTION 7: COMPREHENSIVE MODEL EVALUATION
# PURPOSE: Evaluate all trained models on test set with clinical metrics
# ==============================================================================

print("\n" + "="*80)
print("📊 SECTION 7: MODEL EVALUATION & COMPARISON")
print("="*80)

# -----------------------------------------------------------------------------
# 7.1: Evaluation Functions
# -----------------------------------------------------------------------------

def evaluate_model_detailed(model, env, n_episodes, model_name):
    """
    Comprehensive model evaluation with clinical metrics
    
    Returns:
        Dictionary with rewards, abnormal counts, actions, appropriateness
    """
    
    all_rewards = []
    all_episode_rewards = []
    all_actions = []
    all_abnormal_counts = []
    all_episode_lengths = []
    
    for ep_idx in range(n_episodes):
        obs = env.reset()
        done = False
        episode_reward = 0
        episode_actions = []
        episode_abnormal = []
        steps = 0
        
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            obs, reward, done, info = env.step(action)
            
            all_rewards.append(reward)
            episode_reward += reward
            episode_actions.append(int(action))
            episode_abnormal.append(info.get('abnormal_count', 0))
            steps += 1
        
        all_episode_rewards.append(episode_reward)
        all_actions.extend(episode_actions)
        all_abnormal_counts.extend(episode_abnormal)
        all_episode_lengths.append(steps)
    
    # Calculate appropriateness
    appropriate_actions = 0
    total_actions = len(all_actions)
    
    for action, abnormal in zip(all_actions, all_abnormal_counts):
        # Appropriate if: treating when sick (≥3) OR not treating when healthy (<2)
        if (action > 0 and abnormal >= 3) or (action == 0 and abnormal < 2):
            appropriate_actions += 1
    
    appropriateness = appropriate_actions / total_actions if total_actions > 0 else 0
    
    return {
        'model_name': model_name,
        'mean_reward': np.mean(all_episode_rewards),
        'std_reward': np.std(all_episode_rewards),
        'mean_step_reward': np.mean(all_rewards),
        'mean_abnormal': np.mean(all_abnormal_counts),
        'std_abnormal': np.std(all_abnormal_counts),
        'treatment_rate': np.mean(np.array(all_actions) > 0),
        'action_distribution': np.bincount(all_actions, minlength=4),
        'appropriateness': appropriateness,
        'mean_episode_length': np.mean(all_episode_lengths),
        'all_episode_rewards': all_episode_rewards,
        'all_actions': all_actions,
        'all_abnormal_counts': all_abnormal_counts
    }

def evaluate_baseline(env, n_episodes, policy_name, policy_fn):
    """Evaluate a baseline policy"""
    
    episode_rewards = []
    
    for _ in range(n_episodes):
        obs = env.reset()
        done = False
        episode_reward = 0
        step = 0
        
        while not done:
            action = policy_fn(obs, step)
            obs, reward, done, info = env.step(action)
            episode_reward += reward
            step += 1
        
        episode_rewards.append(episode_reward)
    
    return {
        'policy_name': policy_name,
        'mean_reward': np.mean(episode_rewards),
        'std_reward': np.std(episode_rewards)
    }

# -----------------------------------------------------------------------------
# 7.2: Evaluate Trained Models
# -----------------------------------------------------------------------------

print("\n📈 Evaluating trained models on test set...")

N_TEST_EPISODES = min(50, len(test_episodes))
results = {}

for name, model in models.items():
    print(f"\n🔍 Evaluating {name}...")
    results[name] = evaluate_model_detailed(
        model, test_env, N_TEST_EPISODES, name
    )
    
    print(f"   Mean Episode Reward: {results[name]['mean_reward']:.2f} ± {results[name]['std_reward']:.2f}")
    print(f"   Mean Abnormal Labs: {results[name]['mean_abnormal']:.2f}")
    print(f"   Treatment Rate: {results[name]['treatment_rate']:.1%}")
    print(f"   Clinical Appropriateness: {results[name]['appropriateness']:.1%}")

# -----------------------------------------------------------------------------
# 7.3: Evaluate Baselines
# -----------------------------------------------------------------------------

print("\n" + "="*80)
print("📊 BASELINE COMPARISON")
print("="*80)

baselines = {}

# Random policy
baselines['Random'] = evaluate_baseline(
    test_env, N_TEST_EPISODES, 'Random',
    lambda obs, step: test_env.action_space.sample()
)

# Always treat (action 1)
baselines['Always_Treat'] = evaluate_baseline(
    test_env, N_TEST_EPISODES, 'Always Treat',
    lambda obs, step: 1
)

# Never treat (action 0)
baselines['Never_Treat'] = evaluate_baseline(
    test_env, N_TEST_EPISODES, 'Never Treat',
    lambda obs, step: 0
)

# Smart baseline: treat if many abnormal (estimated from normalized values)
def smart_policy(obs, step):
    normalized_vals = obs[10:20]  # Normalized lab values
    abnormal_est = np.sum(np.abs(normalized_vals) > 1.0)
    return 1 if abnormal_est >= 3 else 0

baselines['Smart_Baseline'] = evaluate_baseline(
    test_env, N_TEST_EPISODES, 'Smart Baseline',
    smart_policy
)

print("\nBaseline Results:")
for name, result in baselines.items():
    print(f"  {name:20s}: {result['mean_reward']:7.2f} ± {result['std_reward']:.2f}")

print("\nTrained Models:")
for name in models.keys():
    print(f"  {name:20s}: {results[name]['mean_reward']:7.2f} ± {results[name]['std_reward']:.2f}")

# -----------------------------------------------------------------------------
# 7.4: Find Best Model
# -----------------------------------------------------------------------------

best_model_name = max(results.keys(), key=lambda k: results[k]['mean_reward'])
best_model = models[best_model_name]

print("\n" + "="*80)
print(f"🏆 BEST MODEL: {best_model_name}")
print("="*80)
print(f"Mean Episode Reward: {results[best_model_name]['mean_reward']:.2f}")
print(f"Mean Abnormal Labs: {results[best_model_name]['mean_abnormal']:.2f}")
print(f"Treatment Rate: {results[best_model_name]['treatment_rate']:.1%}")
print(f"Clinical Appropriateness: {results[best_model_name]['appropriateness']:.1%}")

# Calculate improvements
best_reward = results[best_model_name]['mean_reward']
print(f"\nImprovement over baselines:")
print(f"  vs Random: {best_reward - baselines['Random']['mean_reward']:+.2f}")
print(f"  vs Always Treat: {best_reward - baselines['Always_Treat']['mean_reward']:+.2f}")
print(f"  vs Never Treat: {best_reward - baselines['Never_Treat']['mean_reward']:+.2f}")
print(f"  vs Smart Baseline: {best_reward - baselines['Smart_Baseline']['mean_reward']:+.2f}")


📊 SECTION 7: MODEL EVALUATION & COMPARISON

📈 Evaluating trained models on test set...

🔍 Evaluating PPO...
   Mean Episode Reward: 128.46 ± 150.87
   Mean Abnormal Labs: 5.83
   Treatment Rate: 63.2%
   Clinical Appropriateness: 62.6%

🔍 Evaluating A2C...
   Mean Episode Reward: 244.44 ± 193.77
   Mean Abnormal Labs: 5.57
   Treatment Rate: 100.0%
   Clinical Appropriateness: 99.5%

🔍 Evaluating DQN...
   Mean Episode Reward: 133.09 ± 105.03
   Mean Abnormal Labs: 5.25
   Treatment Rate: 51.2%
   Clinical Appropriateness: 51.2%

📊 BASELINE COMPARISON

Baseline Results:
  Random              :  146.68 ± 135.91
  Always_Treat        :   99.77 ± 136.19
  Never_Treat         :  114.69 ± 127.58
  Smart_Baseline      :  168.23 ± 152.59

Trained Models:
  PPO                 :  128.46 ± 150.87
  A2C                 :  244.44 ± 193.77
  DQN                 :  133.09 ± 105.03

🏆 BEST MODEL: A2C
Mean Episode Reward: 244.44
Mean Abnormal Labs: 5.57
Treatment Rate: 100.0%
Clinical Appropriatenes

In [11]:
# ==============================================================================
# SECTION 8: CLINICAL FILTERING & SAFETY CONSTRAINTS
# PURPOSE: Add clinical safety rules to improve appropriateness
# ==============================================================================

print("\n" + "="*80)
print("🏥 SECTION 8: CLINICAL FILTERING & SAFETY")
print("="*80)

# -----------------------------------------------------------------------------
# 8.1: Clinical Filter Implementation
# -----------------------------------------------------------------------------

class ClinicalSafetyFilter:
    """
    Wrapper that enforces clinical safety constraints on RL policies
    
    Safety Rules:
    1. Don't treat stable patients (abnormal < 2)
    2. Always treat sick patients (abnormal ≥ 3)
    3. Trust RL for borderline cases (abnormal = 2)
    """
    
    def __init__(self, base_model, treat_threshold=3, stable_threshold=2):
        self.base_model = base_model
        self.treat_threshold = treat_threshold
        self.stable_threshold = stable_threshold
        
    def predict(self, obs, abnormal_count, deterministic=True):
        """Get clinically-filtered action"""
        
        # Get RL model's suggestion
        rl_action, _ = self.base_model.predict(obs, deterministic=deterministic)
        
        # Apply clinical safety rules
        if abnormal_count < self.stable_threshold:
            # Rule 1: Patient is stable → No treatment
            return 0, None
        
        elif abnormal_count >= self.treat_threshold:
            # Rule 2: Patient is sick → Ensure treatment
            if rl_action == 0:
                return 1, None  # Override to treat
            else:
                return rl_action, None  # Trust RL's treatment choice
        
        else:
            # Rule 3: Borderline case → Trust RL
            return rl_action, None

# -----------------------------------------------------------------------------
# 8.2: Evaluate Raw vs Filtered Policies
# -----------------------------------------------------------------------------

def compare_raw_vs_filtered(model, model_name, episodes, env_class, lab_info):
    """
    Compare performance of raw RL policy vs clinically-filtered policy
    """
    
    filtered_policy = ClinicalSafetyFilter(model)
    
    results = {
        'raw': {
            'episode_rewards': [],
            'treatment_rates': [],
            'appropriateness_scores': [],
            'abnormal_counts': []
        },
        'filtered': {
            'episode_rewards': [],
            'treatment_rates': [],
            'appropriateness_scores': [],
            'abnormal_counts': []
        }
    }
    
    for episode in episodes:
        env = env_class([episode], lab_info)
        
        # -------------------------
        # Evaluate RAW policy
        # -------------------------
        obs = env.reset()
        done = False
        episode_data = {
            'rewards': [],
            'actions': [],
            'abnormal': []
        }
        
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            obs, reward, done, info = env.step(action)
            
            episode_data['rewards'].append(reward)
            episode_data['actions'].append(int(action))
            episode_data['abnormal'].append(info.get('abnormal_count', 0))
        
        # Calculate metrics
        total_reward = sum(episode_data['rewards'])
        treatment_rate = np.mean(np.array(episode_data['actions']) > 0)
        
        appropriate = sum(
            1 for action, abnormal in zip(episode_data['actions'], episode_data['abnormal'])
            if (action > 0 and abnormal >= 3) or (action == 0 and abnormal < 2)
        )
        appropriateness = appropriate / len(episode_data['actions'])
        
        results['raw']['episode_rewards'].append(total_reward)
        results['raw']['treatment_rates'].append(treatment_rate)
        results['raw']['appropriateness_scores'].append(appropriateness)
        results['raw']['abnormal_counts'].extend(episode_data['abnormal'])
        
        # -------------------------
        # Evaluate FILTERED policy
        # -------------------------
        obs = env.reset()
        done = False
        episode_data = {
            'rewards': [],
            'actions': [],
            'abnormal': []
        }
        
        info = {'abnormal_count': 3}  # Initialize
        
        while not done:
            abnormal_count = info.get('abnormal_count', 3)
            action, _ = filtered_policy.predict(obs, abnormal_count)
            obs, reward, done, info = env.step(action)
            
            episode_data['rewards'].append(reward)
            episode_data['actions'].append(int(action))
            episode_data['abnormal'].append(info.get('abnormal_count', 0))
        
        total_reward = sum(episode_data['rewards'])
        treatment_rate = np.mean(np.array(episode_data['actions']) > 0)
        
        appropriate = sum(
            1 for action, abnormal in zip(episode_data['actions'], episode_data['abnormal'])
            if (action > 0 and abnormal >= 3) or (action == 0 and abnormal < 2)
        )
        appropriateness = appropriate / len(episode_data['actions'])
        
        results['filtered']['episode_rewards'].append(total_reward)
        results['filtered']['treatment_rates'].append(treatment_rate)
        results['filtered']['appropriateness_scores'].append(appropriateness)
        results['filtered']['abnormal_counts'].extend(episode_data['abnormal'])
    
    # Aggregate results
    summary = {
        'raw': {
            'mean_reward': np.mean(results['raw']['episode_rewards']),
            'std_reward': np.std(results['raw']['episode_rewards']),
            'treatment_rate': np.mean(results['raw']['treatment_rates']),
            'appropriateness': np.mean(results['raw']['appropriateness_scores']),
            'mean_abnormal': np.mean(results['raw']['abnormal_counts'])
        },
        'filtered': {
            'mean_reward': np.mean(results['filtered']['episode_rewards']),
            'std_reward': np.std(results['filtered']['episode_rewards']),
            'treatment_rate': np.mean(results['filtered']['treatment_rates']),
            'appropriateness': np.mean(results['filtered']['appropriateness_scores']),
            'mean_abnormal': np.mean(results['filtered']['abnormal_counts'])
        }
    }
    
    return summary

# -----------------------------------------------------------------------------
# 8.3: Apply Filtering to All Models
# -----------------------------------------------------------------------------

print("\n🔬 Comparing raw vs filtered policies...")

filtering_results = {}

for name, model in models.items():
    print(f"\n{name}:")
    print("-" * 40)
    
    result = compare_raw_vs_filtered(
        model, name, 
        test_episodes, 
        MedicalTreatmentEnv, 
        LAB_ITEM_INFO
    )
    
    filtering_results[name] = result
    
    print(f"RAW POLICY:")
    print(f"  Reward: {result['raw']['mean_reward']:.2f} ± {result['raw']['std_reward']:.2f}")
    print(f"  Treatment Rate: {result['raw']['treatment_rate']:.1%}")
    print(f"  Appropriateness: {result['raw']['appropriateness']:.1%}")
    
    print(f"\nFILTERED POLICY:")
    print(f"  Reward: {result['filtered']['mean_reward']:.2f} ± {result['filtered']['std_reward']:.2f}")
    print(f"  Treatment Rate: {result['filtered']['treatment_rate']:.1%}")
    print(f"  Appropriateness: {result['filtered']['appropriateness']:.1%}")
    
    improvement = (result['filtered']['appropriateness'] - result['raw']['appropriateness']) * 100
    print(f"\n✨ IMPROVEMENT: +{improvement:.1f} percentage points")

# -----------------------------------------------------------------------------
# 8.4: Summary Table
# -----------------------------------------------------------------------------

print("\n" + "="*80)
print("📋 CLINICAL FILTERING SUMMARY")
print("="*80)

summary_data = []
for name in models.keys():
    raw = filtering_results[name]['raw']
    filtered = filtering_results[name]['filtered']
    improvement = (filtered['appropriateness'] - raw['appropriateness']) * 100
    
    summary_data.append({
        'Model': name,
        'Raw Appropriateness': f"{raw['appropriateness']:.1%}",
        'Filtered Appropriateness': f"{filtered['appropriateness']:.1%}",
        'Improvement': f"+{improvement:.1f}%"
    })

summary_df = pd.DataFrame(summary_data)
print(summary_df.to_string(index=False))


🏥 SECTION 8: CLINICAL FILTERING & SAFETY

🔬 Comparing raw vs filtered policies...

PPO:
----------------------------------------
RAW POLICY:
  Reward: 152.27 ± 161.79
  Treatment Rate: 56.8%
  Appropriateness: 55.0%

FILTERED POLICY:
  Reward: 148.94 ± 159.47
  Treatment Rate: 98.4%
  Appropriateness: 94.5%

✨ IMPROVEMENT: +39.5 percentage points

A2C:
----------------------------------------
RAW POLICY:
  Reward: 169.01 ± 144.80
  Treatment Rate: 100.0%
  Appropriateness: 99.4%

FILTERED POLICY:
  Reward: 177.02 ± 152.61
  Treatment Rate: 100.0%
  Appropriateness: 98.8%

✨ IMPROVEMENT: +-0.6 percentage points

DQN:
----------------------------------------
RAW POLICY:
  Reward: 176.11 ± 145.86
  Treatment Rate: 55.8%
  Appropriateness: 54.7%

FILTERED POLICY:
  Reward: 173.51 ± 140.40
  Treatment Rate: 98.0%
  Appropriateness: 95.5%

✨ IMPROVEMENT: +40.8 percentage points

📋 CLINICAL FILTERING SUMMARY
Model Raw Appropriateness Filtered Appropriateness Improvement
  PPO               5

In [None]:
# ==============================================================================
# SECTION 9: DETAILED MODEL ANALYSIS & REPORTING
# PURPOSE: Visualize best model behavior and export final results
# ==============================================================================

import matplotlib.pyplot as plt
import seaborn as sns
import json
import numpy as np
import pandas as pd
from datetime import datetime  # ✅ Fix for NameError

print("\n" + "="*80)
print("📈 SECTION 9: DETAILED MODEL ANALYSIS & REPORTING")
print("="*80)

# -----------------------------------------------------------------------------
# 9.1: Visualization for Best Model
# -----------------------------------------------------------------------------

print(f"\n🔍 Creating detailed analysis for {best_model_name}...")

bm = results[best_model_name]

fig, axes = plt.subplots(2, 2, figsize=(12, 8))
fig.suptitle(f"Detailed Analysis — {best_model_name}", fontsize=16, fontweight='bold')

# --- Subplot 1: Reward Distribution ---
ax = axes[0, 0]
reward_data = bm.get("rewards") or bm.get("all_episode_rewards") or [bm["mean_reward"]]
ax.hist(np.array(reward_data).flatten(), bins=20, color='#45B7D1', edgecolor='black', alpha=0.8)
ax.set_title('Reward Distribution', fontsize=13, fontweight='bold')
ax.set_xlabel('Episode Reward')
ax.set_ylabel('Frequency')
ax.grid(alpha=0.3, linestyle='--')

# --- Subplot 2: Abnormal Labs Over Time ---
ax = axes[0, 1]
episode_steps = list(range(len(bm.get("all_abnormal_counts", []))))
ax.plot(episode_steps, bm.get("all_abnormal_counts", []), color='#FF6B6B', linewidth=2)
ax.set_title('Abnormal Labs Over Time', fontsize=13, fontweight='bold')
ax.set_xlabel('Episode Step')
ax.set_ylabel('Abnormal Lab Count')
ax.grid(alpha=0.3, linestyle='--')

# --- Subplot 3: Action Distribution ---
ax = axes[1, 0]
actions = bm.get('all_actions', [])
if len(actions) > 0:
    sns.countplot(x=actions, palette='viridis', ax=ax)
    ax.set_title('Action Distribution', fontsize=13, fontweight='bold')
    ax.set_xlabel('Action')
    ax.set_ylabel('Frequency')
else:
    ax.text(0.5, 0.5, 'No Actions Recorded', ha='center', va='center')

# --- Subplot 4: Appropriateness Summary ---
ax = axes[1, 1]
bar_data = {
    'Metric': ['Appropriateness', 'Treatment Rate', 'Mean Abnormal'],
    'Value': [
        bm.get('appropriateness', 0) * 100,
        bm.get('treatment_rate', 0) * 100,
        bm.get('mean_abnormal', 0)
    ]
}
sns.barplot(x='Metric', y='Value', data=pd.DataFrame(bar_data), ax=ax, palette='coolwarm')
ax.set_title('Clinical Metrics Summary', fontsize=13, fontweight='bold')
ax.set_ylabel('Value (%)')
ax.grid(alpha=0.3, linestyle='--')

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

# -----------------------------------------------------------------------------
# 9.2: Generate Final JSON Report (Safe Serialization)
# -----------------------------------------------------------------------------

print("\n📝 Generating final JSON report...")

def safe_convert(obj):
    """Recursively convert NumPy data to JSON-safe Python-native types."""
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    if isinstance(obj, (np.float32, np.float64)):
        return float(obj)
    if isinstance(obj, (np.int32, np.int64)):
        return int(obj)
    if isinstance(obj, dict):
        return {k: safe_convert(v) for k, v in obj.items()}
    if isinstance(obj, list):
        return [safe_convert(v) for v in obj]
    return obj

final_report = {
    "best_model": best_model_name,
    "timestamp": str(datetime.now()),
    "summary": safe_convert({
        "mean_reward": bm["mean_reward"],
        "std_reward": bm["std_reward"],
        "treatment_rate": bm["treatment_rate"],
        "appropriateness": bm["appropriateness"],
        "mean_abnormal": bm["mean_abnormal"],
    }),
    "filtering_results": safe_convert(filtering_results.get(best_model_name, {})),
    "baselines": safe_convert({k: v["mean_reward"] for k, v in baselines.items()}),
}

output_path = "/kaggle/working/final_report.json"
with open(output_path, "w") as f:
    json.dump(safe_convert(final_report), f, indent=4)

print(f"✅ Final report saved to: {output_path}")

# ==============================================================================
# SECTION 9B: COMPREHENSIVE VISUALIZATION & REPORTING
# PURPOSE: Publication-quality visualizations and comparison plots
# ==============================================================================

print("\n" + "="*80)
print("📊 SECTION 9B: VISUALIZATION & FINAL REPORT")
print("="*80)

# -----------------------------------------------------------------------------
# 9B.1: Algorithm Performance Comparison
# -----------------------------------------------------------------------------

print("\n📈 Creating performance comparison plots...")

fig, axes = plt.subplots(2, 2, figsize=(16, 12))
fig.suptitle('Reinforcement Learning Model Performance Comparison', 
             fontsize=16, fontweight='bold', y=0.995)

model_names = list(models.keys())
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']

# Plot 1: Episode Rewards
ax = axes[0, 0]
rewards = [results[name]['mean_reward'] for name in model_names]
errors = [results[name]['std_reward'] for name in model_names]
bars = ax.bar(model_names, rewards, yerr=errors, capsize=8, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
for i, (bar, val) in enumerate(zip(bars, rewards)):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height(), f'{val:.1f}', ha='center', va='bottom', fontweight='bold', fontsize=11)
ax.set_title('Mean Episode Reward', fontsize=13, fontweight='bold')
ax.set_ylabel('Reward')
ax.grid(axis='y', alpha=0.3, linestyle='--')
ax.set_axisbelow(True)

# Plot 2: Mean Abnormal Labs
ax = axes[0, 1]
abnormal = [results[name]['mean_abnormal'] for name in model_names]
bars = ax.bar(model_names, abnormal, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
for bar, val in zip(bars, abnormal):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height(), f'{val:.2f}', ha='center', va='bottom', fontweight='bold', fontsize=11)
ax.axhline(y=5, color='red', linestyle='--', linewidth=2, alpha=0.6, label='Target (≤5)')
ax.set_title('Mean Abnormal Lab Count', fontsize=13, fontweight='bold')
ax.set_ylabel('Count')
ax.legend()
ax.grid(axis='y', alpha=0.3, linestyle='--')
ax.set_axisbelow(True)

# Plot 3: Clinical Appropriateness (Raw vs Filtered)
ax = axes[1, 0]
x = np.arange(len(model_names))
width = 0.35
raw_app = [filtering_results[name]['raw']['appropriateness'] * 100 for name in model_names]
filtered_app = [filtering_results[name]['filtered']['appropriateness'] * 100 for name in model_names]
bars1 = ax.bar(x - width/2, raw_app, width, label='Raw', color='#FFB6B9', edgecolor='black')
bars2 = ax.bar(x + width/2, filtered_app, width, label='Filtered', color='#A8E6CF', edgecolor='black')
for bars in [bars1, bars2]:
    for bar in bars:
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height(), f'{bar.get_height():.1f}%', ha='center', va='bottom', fontsize=9, fontweight='bold')
ax.set_title('Clinical Appropriateness: Raw vs Filtered', fontsize=13, fontweight='bold')
ax.set_ylabel('Appropriateness (%)')
ax.set_xticks(x)
ax.set_xticklabels(model_names)
ax.legend()
ax.grid(axis='y', alpha=0.3, linestyle='--')

# Plot 4: Action Distribution
ax = axes[1, 1]
action_labels = ['No Action', 'Med A', 'Med B', 'Med C']
x = np.arange(len(action_labels))
width = 0.25
for i, name in enumerate(model_names):
    action_dist = results[name]['action_distribution']
    total = action_dist.sum() if np.sum(action_dist) != 0 else 1
    ax.bar(x + i*width, (action_dist / total) * 100, width, label=name, color=colors[i], alpha=0.8, edgecolor='black')
ax.set_title('Action Distribution', fontsize=13, fontweight='bold')
ax.set_ylabel('Percentage (%)')
ax.set_xticks(x + width)
ax.set_xticklabels(action_labels)
ax.legend()
ax.grid(axis='y', alpha=0.3, linestyle='--')

plt.tight_layout()
plt.savefig('/kaggle/working/model_comparison.png', dpi=300, bbox_inches='tight')
print("✅ Saved: model_comparison.png")
plt.show()

# -----------------------------------------------------------------------------
# 9B.2: Final Comprehensive Report (JSON Safe)
# -----------------------------------------------------------------------------

print("\n📝 Generating comprehensive final report...")

final_report_full = {
    "best_model": best_model_name,
    "timestamp": str(datetime.now()),
    "summary": safe_convert({
        "mean_reward": bm["mean_reward"],
        "std_reward": bm["std_reward"],
        "mean_abnormal": bm["mean_abnormal"],
        "treatment_rate": bm["treatment_rate"],
        "appropriateness": bm["appropriateness"],
    }),
    "comparisons": safe_convert(results),
    "filtering_results": safe_convert(filtering_results),
    "baselines": safe_convert(baselines),
}

report_path = "/kaggle/working/final_evaluation_report.json"
with open(report_path, "w") as f:
    json.dump(safe_convert(final_report_full), f, indent=4)

print(f"✅ Final comprehensive report saved to: {report_path}")
print("\n🎯 Visualization & Reporting Complete! All figures and reports saved in /kaggle/working/")
