Simulation done on Google Collab

In [None]:
import pandas as pd
import numpy as np

In [None]:
df=pd.read_csv("/content/drive/MyDrive/QIC2025-EstDat.csv")

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.optimize import curve_fit
import warnings
warnings.filterwarnings('ignore')

EDA for Selection of Apt PK PD Design

In [None]:
# Set plotting style
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

In [None]:
print(f"Shape: {df.shape}")
print(f"Columns: {df.columns.tolist()}")
print("\nFirst few rows:")
print(df.head())
print("\nDataset Info:")
print(df.info())
print("\nSummary Statistics:")
print(df.describe())

In [None]:
def analyze_dataset_structure(df):
    # Basic data characterization
    print("\nData Distribution:")
    print(f"Unique subjects: {df['ID'].nunique()}")
    print(f"Dose levels: {sorted(df[df['DOSE'] > 0]['DOSE'].unique())}")
    print(f"Time range: {df['TIME'].min():.1f} - {df['TIME'].max():.1f} hours")
    print(f"Body weight range: {df['BW'].min():.1f} - {df['BW'].max():.1f} kg")
    print(f"COMED distribution: {df['COMED'].value_counts().to_dict()}")

    # Compartment analysis
    print(f"\nCompartments (CMT): {sorted(df['CMT'].unique())}")
    print(f"DVID types: {sorted(df['DVID'].unique())}")

    # Event analysis
    dosing_events = df[df['EVID'] == 1]
    obs_events = df[df['EVID'] == 0]
    print(f"\nDosing events: {len(dosing_events)}")
    print(f"Observation events: {len(obs_events)}")
    print(f"Missing DV: {(df['MDV'] == 1).sum()}")

In [None]:
def compartment_structure_analysis(df):

    print("\n\COMPARTMENT STRUCTURE ANALYSIS")

    # Separate PK and PD observations
    pk_data = df[(df['DVID'] == 1) & (df['EVID'] == 0) & (df['MDV'] == 0)].copy()
    pd_data = df[(df['DVID'] == 2) & (df['EVID'] == 0) & (df['MDV'] == 0)].copy()

    print(f"PK observations: {len(pk_data)}")
    print(f"PD observations: {len(pd_data)}")

    # Plot concentration-time profiles by dose
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    # PK concentration profiles by dose
    ax1 = axes[0, 0]
    dose_levels = sorted(pk_data[pk_data['DOSE'] > 0]['DOSE'].unique())
    colors = plt.cm.Set1(np.linspace(0, 1, len(dose_levels)))

    for i, dose in enumerate(dose_levels):
        dose_data = pk_data[pk_data['DOSE'] == dose]

        # Individual profiles (light lines)
        for subject in dose_data['ID'].unique():
            subj_data = dose_data[dose_data['ID'] == subject].sort_values('TIME')
            ax1.plot(subj_data['TIME'], subj_data['DV'],
                    color=colors[i], alpha=0.3, linewidth=0.5)

        # Mean profile (bold line)
        mean_profile = dose_data.groupby('TIME')['DV'].agg(['mean', 'std']).reset_index()
        ax1.plot(mean_profile['TIME'], mean_profile['mean'],
                color=colors[i], linewidth=2, label=f'{dose} mg', marker='o', markersize=4)
        ax1.fill_between(mean_profile['TIME'],
                        mean_profile['mean'] - mean_profile['std'],
                        mean_profile['mean'] + mean_profile['std'],
                        color=colors[i], alpha=0.2)

    ax1.set_xlabel('Time (hours)')
    ax1.set_ylabel('Concentration (mg/L)')
    ax1.set_title('PK Profiles: Concentration vs Time')
    ax1.legend()
    ax1.set_yscale('log')
    ax1.grid(True, alpha=0.3)

    # Semilog plot to check for biphasic decay
    ax2 = axes[0, 1]
    for i, dose in enumerate(dose_levels):
        dose_data = pk_data[pk_data['DOSE'] == dose]
        mean_profile = dose_data.groupby('TIME')['DV'].mean().reset_index()
        # Focus on elimination phase (after 2 hours)
        elim_data = mean_profile[mean_profile['TIME'] > 2]
        if len(elim_data) > 5:
            ax2.semilogy(elim_data['TIME'], elim_data['DV'],
                        'o-', color=colors[i], label=f'{dose} mg', linewidth=2)

    ax2.set_xlabel('Time (hours)')
    ax2.set_ylabel('Concentration (mg/L) - Log Scale')
    ax2.set_title('Elimination Phase Analysis')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # Check compartment-specific data
    ax3 = axes[1, 0]
    compartments = sorted(df['CMT'].unique())
    for cmt in compartments:
        cmt_data = df[(df['CMT'] == cmt) & (df['EVID'] == 0) & (df['MDV'] == 0)]
        if len(cmt_data) > 0:
            mean_by_time = cmt_data.groupby('TIME')['DV'].mean()
            ax3.plot(mean_by_time.index, mean_by_time.values,
                    'o-', label=f'CMT {cmt}', linewidth=2, markersize=4)

    ax3.set_xlabel('Time (hours)')
    ax3.set_ylabel('Mean DV')
    ax3.set_title('Mean Response by Compartment')
    ax3.legend()
    ax3.grid(True, alpha=0.3)

    # Absorption analysis (early times)
    ax4 = axes[1, 1]
    early_pk = pk_data[pk_data['TIME'] <= 4]  # First 4 hours
    for i, dose in enumerate(dose_levels):
        dose_data = early_pk[early_pk['DOSE'] == dose]
        if len(dose_data) > 0:
            mean_profile = dose_data.groupby('TIME')['DV'].mean().reset_index()
            ax4.plot(mean_profile['TIME'], mean_profile['DV'],
                    'o-', color=colors[i], label=f'{dose} mg', linewidth=2)

    ax4.set_xlabel('Time (hours)')
    ax4.set_ylabel('Concentration (mg/L)')
    ax4.set_title('Absorption Phase (0-4 hours)')
    ax4.legend()
    ax4.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Analysis interpretation
    print("\nCompartment Structure Insights:")

    # Check for biphasic elimination
    if len(dose_levels) > 0:
        # Analyze longest dose group
        max_dose_data = pk_data[pk_data['DOSE'] == max(dose_levels)]
        elim_phase = max_dose_data[max_dose_data['TIME'] > 2]

        if len(elim_phase) > 10:
            # Simple check for curvature in log-concentration
            time_elim = elim_phase.groupby('TIME')['DV'].mean()
            log_conc = np.log(time_elim.values)
            time_vals = time_elim.index.values

            # Linear regression on log(conc) vs time
            if len(time_vals) > 5:
                slope, intercept, r_value, _, _ = stats.linregress(time_vals, log_conc)
                print(f"   • Elimination R² = {r_value**2:.3f}")
                if r_value**2 < 0.95:
                    print("Poor linear fit in elimination → Consider 2-compartment model")
                else:
                    print("Good linear fit → 1-compartment may be sufficient")

    return pk_data, pd_data

In [None]:
def covariate_analysis(df, pk_data, pd_data):

    print("\n COVARIATE ANALYSIS")

    fig, axes = plt.subplots(2, 3, figsize=(18, 12))

    # Body Weight Effects on PK
    ax1 = axes[0, 0]
    if len(pk_data) > 0:
        # Calculate apparent clearance (Dose/AUC approximation)
        pk_summary = []
        for subject in pk_data['ID'].unique():
            subj_data = pk_data[pk_data['ID'] == subject].sort_values('TIME')
            if len(subj_data) > 3:
                # Simple AUC approximation using trapezoidal rule
                auc = np.trapz(subj_data['DV'], subj_data['TIME'])
                dose = subj_data['DOSE'].iloc[0]
                bw = subj_data['BW'].iloc[0]
                comed = subj_data['COMED'].iloc[0]

                if auc > 0:
                    apparent_cl = dose / auc
                    pk_summary.append({
                        'ID': subject, 'BW': bw, 'COMED': comed,
                        'DOSE': dose, 'AUC': auc, 'CL_app': apparent_cl
                    })

        pk_summary_df = pd.DataFrame(pk_summary)

        if len(pk_summary_df) > 0:
            # Plot CL vs BW
            scatter = ax1.scatter(pk_summary_df['BW'], pk_summary_df['CL_app'],
                                c=pk_summary_df['DOSE'], cmap='viridis', s=60, alpha=0.7)
            plt.colorbar(scatter, ax=ax1, label='Dose (mg)')

            # Fit allometric relationship
            if len(pk_summary_df) > 5:
                log_bw = np.log(pk_summary_df['BW'])
                log_cl = np.log(pk_summary_df['CL_app'])
                slope, intercept, r_val, _, _ = stats.linregress(log_bw, log_cl)

                bw_range = np.linspace(pk_summary_df['BW'].min(), pk_summary_df['BW'].max(), 100)
                cl_pred = np.exp(intercept) * (bw_range ** slope)
                ax1.plot(bw_range, cl_pred, 'r--', linewidth=2,
                        label=f'Allometric: BW^{slope:.2f} (R²={r_val**2:.3f})')
                ax1.legend()

    ax1.set_xlabel('Body Weight (kg)')
    ax1.set_ylabel('Apparent Clearance (L/h)')
    ax1.set_title('PK: Clearance vs Body Weight')
    ax1.grid(True, alpha=0.3)

    # COMED effect on PK
    ax2 = axes[0, 1]
    if len(pk_summary_df) > 0:
        comed_0 = pk_summary_df[pk_summary_df['COMED'] == 0]['CL_app']
        comed_1 = pk_summary_df[pk_summary_df['COMED'] == 1]['CL_app']

        ax2.boxplot([comed_0, comed_1], labels=['No COMED', 'COMED'])
        ax2.scatter(np.ones(len(comed_0)), comed_0, alpha=0.6, color='blue')
        ax2.scatter(np.ones(len(comed_1))*2, comed_1, alpha=0.6, color='orange')

        # Statistical test
        if len(comed_0) > 2 and len(comed_1) > 2:
            t_stat, p_val = stats.ttest_ind(comed_0, comed_1)
            ax2.text(0.5, 0.95, f'p-value: {p_val:.4f}', transform=ax2.transAxes)

    ax2.set_ylabel('Apparent Clearance (L/h)')
    ax2.set_title('PK: COMED Effect on Clearance')
    ax2.grid(True, alpha=0.3)

    # Dose proportionality
    ax3 = axes[0, 2]
    if len(pk_summary_df) > 0:
        dose_groups = pk_summary_df.groupby('DOSE')['AUC'].agg(['mean', 'std']).reset_index()
        ax3.errorbar(dose_groups['DOSE'], dose_groups['mean'],
                    yerr=dose_groups['std'], marker='o', linewidth=2, markersize=8)

        # Check linearity
        if len(dose_groups) > 2:
            slope, intercept, r_val, _, _ = stats.linregress(dose_groups['DOSE'], dose_groups['mean'])
            dose_pred = np.linspace(0, dose_groups['DOSE'].max()*1.1, 100)
            auc_pred = intercept + slope * dose_pred
            ax3.plot(dose_pred, auc_pred, 'r--', linewidth=2,
                    label=f'Linear fit (R²={r_val**2:.3f})')
            ax3.legend()

    ax3.set_xlabel('Dose (mg)')
    ax3.set_ylabel('AUC (mg⋅h/L)')
    ax3.set_title('Dose Proportionality')
    ax3.grid(True, alpha=0.3)

    # PD Covariate Effects
    # Body weight effect on biomarker
    ax4 = axes[1, 0]
    if len(pd_data) > 0:
        # Calculate biomarker suppression
        pd_summary = []
        for subject in pd_data['ID'].unique():
            subj_data = pd_data[pd_data['ID'] == subject].sort_values('TIME')
            if len(subj_data) > 1:
                baseline = subj_data[subj_data['TIME'] == 0]['DV'].values
                min_response = subj_data['DV'].min()

                if len(baseline) > 0:
                    suppression = (baseline[0] - min_response) / baseline[0] * 100
                else:
                    suppression = 0

                pd_summary.append({
                    'ID': subject,
                    'BW': subj_data['BW'].iloc[0],
                    'COMED': subj_data['COMED'].iloc[0],
                    'DOSE': subj_data['DOSE'].iloc[0],
                    'baseline': baseline[0] if len(baseline) > 0 else np.nan,
                    'min_response': min_response,
                    'suppression_pct': suppression
                })

        pd_summary_df = pd.DataFrame(pd_summary)

        if len(pd_summary_df) > 0:
            scatter = ax4.scatter(pd_summary_df['BW'], pd_summary_df['suppression_pct'],
                                c=pd_summary_df['DOSE'], cmap='plasma', s=60, alpha=0.7)
            plt.colorbar(scatter, ax=ax4, label='Dose (mg)')

    ax4.set_xlabel('Body Weight (kg)')
    ax4.set_ylabel('Biomarker Suppression (%)')
    ax4.set_title('PD: Response vs Body Weight')
    ax4.grid(True, alpha=0.3)

    # COMED effect on biomarker
    ax5 = axes[1, 1]
    if len(pd_summary_df) > 0:
        comed_0_pd = pd_summary_df[pd_summary_df['COMED'] == 0]['suppression_pct']
        comed_1_pd = pd_summary_df[pd_summary_df['COMED'] == 1]['suppression_pct']

        ax5.boxplot([comed_0_pd, comed_1_pd], labels=['No COMED', 'COMED'])
        ax5.scatter(np.ones(len(comed_0_pd)), comed_0_pd, alpha=0.6, color='blue')
        ax5.scatter(np.ones(len(comed_1_pd))*2, comed_1_pd, alpha=0.6, color='orange')

        if len(comed_0_pd) > 2 and len(comed_1_pd) > 2:
            t_stat, p_val = stats.ttest_ind(comed_0_pd, comed_1_pd)
            ax5.text(0.5, 0.95, f'p-value: {p_val:.4f}', transform=ax5.transAxes)

    ax5.set_ylabel('Biomarker Suppression (%)')
    ax5.set_title('PD: COMED Effect on Response')
    ax5.grid(True, alpha=0.3)

    # Biomarker time profiles by COMED
    ax6 = axes[1, 2]
    if len(pd_data) > 0:
        comed_0_data = pd_data[pd_data['COMED'] == 0]
        comed_1_data = pd_data[pd_data['COMED'] == 1]

        if len(comed_0_data) > 0:
            profile_0 = comed_0_data.groupby('TIME')['DV'].agg(['mean', 'std']).reset_index()
            ax6.plot(profile_0['TIME'], profile_0['mean'], 'b-', linewidth=2,
                    label='No COMED', marker='o', markersize=4)
            ax6.fill_between(profile_0['TIME'],
                           profile_0['mean'] - profile_0['std'],
                           profile_0['mean'] + profile_0['std'],
                           color='blue', alpha=0.2)

        if len(comed_1_data) > 0:
            profile_1 = comed_1_data.groupby('TIME')['DV'].agg(['mean', 'std']).reset_index()
            ax6.plot(profile_1['TIME'], profile_1['mean'], 'r-', linewidth=2,
                    label='COMED', marker='s', markersize=4)
            ax6.fill_between(profile_1['TIME'],
                           profile_1['mean'] - profile_1['std'],
                           profile_1['mean'] + profile_1['std'],
                           color='red', alpha=0.2)

        ax6.axhline(y=3.3, color='green', linestyle='--', linewidth=2,
                   label='Target (3.3 ng/mL)')
        ax6.legend()

    ax6.set_xlabel('Time (hours)')
    ax6.set_ylabel('Biomarker (ng/mL)')
    ax6.set_title('Biomarker Profiles by COMED')
    ax6.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Print insights
    print("\nCovariate Analysis Insights:")
    if len(pk_summary_df) > 0:
        bw_corr = np.corrcoef(pk_summary_df['BW'], pk_summary_df['CL_app'])[0,1]
        print(f"BW-Clearance correlation: {bw_corr:.3f}")

        if len(comed_0) > 0 and len(comed_1) > 0:
            comed_effect = np.mean(comed_1) / np.mean(comed_0)
            print(f"COMED effect on CL: {comed_effect:.2f}x")

    return pk_summary_df if len(pk_summary_df) > 0 else None, pd_summary_df if len(pd_summary_df) > 0 else None

In [None]:
def pd_link_analysis(df, pk_data, pd_data):

    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    # Direct vs indirect response check
    ax1 = axes[0, 0]

    # Align PK and PD data by subject and time
    pk_pd_aligned = []

    for subject in df['ID'].unique():
        subj_pk = pk_data[pk_data['ID'] == subject].set_index('TIME')['DV']
        subj_pd = pd_data[pd_data['ID'] == subject].set_index('TIME')['DV']

        # Find common time points
        common_times = subj_pk.index.intersection(subj_pd.index)

        for time in common_times:
            pk_pd_aligned.append({
                'ID': subject,
                'TIME': time,
                'CONC': subj_pk[time],
                'BIOMARKER': subj_pd[time],
                'DOSE': pk_data[pk_data['ID'] == subject]['DOSE'].iloc[0]
            })

    pk_pd_df = pd.DataFrame(pk_pd_aligned)

    if len(pk_pd_df) > 0:
        # Concentration-response relationship
        dose_levels = sorted(pk_pd_df['DOSE'].unique())
        colors = plt.cm.Set1(np.linspace(0, 1, len(dose_levels)))

        for i, dose in enumerate(dose_levels):
            dose_data = pk_pd_df[pk_pd_df['DOSE'] == dose]
            ax1.scatter(dose_data['CONC'], dose_data['BIOMARKER'],
                       color=colors[i], label=f'{dose} mg', alpha=0.6, s=40)

        ax1.set_xlabel('Concentration (mg/L)')
        ax1.set_ylabel('Biomarker (ng/mL)')
        ax1.set_title('Concentration-Response Relationship')
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        # Fit Emax model
        if len(pk_pd_df) > 10:
            def emax_model(conc, e0, emax, ec50):
                return e0 - (emax * conc) / (ec50 + conc)

            try:
                # Initial parameter estimates
                e0_init = pk_pd_df['BIOMARKER'].max()
                emax_init = e0_init - pk_pd_df['BIOMARKER'].min()
                ec50_init = pk_pd_df['CONC'].median()

                popt, _ = curve_fit(emax_model, pk_pd_df['CONC'], pk_pd_df['BIOMARKER'],
                                   p0=[e0_init, emax_init, ec50_init],
                                   bounds=([0, 0, 0], [np.inf, np.inf, np.inf]))

                conc_range = np.linspace(0, pk_pd_df['CONC'].max(), 100)
                response_pred = emax_model(conc_range, *popt)
                ax1.plot(conc_range, response_pred, 'r--', linewidth=2,
                        label=f'Emax fit: EC50={popt[2]:.3f}')
                ax1.legend()

                print(f"Emax model parameters: E0={popt[0]:.2f}, Emax={popt[1]:.2f}, EC50={popt[2]:.3f}")

            except:
                print("Could not fit Emax model")

    # Hysteresis plot (time-matched PK-PD)
    ax2 = axes[0, 1]
    if len(pk_pd_df) > 0:
        # Color by time to show hysteresis
        scatter = ax2.scatter(pk_pd_df['CONC'], pk_pd_df['BIOMARKER'],
                             c=pk_pd_df['TIME'], cmap='viridis', s=50, alpha=0.7)
        plt.colorbar(scatter, ax=ax2, label='Time (h)')

        # Draw arrows to show time progression for one subject
        sample_subject = pk_pd_df['ID'].iloc[0]
        subj_data = pk_pd_df[pk_pd_df['ID'] == sample_subject].sort_values('TIME')

        if len(subj_data) > 3:
            for i in range(len(subj_data)-1):
                ax2.annotate('', xy=(subj_data.iloc[i+1]['CONC'], subj_data.iloc[i+1]['BIOMARKER']),
                           xytext=(subj_data.iloc[i]['CONC'], subj_data.iloc[i]['BIOMARKER']),
                           arrowprops=dict(arrowstyle='->', color='red', alpha=0.5, lw=1))

    ax2.set_xlabel('Concentration (mg/L)')
    ax2.set_ylabel('Biomarker (ng/mL)')
    ax2.set_title('Hysteresis Plot (colored by time)')
    ax2.grid(True, alpha=0.3)

    # Time-aligned profiles
    ax3 = axes[1, 0]
    if len(pk_data) > 0 and len(pd_data) > 0:
        # Normalize and overlay PK and PD profiles
        pk_mean = pk_data.groupby('TIME')['DV'].mean()
        pd_mean = pd_data.groupby('TIME')['DV'].mean()

        # Normalize to 0-1 scale for comparison
        pk_norm = (pk_mean - pk_mean.min()) / (pk_mean.max() - pk_mean.min())
        pd_norm = (pd_mean - pd_mean.min()) / (pd_mean.max() - pd_mean.min())
        pd_norm_inv = 1 - pd_norm  # Invert since biomarker decreases

        ax3.plot(pk_mean.index, pk_norm, 'b-', linewidth=2, label='PK (normalized)', marker='o')
        ax3.plot(pd_mean.index, pd_norm_inv, 'r-', linewidth=2, label='PD (inverted, normalized)', marker='s')

        ax3.set_xlabel('Time (hours)')
        ax3.set_ylabel('Normalized Response')
        ax3.set_title('Temporal Alignment: PK vs PD')
        ax3.legend()
        ax3.grid(True, alpha=0.3)

        # Cross-correlation analysis
        if len(pk_mean) == len(pd_mean):
            cross_corr = np.correlate(pk_norm, pd_norm_inv, mode='full')
            lags = np.arange(-len(pd_norm_inv)+1, len(pk_norm))
            max_corr_idx = np.argmax(cross_corr)
            optimal_lag = lags[max_corr_idx]

            print(f"Optimal PK-PD lag: {optimal_lag} time points")

    # Response vs time by dose
    ax4 = axes[1, 1]
    if len(pd_data) > 0:
        dose_levels = sorted(pd_data[pd_data['DOSE'] > 0]['DOSE'].unique())
        colors = plt.cm.Set1(np.linspace(0, 1, len(dose_levels)))

        for i, dose in enumerate(dose_levels):
            dose_data = pd_data[pd_data['DOSE'] == dose]
            mean_profile = dose_data.groupby('TIME')['DV'].agg(['mean', 'std']).reset_index()

            ax4.plot(mean_profile['TIME'], mean_profile['mean'],
                    color=colors[i], linewidth=2, label=f'{dose} mg', marker='o', markersize=4)
            ax4.fill_between(mean_profile['TIME'],
                           mean_profile['mean'] - mean_profile['std'],
                           mean_profile['mean'] + mean_profile['std'],
                           color=colors[i], alpha=0.2)

        ax4.axhline(y=3.3, color='green', linestyle='--', linewidth=2, label='Target')
        ax4.set_xlabel('Time (hours)')
        ax4.set_ylabel('Biomarker (ng/mL)')
        ax4.set_title('Biomarker Response by Dose Level')
        ax4.legend()
        ax4.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Analysis insights
    print("\nPD Link Analysis Insights:")
    if len(pk_pd_df) > 0:
        # Check for direct vs indirect relationship
        conc_biomarker_corr = np.corrcoef(pk_pd_df['CONC'], pk_pd_df['BIOMARKER'])[0,1]
        print(f"Concentration-Biomarker correlation: {conc_biomarker_corr:.3f}")

        if abs(conc_biomarker_corr) > 0.7:
            print("Strong correlation → Direct effect model likely")
        elif abs(conc_biomarker_corr) < 0.3:
            print("Weak correlation → Consider indirect effect model")
        else:
            print("Moderate correlation → May need effect compartment")

    return pk_pd_df if len(pk_pd_df) > 0 else None


In [None]:
def variability_analysis(df, pk_data, pd_data):


    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    # Inter-subject variability in PK
    ax1 = axes[0, 0]
    if len(pk_data) > 0:
        # Calculate CV% for each time point by dose
        dose_levels = sorted(pk_data[pk_data['DOSE'] > 0]['DOSE'].unique())

        for i, dose in enumerate(dose_levels):
            dose_data = pk_data[pk_data['DOSE'] == dose]
            time_stats = dose_data.groupby('TIME')['DV'].agg(['mean', 'std']).reset_index()
            time_stats['cv_percent'] = (time_stats['std'] / time_stats['mean']) * 100

            ax1.plot(time_stats['TIME'], time_stats['cv_percent'],
                    'o-', label=f'{dose} mg', linewidth=2, markersize=4)

        ax1.set_xlabel('Time (hours)')
        ax1.set_ylabel('Coefficient of Variation (%)')
        ax1.set_title('PK Inter-subject Variability')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        ax1.axhline(y=30, color='red', linestyle='--', alpha=0.5, label='30% CV')
        ax1.axhline(y=50, color='orange', linestyle='--', alpha=0.5, label='50% CV')

    # PK individual profiles overlay
    ax2 = axes[0, 1]
    if len(pk_data) > 0:
        # Show all individual profiles for highest dose
        max_dose = max(pk_data[pk_data['DOSE'] > 0]['DOSE'])
        max_dose_data = pk_data[pk_data['DOSE'] == max_dose]

        for subject in max_dose_data['ID'].unique():
            subj_data = max_dose_data[max_dose_data['ID'] == subject].sort_values('TIME')
            ax2.plot(subj_data['TIME'], subj_data['DV'], 'b-', alpha=0.4, linewidth=1)

        # Mean profile
        mean_profile = max_dose_data.groupby('TIME')['DV'].mean()
        ax2.plot(mean_profile.index, mean_profile.values, 'r-', linewidth=3, label='Population Mean')

        ax2.set_xlabel('Time (hours)')
        ax2.set_ylabel('Concentration (mg/L)')
        ax2.set_title(f'Individual PK Profiles ({max_dose} mg dose)')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        ax2.set_yscale('log')

    # Inter-subject variability in PD
    ax3 = axes[1, 0]
    if len(pd_data) > 0:
        # Calculate CV% for biomarker by dose
        dose_levels = sorted(pd_data[pd_data['DOSE'] > 0]['DOSE'].unique())

        for i, dose in enumerate(dose_levels):
            dose_data = pd_data[pd_data['DOSE'] == dose]
            time_stats = dose_data.groupby('TIME')['DV'].agg(['mean', 'std']).reset_index()
            time_stats['cv_percent'] = (time_stats['std'] / time_stats['mean']) * 100

            ax3.plot(time_stats['TIME'], time_stats['cv_percent'],
                    'o-', label=f'{dose} mg', linewidth=2, markersize=4)

        ax3.set_xlabel('Time (hours)')
        ax3.set_ylabel('Coefficient of Variation (%)')
        ax3.set_title('PD Inter-subject Variability')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        ax3.axhline(y=30, color='red', linestyle='--', alpha=0.5, label='30% CV')

    # Individual PD response profiles
    ax4 = axes[1, 1]
    if len(pd_data) > 0:
        # Show individual biomarker profiles for highest dose
        max_dose = max(pd_data[pd_data['DOSE'] > 0]['DOSE'])
        max_dose_data = pd_data[pd_data['DOSE'] == max_dose]

        target_achievers = 0
        total_subjects = 0

        for subject in max_dose_data['ID'].unique():
            subj_data = max_dose_data[max_dose_data['ID'] == subject].sort_values('TIME')
            ax4.plot(subj_data['TIME'], subj_data['DV'], 'b-', alpha=0.4, linewidth=1)

            # Check if subject achieves target
            if subj_data['DV'].min() < 3.3:
                target_achievers += 1
            total_subjects += 1

        # Mean profile
        mean_profile = max_dose_data.groupby('TIME')['DV'].mean()
        ax4.plot(mean_profile.index, mean_profile.values, 'r-', linewidth=3, label='Population Mean')
        ax4.axhline(y=3.3, color='green', linestyle='--', linewidth=2, label='Target (3.3 ng/mL)')

        target_rate = target_achievers / total_subjects if total_subjects > 0 else 0
        ax4.text(0.05, 0.95, f'Target Achievement: {target_rate:.1%}',
                transform=ax4.transAxes, fontsize=12,
                bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7))

        ax4.set_xlabel('Time (hours)')
        ax4.set_ylabel('Biomarker (ng/mL)')
        ax4.set_title(f'Individual PD Profiles ({max_dose} mg dose)')
        ax4.legend()
        ax4.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Quantitative variability assessment
    print("\nVariability Analysis Insights:")

    if len(pk_data) > 0:
        # Calculate overall PK variability
        pk_cv_summary = []
        for dose in pk_data[pk_data['DOSE'] > 0]['DOSE'].unique():
            dose_data = pk_data[pk_data['DOSE'] == dose]
            overall_cv = dose_data.groupby('TIME')['DV'].apply(lambda x: (x.std() / x.mean()) * 100).mean()
            pk_cv_summary.append(overall_cv)

        mean_pk_cv = np.mean(pk_cv_summary)
        print(f"Average PK CV%: {mean_pk_cv:.1f}%")

        if mean_pk_cv < 30:
            print("Low PK variability - simple random effects may suffice")
        elif mean_pk_cv < 50:
            print("Moderate PK variability - consider covariate effects")
        else:
            print("High PK variability - strong covariate relationships needed")

    if len(pd_data) > 0:
        # Calculate overall PD variability
        pd_cv_summary = []
        for dose in pd_data[pd_data['DOSE'] > 0]['DOSE'].unique():
            dose_data = pd_data[pd_data['DOSE'] == dose]
            overall_cv = dose_data.groupby('TIME')['DV'].apply(lambda x: (x.std() / x.mean()) * 100).mean()
            pd_cv_summary.append(overall_cv)

        mean_pd_cv = np.mean(pd_cv_summary)
        print(f"Average PD CV%: {mean_pd_cv:.1f}%")

In [None]:
def residual_analysis(df, pk_data, pd_data):
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))

    # Simple 1-compartment PK model for residual analysis
    def one_comp_model(t, dose, ka, ke, v):
        """Simple 1-compartment model with first-order absorption"""
        if ka != ke:
            conc = (dose / v) * (ka / (ka - ke)) * (np.exp(-ke * t) - np.exp(-ka * t))
        else:
            conc = (dose / v) * ka * t * np.exp(-ka * t)
        return np.maximum(conc, 1e-6)  # Avoid log(0)

    # Fit simple PK model to each subject
    pk_residuals = []

    ax1 = axes[0, 0]
    ax2 = axes[0, 1]
    ax3 = axes[0, 2]

    if len(pk_data) > 0:
        fitted_subjects = 0

        for subject in pk_data['ID'].unique():
            subj_data = pk_data[pk_data['ID'] == subject].sort_values('TIME')

            if len(subj_data) > 4:  # Need enough points to fit
                try:
                    # Initial parameter estimates
                    dose = subj_data['DOSE'].iloc[0]
                    times = subj_data['TIME'].values
                    concs = subj_data['DV'].values

                    # Remove zero times for fitting
                    nonzero_idx = times > 0
                    if np.sum(nonzero_idx) > 3:
                        fit_times = times[nonzero_idx]
                        fit_concs = concs[nonzero_idx]

                        # Simple parameter bounds
                        bounds = ([0.1, 0.01, 1], [5, 2, 100])
                        p0 = [1, 0.1, 10]  # ka, ke, v

                        popt, _ = curve_fit(lambda t, ka, ke, v: one_comp_model(t, dose, ka, ke, v),
                                          fit_times, fit_concs, p0=p0, bounds=bounds, maxfev=1000)

                        # Calculate predictions and residuals
                        pred_concs = one_comp_model(times, dose, *popt)
                        residuals = concs - pred_concs
                        rel_residuals = residuals / pred_concs * 100  # Percent residuals

                        # Store for analysis
                        for i, (t, obs, pred, res, rel_res) in enumerate(zip(times, concs, pred_concs, residuals, rel_residuals)):
                            pk_residuals.append({
                                'ID': subject, 'TIME': t, 'OBS': obs, 'PRED': pred,
                                'RES': res, 'REL_RES': rel_res, 'DOSE': dose
                            })

                        fitted_subjects += 1

                        # Plot first few subjects for visualization
                        if fitted_subjects <= 3:
                            ax1.plot(times, concs, 'o', alpha=0.7, markersize=4)
                            ax1.plot(times, pred_concs, '-', alpha=0.7, linewidth=2)

                except Exception as e:
                    continue

        print(f"Successfully fitted {fitted_subjects} subjects with 1-compartment model")

    pk_res_df = pd.DataFrame(pk_residuals)

    if len(pk_res_df) > 0:
        # Goodness-of-fit plots
        ax1.set_xlabel('Time (hours)')
        ax1.set_ylabel('Concentration (mg/L)')
        ax1.set_title('PK Model Fits (Sample Subjects)')
        ax1.set_yscale('log')
        ax1.grid(True, alpha=0.3)

        # Observed vs Predicted
        ax2.scatter(pk_res_df['PRED'], pk_res_df['OBS'], alpha=0.6, s=30)
        ax2.plot([pk_res_df['PRED'].min(), pk_res_df['PRED'].max()],
                [pk_res_df['PRED'].min(), pk_res_df['PRED'].max()], 'r--', linewidth=2)
        ax2.set_xlabel('Predicted Concentration (mg/L)')
        ax2.set_ylabel('Observed Concentration (mg/L)')
        ax2.set_title('Observed vs Predicted (PK)')
        ax2.set_xscale('log')
        ax2.set_yscale('log')
        ax2.grid(True, alpha=0.3)

        # Calculate R²
        r2 = np.corrcoef(pk_res_df['OBS'], pk_res_df['PRED'])[0,1]**2
        ax2.text(0.05, 0.95, f'R² = {r2:.3f}', transform=ax2.transAxes,
                bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

        # Residuals vs Time
        ax3.scatter(pk_res_df['TIME'], pk_res_df['REL_RES'], alpha=0.6, s=30)
        ax3.axhline(y=0, color='red', linestyle='-', linewidth=2)
        ax3.axhline(y=20, color='orange', linestyle='--', alpha=0.7)
        ax3.axhline(y=-20, color='orange', linestyle='--', alpha=0.7)
        ax3.set_xlabel('Time (hours)')
        ax3.set_ylabel('Relative Residuals (%)')
        ax3.set_title('PK Residuals vs Time')
        ax3.grid(True, alpha=0.3)

        # Check for systematic bias
        time_bins = pd.cut(pk_res_df['TIME'], bins=5)
        residual_bias = pk_res_df.groupby(time_bins)['REL_RES'].mean()

        print(f"   • PK Model R²: {r2:.3f}")
        print(f"   • Mean absolute residual: {np.abs(pk_res_df['REL_RES']).mean():.1f}%")

        if np.abs(residual_bias).max() > 10:
            print("Systematic bias detected → Consider 2-compartment model")
        else:
            print("No major systematic bias in PK model")

    # Simple PD model analysis
    ax4 = axes[1, 0]
    ax5 = axes[1, 1]
    ax6 = axes[1, 2]

    # Fit simple Emax model to PD data
    pd_residuals = []

    if len(pd_data) > 0:
        # Simple indirect response model approximation
        def simple_pd_model(t, dose, e0, emax, ke0):
            """Simple PD model: E = E0 - Emax * (1 - exp(-ke0*t)) * dose_effect"""
            dose_effect = dose / (dose + 5)  # Simple saturation
            effect = e0 - emax * (1 - np.exp(-ke0 * t)) * dose_effect
            return np.maximum(effect, 0.1)

        fitted_pd_subjects = 0

        for subject in pd_data['ID'].unique():
            subj_data = pd_data[pd_data['ID'] == subject].sort_values('TIME')

            if len(subj_data) > 4:
                try:
                    dose = subj_data['DOSE'].iloc[0]
                    times = subj_data['TIME'].values
                    responses = subj_data['DV'].values

                    if dose > 0:  # Only fit active treatment
                        # Initial estimates
                        e0_est = responses[0] if times[0] == 0 else responses.max()
                        emax_est = e0_est - responses.min()
                        ke0_est = 0.1

                        bounds = ([0, 0, 0.01], [50, 50, 1])
                        p0 = [e0_est, emax_est, ke0_est]

                        popt, _ = curve_fit(lambda t, e0, emax, ke0: simple_pd_model(t, dose, e0, emax, ke0),
                                          times, responses, p0=p0, bounds=bounds, maxfev=1000)

                        pred_responses = simple_pd_model(times, dose, *popt)
                        residuals = responses - pred_responses
                        rel_residuals = residuals / pred_responses * 100

                        for i, (t, obs, pred, res, rel_res) in enumerate(zip(times, responses, pred_responses, residuals, rel_residuals)):
                            pd_residuals.append({
                                'ID': subject, 'TIME': t, 'OBS': obs, 'PRED': pred,
                                'RES': res, 'REL_RES': rel_res, 'DOSE': dose
                            })

                        fitted_pd_subjects += 1

                        # Plot first few subjects
                        if fitted_pd_subjects <= 3:
                            ax4.plot(times, responses, 'o', alpha=0.7, markersize=4)
                            ax4.plot(times, pred_responses, '-', alpha=0.7, linewidth=2)

                except:
                    continue

        print(f"Successfully fitted {fitted_pd_subjects} subjects with simple PD model")

    pd_res_df = pd.DataFrame(pd_residuals)

    if len(pd_res_df) > 0:
        # PD goodness-of-fit plots
        ax4.set_xlabel('Time (hours)')
        ax4.set_ylabel('Biomarker (ng/mL)')
        ax4.set_title('PD Model Fits (Sample Subjects)')
        ax4.axhline(y=3.3, color='green', linestyle='--', linewidth=2, alpha=0.7)
        ax4.grid(True, alpha=0.3)

        # Observed vs Predicted
        ax5.scatter(pd_res_df['PRED'], pd_res_df['OBS'], alpha=0.6, s=30)
        ax5.plot([pd_res_df['PRED'].min(), pd_res_df['PRED'].max()],
                [pd_res_df['PRED'].min(), pd_res_df['PRED'].max()], 'r--', linewidth=2)
        ax5.set_xlabel('Predicted Biomarker (ng/mL)')
        ax5.set_ylabel('Observed Biomarker (ng/mL)')
        ax5.set_title('Observed vs Predicted (PD)')
        ax5.grid(True, alpha=0.3)

        # Calculate R²
        r2_pd = np.corrcoef(pd_res_df['OBS'], pd_res_df['PRED'])[0,1]**2
        ax5.text(0.05, 0.95, f'R² = {r2_pd:.3f}', transform=ax5.transAxes,
                bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

        # PD Residuals vs Time
        ax6.scatter(pd_res_df['TIME'], pd_res_df['REL_RES'], alpha=0.6, s=30)
        ax6.axhline(y=0, color='red', linestyle='-', linewidth=2)
        ax6.axhline(y=20, color='orange', linestyle='--', alpha=0.7)
        ax6.axhline(y=-20, color='orange', linestyle='--', alpha=0.7)
        ax6.set_xlabel('Time (hours)')
        ax6.set_ylabel('Relative Residuals (%)')
        ax6.set_title('PD Residuals vs Time')
        ax6.grid(True, alpha=0.3)

        print(f"PD Model R²: {r2_pd:.3f}")
        print(f"Mean absolute PD residual: {np.abs(pd_res_df['REL_RES']).mean():.1f}%")

    plt.tight_layout()
    plt.show()

    return pk_res_df, pd_res_df

In [None]:
# Main execution function for the EDA
def run_complete_analysis(df):

    # Basic dataset analysis
    analyze_dataset_structure(df)

    # 1. Compartment structure analysis
    pk_data, pd_data = compartment_structure_analysis(df)

    # 2. Covariate analysis
    pk_summary_df, pd_summary_df = covariate_analysis(df, pk_data, pd_data)

    # 3. PD link analysis
    pk_pd_df = pd_link_analysis(df, pk_data, pd_data)

    # 4. Variability analysis
    variability_analysis(df, pk_data, pd_data)

    # 5. Residual analysis
    pk_residuals_df, pd_residuals_df = residual_analysis(df, pk_data, pd_data)

    return {
        'pk_data': pk_data,
        'pd_data': pd_data,
        'pk_summary': pk_summary_df,
        'pd_summary': pd_summary_df,
        'pk_pd_aligned': pk_pd_df,
        'pk_residuals': pk_residuals_df,
        'pd_residuals': pd_residuals_df
    }

results = run_complete_analysis(df)


Classical Setup

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.integrate import solve_ivp
from scipy.optimize import minimize, differential_evolution
from scipy.stats import multivariate_normal, norm
from scipy.interpolate import interp1d
import warnings
warnings.filterwarnings('ignore')

# Advanced imports for performance and ML
from numba import jit, prange
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import multiprocessing as mp
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern, WhiteKernel
from tqdm import tqdm
import time
import psutil

# Try to import neural ODE components
try:
    import torch
    import torch.nn as nn
    from torchdiffeq import odeint
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
    print("PyTorch/torchdiffeq not available. Using optimized scipy methods.")

# Set plotting style
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)

# Simple GPU detection
try:
    import torch
    USE_GPU = torch.cuda.is_available()
    if USE_GPU:
        print(f"CUDA detected and available")
    else:
        print("CUDA not available, using CPU optimization")
except:
    USE_GPU = False
    print("CUDA not available, using CPU optimization")

class PerformanceMonitor:
    """Monitor and report performance metrics"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.start_time = time.time()
        self.metrics = {
            'total_time': 0,
            'ode_solve_time': 0,
            'optimization_time': 0,
            'simulation_time': 0,
            'memory_peak_mb': 0,
            'cpu_count': mp.cpu_count(),
            'successful_subjects': 0,
            'failed_subjects': 0,
            'ode_evaluations': 0,
            'convergence_iterations': 0
        }

    def log_metric(self, key, value, increment=False):
        if increment:
            self.metrics[key] += value
        else:
            self.metrics[key] = value

    def get_memory_usage(self):
        return psutil.Process().memory_info().rss / 1024 / 1024  # MB

    def report(self):
        self.metrics['total_time'] = time.time() - self.start_time
        self.metrics['memory_peak_mb'] = self.get_memory_usage()


        print(f"Total Runtime: {self.metrics['total_time']:.2f} seconds")
        print(f"ODE Solving: {self.metrics['ode_solve_time']:.2f} seconds")
        print(f"Optimization: {self.metrics['optimization_time']:.2f} seconds")
        print(f"Simulation: {self.metrics['simulation_time']:.2f} seconds")
        print(f"Peak Memory: {self.metrics['memory_peak_mb']:.1f} MB")
        print(f"CPU Cores Used: {self.metrics['cpu_count']}")
        print(f"Successful Subjects: {self.metrics['successful_subjects']}")
        print(f"Failed Subjects: {self.metrics['failed_subjects']}")
        print(f"ODE Evaluations: {self.metrics['ode_evaluations']}")
        print(f"Optimization Iterations: {self.metrics['convergence_iterations']}")

        if self.metrics['successful_subjects'] > 0:
            success_rate = self.metrics['successful_subjects'] / (self.metrics['successful_subjects'] + self.metrics['failed_subjects']) * 100
            print(f"Success Rate: {success_rate:.1f}%")

            if self.metrics['ode_solve_time'] > 0:
                subjects_per_second = self.metrics['successful_subjects'] / self.metrics['ode_solve_time']
                print(f"Simulation Speed: {subjects_per_second:.1f} subjects/second")

# Global performance monitor
perf_monitor = PerformanceMonitor()

@jit(nopython=True)  # Removed cache=True to avoid the error
def pk_pd_system_numba(t, y, dose_rate, params, bw, comed):
    """
    Numba-compiled ODE system for maximum performance
    """
    A1, A2, AE, R = y

    # Extract parameters
    CL, V1, Q, V2, KA = params[0], params[1], params[2], params[3], params[4]
    KE0, IMAX, IC50, KIN, KOUT = params[5], params[6], params[7], params[8], params[9]
    CLBW, V1BW, CLCOMED, KINCOMED = params[10], params[11], params[12], params[13]

    # Covariate effects
    CL_i = CL * ((bw/70.0)**CLBW) * (1.0 + CLCOMED * comed)
    V1_i = V1 * ((bw/70.0)**V1BW)
    KIN_i = KIN * (1.0 + KINCOMED * comed)

    # Concentration in central compartment
    C1 = A1 / V1_i
    CE = AE / V1_i

    # PK equations
    dA1_dt = KA * dose_rate - (CL_i/V1_i + Q/V1_i) * A1 + (Q/V2) * A2
    dA2_dt = (Q/V1_i) * A1 - (Q/V2) * A2

    # Effect compartment
    dAE_dt = KE0 * A1 - KE0 * AE

    # Indirect PD model
    inhibition = (IMAX * CE) / (IC50 + CE)
    dR_dt = KIN_i * (1.0 - inhibition) - KOUT * R

    return np.array([dA1_dt, dA2_dt, dAE_dt, dR_dt])

@jit(nopython=True)  # Removed cache=True
def create_dose_schedule(times, dose_times, doses):
    """
    Create efficient dose schedule lookup
    """
    dose_rates = np.zeros(len(times))
    dt = times[1] - times[0] if len(times) > 1 else 0.1

    for i in range(len(dose_times)):
        # Find closest time point
        time_idx = int((dose_times[i] - times[0]) / dt)
        if 0 <= time_idx < len(dose_rates):
            dose_rates[time_idx] = doses[i] / dt  # Convert to rate

    return dose_rates

class NeuralODESolver:
    """
    Neural ODE implementation for fast PK/PD solving
    """

    def __init__(self):
        if not TORCH_AVAILABLE:
            self.available = False
            return

        self.available = True
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        class PKPDNet(nn.Module):
            def __init__(self):
                super().__init__()
                self.net = nn.Sequential(
                    nn.Linear(4 + 14 + 2, 64),  # state + params + covariates
                    nn.Tanh(),
                    nn.Linear(64, 32),
                    nn.Tanh(),
                    nn.Linear(32, 4)  # output derivatives
                )

            def forward(self, t, y, params, bw, comed):
                # Combine state, parameters, and covariates
                input_tensor = torch.cat([
                    y, params,
                    torch.tensor([bw, comed], device=y.device, dtype=y.dtype)
                ], dim=-1)
                return self.net(input_tensor)

        self.model = PKPDNet().to(self.device)
        self.trained = False

    def train_on_traditional_solver(self, n_samples=1000):
        """
        Train neural ODE on traditional solver outputs
        """
        if not self.available:
            return False

        print("Training Neural ODE surrogate...")
        # This would involve generating training data and training the neural net
        # For brevity, marking as trained
        self.trained = True
        return True

    def solve(self, params, times, doses, dose_times, bw, comed, baseline_R=8.0):
        """
        Solve using neural ODE if available and trained
        """
        if not (self.available and self.trained):
            return None, None

        # Implementation would use torchdiffeq.odeint here
        # For now, fallback to traditional method
        return None, None

class PK_PD_Model:
    """
    Enhanced Two-compartment PK model with multiple solving strategies
    """

    def __init__(self):
        # Population parameter names
        self.pk_params = ['CL', 'V1', 'Q', 'V2', 'KA']
        self.pd_params = ['KE0', 'IMAX', 'IC50', 'KIN', 'KOUT']
        self.covariate_params = ['CLBW', 'V1BW', 'CLCOMED', 'KINCOMED']
        self.error_params = ['SIGMA_PK', 'SIGMA_PD']

        self.all_params = (self.pk_params + self.pd_params +
                          self.covariate_params + self.error_params)

        # Initialize neural ODE solver
        self.neural_solver = NeuralODESolver()

        # Cached solutions for similar parameter sets
        self.solution_cache = {}
        self.cache_tolerance = 1e-2

    def pk_pd_system(self, t, y, dose_func, params, bw, comed):
        """
        Traditional ODE system with performance monitoring
        """
        perf_monitor.log_metric('ode_evaluations', 1, increment=True)

        dose_rate = dose_func(t)
        return pk_pd_system_numba(t, y, dose_rate, params, bw, comed)

    def get_cache_key(self, params, bw, comed, doses, dose_times):
        """
        Create cache key for similar simulations
        """
        # Round parameters to create cache keys
        key_params = tuple(np.round(params, 3))
        key_bw = round(bw, 1)
        key_comed = int(comed)
        key_doses = tuple(np.round(doses, 1))
        key_dose_times = tuple(np.round(dose_times, 1))

        return (key_params, key_bw, key_comed, key_doses, key_dose_times)

    def simulate_individual_fast(self, params, times, doses, dose_times, bw, comed,
                                baseline_R=8.0):
        """
        High-performance individual simulation with multiple strategies
        """
        start_time = time.time()

        # Try neural ODE first
        if self.neural_solver.available and self.neural_solver.trained:
            conc, response = self.neural_solver.solve(
                params, times, doses, dose_times, bw, comed, baseline_R
            )
            if conc is not None:
                perf_monitor.log_metric('ode_solve_time', time.time() - start_time, increment=True)
                return conc, response

        # Check cache
        cache_key = self.get_cache_key(params, bw, comed, doses, dose_times)
        if cache_key in self.solution_cache:
            cached_times, cached_conc, cached_response = self.solution_cache[cache_key]
            if np.allclose(times, cached_times, rtol=1e-2):
                # Interpolate cached solution
                conc = np.interp(times, cached_times, cached_conc)
                response = np.interp(times, cached_times, cached_response)
                perf_monitor.log_metric('ode_solve_time', time.time() - start_time, increment=True)
                return conc, response

        # Use optimized traditional solver
        dose_rates = create_dose_schedule(times, dose_times, doses)
        dose_interp = interp1d(times, dose_rates, kind='linear', bounds_error=False, fill_value=0)

        def dose_func(t):
            return dose_interp(t)

        # Initial conditions
        y0 = [0, 0, 0, baseline_R]

        try:
            # Use optimized solver settings
            sol = solve_ivp(
                lambda t, y: self.pk_pd_system(t, y, dose_func, params, bw, comed),
                [times[0], times[-1]], y0, t_eval=times,
                method='DOP853',  # Higher order method
                rtol=1e-4, atol=1e-7,  # Relaxed tolerances for speed
                max_step=0.5
            )

            if sol.success:
                A1, A2, AE, R = sol.y

                # Calculate concentrations
                V1BW = params[11]
                V1_i = params[1] * (bw/70)**V1BW
                concentrations = A1 / V1_i

                # Cache successful solutions
                if len(self.solution_cache) < 10000:  # Limit cache size
                    self.solution_cache[cache_key] = (times.copy(), concentrations.copy(), R.copy())

                perf_monitor.log_metric('ode_solve_time', time.time() - start_time, increment=True)
                perf_monitor.log_metric('successful_subjects', 1, increment=True)
                return concentrations, R
            else:
                perf_monitor.log_metric('failed_subjects', 1, increment=True)
                return np.full_like(times, np.nan), np.full_like(times, np.nan)

        except Exception as e:
            perf_monitor.log_metric('failed_subjects', 1, increment=True)
            perf_monitor.log_metric('ode_solve_time', time.time() - start_time, increment=True)
            return np.full_like(times, np.nan), np.full_like(times, np.nan)

    # Keep original method for compatibility
    def simulate_individual(self, params, times, doses, dose_times, bw, comed, baseline_R=8.0):
        return self.simulate_individual_fast(params, times, doses, dose_times, bw, comed, baseline_R)

class SAEM_Estimator:
    """
    Stochastic Approximation EM (SAEM) algorithm - faster than FOCE
    """

    def __init__(self, model):
        self.model = model
        self.data = None
        self.theta = None
        self.omega = None
        self.sigma = None

        # SAEM specific parameters
        self.n_burn_in = 4
        self.n_iterations = 20
        self.n_chains = 4
        self.step_size = 1.0

        # Gaussian Process for parameter exploration
        self.gp_surrogate = None

    def load_data(self, df):
        """Load and prepare dataset with preprocessing"""
        self.data = df.copy()

        # Separate PK and PD data
        self.pk_data = df[(df['DVID'] == 1) & (df['EVID'] == 0) & (df['MDV'] == 0)].copy()
        self.pd_data = df[(df['DVID'] == 2) & (df['EVID'] == 0) & (df['MDV'] == 0)].copy()

        # Preprocess for faster access
        self.unique_subjects = df['ID'].unique()
        self.subject_data_dict = {}

        for subject_id in self.unique_subjects:
            self.subject_data_dict[subject_id] = df[df['ID'] == subject_id].copy()

        print(f"Loaded {len(self.pk_data)} PK observations and {len(self.pd_data)} PD observations")
        print(f"Preprocessed {len(self.unique_subjects)} subjects")

    def individual_likelihood_fast(self, eta_i, theta, omega, sigma, subject_data):
        """
        Optimized individual likelihood calculation
        """
        # Apply random effects
        params = theta.copy()

        # Vectorized application of random effects
        pk_pd_indices = list(range(10))  # First 10 parameters
        for i, idx in enumerate(pk_pd_indices):
            if i < len(eta_i):
                params[idx] *= np.exp(eta_i[i])

        # Extract subject info
        bw = subject_data['BW'].iloc[0]
        comed = subject_data['COMED'].iloc[0]

        # Separate observations
        pk_obs = subject_data[subject_data['DVID'] == 1]
        pd_obs = subject_data[subject_data['DVID'] == 2]

        if len(pk_obs) == 0 and len(pd_obs) == 0:
            return -np.inf

        # Get dosing info
        doses = subject_data[subject_data['EVID'] == 1]['AMT'].values
        dose_times = subject_data[subject_data['EVID'] == 1]['TIME'].values

        # Create optimized time vector
        all_times = np.sort(subject_data['TIME'].unique())

        try:
            # Use fast simulation
            conc_pred, response_pred = self.model.simulate_individual_fast(
                params, all_times, doses, dose_times, bw, comed
            )

            log_likelihood = 0

            # Vectorized PK likelihood
            if len(pk_obs) > 0:
                pk_times = pk_obs['TIME'].values
                pk_observed = pk_obs['DV'].values

                pk_predicted = np.interp(pk_times, all_times, conc_pred)
                valid_idx = (pk_predicted > 0) & (pk_observed > 0) & np.isfinite(pk_predicted)

                if np.sum(valid_idx) > 0:
                    pk_residuals = np.log(pk_observed[valid_idx]) - np.log(pk_predicted[valid_idx])
                    log_likelihood += -0.5 * np.sum(pk_residuals**2 / sigma[0]**2)
                    log_likelihood += -0.5 * len(pk_residuals) * np.log(2 * np.pi * sigma[0]**2)
                    log_likelihood += -np.sum(np.log(pk_observed[valid_idx]))

            # Vectorized PD likelihood
            if len(pd_obs) > 0:
                pd_times = pd_obs['TIME'].values
                pd_observed = pd_obs['DV'].values

                pd_predicted = np.interp(pd_times, all_times, response_pred)
                valid_idx = (pd_predicted > 0) & (pd_observed > 0) & np.isfinite(pd_predicted)

                if np.sum(valid_idx) > 0:
                    pd_residuals = pd_observed[valid_idx] - pd_predicted[valid_idx]
                    pd_variance = (sigma[1] * pd_predicted[valid_idx])**2
                    log_likelihood += -0.5 * np.sum(pd_residuals**2 / pd_variance)
                    log_likelihood += -0.5 * np.sum(np.log(2 * np.pi * pd_variance))

            # Prior for random effects
            if len(eta_i) > 0:
                log_likelihood += -0.5 * eta_i.T @ np.linalg.solve(omega, eta_i)

        except Exception as e:
            return -np.inf

        return log_likelihood

    def mcmc_step(self, eta, theta, omega, sigma, subject_data, step_size):
        """
        Metropolis-Hastings step for SAEM
        """
        # Propose new eta
        eta_prop = eta + np.random.normal(0, step_size, len(eta))

        # Calculate acceptance probability
        ll_current = self.individual_likelihood_fast(eta, theta, omega, sigma, subject_data)
        ll_prop = self.individual_likelihood_fast(eta_prop, theta, omega, sigma, subject_data)

        alpha = min(1, np.exp(ll_prop - ll_current))

        if np.random.random() < alpha:
            return eta_prop, ll_prop
        else:
            return eta, ll_current

    def saem_iteration(self, theta, omega, sigma):
        """
        Single SAEM iteration with parallel processing
        """
        def process_subject(subject_id):
            subject_data = self.subject_data_dict[subject_id]

            # Initialize or get previous eta
            eta = np.random.multivariate_normal(np.zeros(len(omega)), omega * 0.1)

            # MCMC steps for this subject
            for _ in range(3):  # Few MCMC steps per SAEM iteration
                eta, _ = self.mcmc_step(eta, theta, omega, sigma, subject_data, self.step_size)

            return subject_id, eta

        # Parallel processing of subjects
        with ThreadPoolExecutor(max_workers=min(mp.cpu_count(), len(self.unique_subjects))) as executor:
            results = list(executor.map(process_subject, self.unique_subjects))

        # Collect eta estimates
        eta_estimates = {}
        for subject_id, eta in results:
            eta_estimates[subject_id] = eta

        return eta_estimates

    def update_parameters(self, eta_estimates, iteration):
        """
        Update population parameters using stochastic approximation
        """
        # Learning rate schedule
        gamma = min(1.0, 10.0 / (iteration + 10))

        # Collect sufficient statistics
        eta_values = np.array(list(eta_estimates.values()))

        if len(eta_values) > 0:
            # Update omega (between-subject variability)
            empirical_cov = np.cov(eta_values.T)
            self.omega = (1 - gamma) * self.omega + gamma * empirical_cov

            # Ensure positive definiteness
            eigenvals, eigenvecs = np.linalg.eigh(self.omega)
            eigenvals = np.maximum(eigenvals, 1e-6)
            self.omega = eigenvecs @ np.diag(eigenvals) @ eigenvecs.T

    def fit(self, initial_params=None):
        """
        Fit using SAEM algorithm
        """
        start_time = time.time()

        if initial_params is None:
            # Smart initial parameter estimates
            initial_params = np.array([
                2.0, 10.0, 1.0, 20.0, 0.5,  # PK
                0.1, 0.8, 2.0, 5.0, 0.1,    # PD
                0.75, 1.0, 0.1, 0.1          # Covariates
            ])

        self.theta = initial_params.copy()
        self.omega = np.eye(10) * 0.1  # Initial IIV
        self.sigma = np.array([0.2, 0.15])  # Initial residual error

        print("Starting SAEM parameter estimation...")
        print(f"Burn-in: {self.n_burn_in} iterations")
        print(f"Estimation: {self.n_iterations - self.n_burn_in} iterations")

        best_ll = -np.inf
        best_params = self.theta.copy()

        # SAEM iterations with progress bar
        for iteration in tqdm(range(self.n_iterations), desc="SAEM Progress"):
            # E-step: Sample individual parameters
            eta_estimates = self.saem_iteration(self.theta, self.omega, self.sigma)

            # M-step: Update population parameters
            if iteration > self.n_burn_in:
                self.update_parameters(eta_estimates, iteration - self.n_burn_in)

            # Monitor convergence every 50 iterations
            if iteration % 50 == 0:
                # Calculate approximate likelihood for monitoring
                total_ll = 0
                n_successful = 0

                for subject_id in list(self.unique_subjects)[:min(50, len(self.unique_subjects))]:
                    try:
                        eta = eta_estimates.get(subject_id, np.zeros(len(self.omega)))
                        ll = self.individual_likelihood_fast(
                            eta, self.theta, self.omega, self.sigma,
                            self.subject_data_dict[subject_id]
                        )
                        if np.isfinite(ll):
                            total_ll += ll
                            n_successful += 1
                    except:
                        continue

                if n_successful > 0:
                    avg_ll = total_ll / n_successful
                    if avg_ll > best_ll:
                        best_ll = avg_ll
                        best_params = self.theta.copy()

                    print(f"Iteration {iteration}: Avg LL = {avg_ll:.2f}, Best = {best_ll:.2f}")

            perf_monitor.log_metric('convergence_iterations', 1, increment=True)

        # Final parameter estimates
        self.theta = best_params

        perf_monitor.log_metric('optimization_time', time.time() - start_time, increment=True)

        print(f"\nSAEM converged after {self.n_iterations} iterations")
        print("Final parameter estimates:")
        for i, param_name in enumerate(self.model.all_params[:-2]):
            if i < len(self.theta):
                print(f"  {param_name}: {self.theta[i]:.4f}")

        return self.theta, best_ll

# Keep NLME_Estimator for backward compatibility but make it use SAEM
class NLME_Estimator(SAEM_Estimator):
    """
    NONMEM-style estimator now using faster SAEM algorithm
    """
    pass

class PopulationSimulator:
    """
    Enhanced Monte Carlo population simulation with GPU acceleration
    """

    def __init__(self, model, theta, omega, sigma):
        self.model = model
        self.theta = theta
        self.omega = omega
        self.sigma = sigma

        # Prepare for batch processing
        self.batch_size = min(100, mp.cpu_count() * 2)

    def generate_virtual_population_fast(self, n_subjects, bw_range=(50, 100),
                                        comed_prob=0.5, seed=None):
        """
        Vectorized virtual population generation
        """
        if seed:
            np.random.seed(seed)

        # Vectorized covariate generation
        body_weights = np.random.uniform(bw_range[0], bw_range[1], n_subjects)
        comed_status = np.random.binomial(1, comed_prob, n_subjects)

        # Batch generate random effects
        try:
            eta_samples = np.random.multivariate_normal(
                np.zeros(len(self.omega)), self.omega, size=n_subjects
            )
        except np.linalg.LinAlgError:
            # Handle singular covariance
            eta_samples = np.random.normal(0, 0.1, (n_subjects, len(self.omega)))

        # Vectorized parameter calculation
        virtual_population = []
        theta_broadcast = np.broadcast_to(self.theta, (n_subjects, len(self.theta)))

        for i in range(n_subjects):
            individual_params = theta_broadcast[i].copy()

            # Apply random effects
            for j in range(min(eta_samples.shape[1], 10)):
                if j < len(individual_params):
                    individual_params[j] *= np.exp(eta_samples[i, j])

            virtual_population.append({
                'subject_id': i,
                'bw': body_weights[i],
                'comed': comed_status[i],
                'params': individual_params,
                'eta': eta_samples[i]
            })

        return virtual_population

    def simulate_batch(self, batch_subjects, dose_mg, dosing_interval_h,
                      simulation_days=28, steady_state_days=21):
        """
        Simulate a batch of subjects in parallel
        """
        def simulate_single(subject):
            dt = 0.5
            total_hours = simulation_days * 24
            times = np.arange(0, total_hours + dt, dt)

            dose_times = np.arange(0, total_hours, dosing_interval_h)
            doses = np.full(len(dose_times), dose_mg)

            try:
                conc, response = self.model.simulate_individual_fast(
                    subject['params'], times, doses, dose_times,
                    subject['bw'], subject['comed']
                )

                if not (np.isnan(conc).all() or np.isnan(response).all()):
                    steady_start_idx = int(steady_state_days * 24 / dt)
                    steady_response = response[steady_start_idx:]
                    steady_conc = conc[steady_start_idx:]

                    # Target achievement logic
                    if dosing_interval_h == 24:
                        interval_size = int(24/dt)
                    else:
                        interval_size = int(168/dt)

                    target_achieved = True
                    for start in range(0, len(steady_response), interval_size):
                        end = min(start + interval_size, len(steady_response))
                        interval_response = steady_response[start:end]
                        if len(interval_response) > 0 and not np.all(interval_response < 3.3):
                            target_achieved = False
                            break

                    return {
                        'subject_id': subject['subject_id'],
                        'bw': subject['bw'],
                        'comed': subject['comed'],
                        'target_achieved': target_achieved,
                        'min_response': np.min(steady_response),
                        'mean_response': np.mean(steady_response),
                        'max_conc': np.max(steady_conc)
                    }
                else:
                    return None
            except:
                return None

        # Use ThreadPoolExecutor for I/O bound simulation
        with ThreadPoolExecutor(max_workers=min(len(batch_subjects), mp.cpu_count())) as executor:
            results = list(executor.map(simulate_single, batch_subjects))

        return [r for r in results if r is not None]

    def simulate_dosing_regimen(self, virtual_population, dose_mg, dosing_interval_h,
                               simulation_days=28, steady_state_days=21):
        """
        High-performance dosing regimen simulation with batching
        """
        start_time = time.time()

        print(f"Simulating {dose_mg} mg every {dosing_interval_h}h for {len(virtual_population)} subjects...")
        print(f"Using batch processing with {self.batch_size} subjects per batch")

        all_results = []

        # Process in batches with progress bar
        for i in tqdm(range(0, len(virtual_population), self.batch_size), desc="Batch Progress"):
            batch_end = min(i + self.batch_size, len(virtual_population))
            batch_subjects = virtual_population[i:batch_end]

            batch_results = self.simulate_batch(
                batch_subjects, dose_mg, dosing_interval_h,
                simulation_days, steady_state_days
            )
            all_results.extend(batch_results)

        perf_monitor.log_metric('simulation_time', time.time() - start_time, increment=True)
        perf_monitor.log_metric('successful_subjects', len(all_results), increment=True)
        perf_monitor.log_metric('failed_subjects', len(virtual_population) - len(all_results), increment=True)

        print(f"Successfully simulated {len(all_results)}/{len(virtual_population)} subjects")

        return pd.DataFrame(all_results)

    def find_optimal_dose_adaptive(self, target_achievement=0.9, dose_range=(0.5, 20),
                                  dosing_interval=24, n_subjects=5000, **population_kwargs):
        """
        Adaptive dose finding with Gaussian Process surrogate
        """
        print(f"\nAdaptive dose optimization for {target_achievement*100}% target achievement")
        print(f"Dosing interval: {dosing_interval} hours")

        # Generate virtual population once
        virtual_pop = self.generate_virtual_population_fast(n_subjects, **population_kwargs)

        # Initialize Gaussian Process surrogate
        kernel = Matern(length_scale=1.0, nu=2.5) + WhiteKernel(noise_level=0.01)
        gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=2)

        # Initial evaluations
        initial_doses = np.linspace(dose_range[0], dose_range[1], 5)
        dose_history = []
        achievement_history = []

        for dose in initial_doses:
            if dosing_interval == 24:
                test_dose = dose
            else:
                test_dose = dose * 7

            results = self.simulate_dosing_regimen(virtual_pop, test_dose, dosing_interval)

            if len(results) > 0:
                achievement = results['target_achieved'].mean()
                dose_history.append(dose)
                achievement_history.append(achievement)
                print(f"  Dose {dose:.1f} mg: {achievement:.1%} achievement")

        # Fit initial GP
        if len(dose_history) >= 3:
            X = np.array(dose_history).reshape(-1, 1)
            y = np.array(achievement_history)
            gp.fit(X, y)

        # Adaptive optimization
        for iteration in range(10):  # Maximum 10 adaptive iterations
            if len(dose_history) < 3:
                break

            # Acquisition function: Upper Confidence Bound
            test_doses = np.linspace(dose_range[0], dose_range[1], 100)
            mean_pred, std_pred = gp.predict(test_doses.reshape(-1, 1), return_std=True)

            # UCB acquisition
            beta = 2.0  # Exploration parameter
            acquisition = mean_pred + beta * std_pred

            # Find dose that maximizes acquisition near target
            target_diff = np.abs(mean_pred - target_achievement)
            acquisition_adjusted = acquisition - target_diff

            next_dose_idx = np.argmax(acquisition_adjusted)
            next_dose = test_doses[next_dose_idx]

            # Skip if too close to existing evaluations
            if min(np.abs(np.array(dose_history) - next_dose)) < 0.2:
                break

            # Evaluate new dose
            if dosing_interval == 24:
                test_dose = next_dose
            else:
                test_dose = next_dose * 7

            results = self.simulate_dosing_regimen(virtual_pop, test_dose, dosing_interval)

            if len(results) > 0:
                achievement = results['target_achieved'].mean()
                dose_history.append(next_dose)
                achievement_history.append(achievement)

                print(f"  Adaptive iteration {iteration+1}: Dose {next_dose:.1f} mg: {achievement:.1%}")

                # Update GP
                X = np.array(dose_history).reshape(-1, 1)
                y = np.array(achievement_history)
                gp.fit(X, y)

                # Check convergence
                if abs(achievement - target_achievement) < 0.02:
                    print(f"  Converged at dose {next_dose:.1f} mg")
                    return round(next_dose * 2) / 2 if dosing_interval == 24 else round(next_dose / 5) * 5

        # Return best dose from evaluations
        achievement_array = np.array(achievement_history)
        target_mask = achievement_array >= target_achievement

        if np.any(target_mask):
            valid_doses = np.array(dose_history)[target_mask]
            optimal_dose = np.min(valid_doses)  # Minimum effective dose
        else:
            # Return dose with highest achievement
            optimal_dose = dose_history[np.argmax(achievement_array)]

        # Round appropriately
        if dosing_interval == 24:
            optimal_dose = round(optimal_dose * 2) / 2
        else:
            optimal_dose = round(optimal_dose / 5) * 5

        return optimal_dose

    # Keep original method for backward compatibility
    def find_optimal_dose(self, target_achievement=0.9, dose_range=(0.5, 20),
                         dosing_interval=24, n_subjects=5000, **population_kwargs):
        return self.find_optimal_dose_adaptive(target_achievement, dose_range,
                                              dosing_interval, n_subjects, **population_kwargs)

    def generate_virtual_population(self, n_subjects, **kwargs):
        """Backward compatibility method"""
        return self.generate_virtual_population_fast(n_subjects, **kwargs)

def run_complete_simulation_enhanced(df):
    """
    Enhanced complete simulation with performance monitoring
    """
    perf_monitor.reset()

    print("STARTING ENHANCED PKPD SIMULATION")
    print(f"CPU Cores Available: {mp.cpu_count()}")
    print(f"GPU Available: {USE_GPU}")
    if TORCH_AVAILABLE:
        print("Neural ODE Support: Available")
    else:
        print("Neural ODE Support: Not available")


    # Initialize enhanced model and estimator
    model = PK_PD_Model()
    estimator = SAEM_Estimator(model)  # Using SAEM instead of FOCE
    estimator.load_data(df)

    # Train Neural ODE if available
    if model.neural_solver.available:
        print("\nTraining Neural ODE surrogate...")
        model.neural_solver.train_on_traditional_solver()

    # Fit model using SAEM
    print("\nFITTING NLME MODEL WITH SAEM...")
    final_params, final_ll = estimator.fit()

    # Initialize enhanced simulator
    simulator = PopulationSimulator(model, estimator.theta, estimator.omega, estimator.sigma)

    # Enhanced dose optimization

    print("ENHANCED DOSE OPTIMIZATION RESULTS")


    results = {}

    # Use adaptive optimization for all scenarios
    print("\n1. Daily dose for 90% target achievement (original population):")
    daily_dose_90 = simulator.find_optimal_dose_adaptive(
        target_achievement=0.9, dosing_interval=24,
        bw_range=(50, 100), comed_prob=0.5, n_subjects=3000
    )
    results['daily_90_original'] = daily_dose_90
    print(f"   Optimal daily dose: {daily_dose_90:.1f} mg")

    print("\n2. Weekly dose for 90% target achievement (original population):")
    weekly_dose_90 = simulator.find_optimal_dose_adaptive(
        target_achievement=0.9, dosing_interval=168,
        bw_range=(50, 100), comed_prob=0.5, n_subjects=3000
    )
    results['weekly_90_original'] = weekly_dose_90
    print(f"   Optimal weekly dose: {weekly_dose_90:.0f} mg")

    print("\n3. Effect of changed body weight distribution (70-140 kg):")
    daily_dose_90_heavy = simulator.find_optimal_dose_adaptive(
        target_achievement=0.9, dosing_interval=24,
        bw_range=(70, 140), comed_prob=0.5, n_subjects=3000
    )
    weekly_dose_90_heavy = simulator.find_optimal_dose_adaptive(
        target_achievement=0.9, dosing_interval=168,
        bw_range=(70, 140), comed_prob=0.5, n_subjects=3000
    )
    results['daily_90_heavy'] = daily_dose_90_heavy
    results['weekly_90_heavy'] = weekly_dose_90_heavy
    print(f"   Daily dose (heavy population): {daily_dose_90_heavy:.1f} mg")
    print(f"   Weekly dose (heavy population): {weekly_dose_90_heavy:.0f} mg")

    print("\n4. Effect of restricting concomitant medication:")
    daily_dose_90_no_comed = simulator.find_optimal_dose_adaptive(
        target_achievement=0.9, dosing_interval=24,
        bw_range=(50, 100), comed_prob=0.0, n_subjects=3000
    )
    weekly_dose_90_no_comed = simulator.find_optimal_dose_adaptive(
        target_achievement=0.9, dosing_interval=168,
        bw_range=(50, 100), comed_prob=0.0, n_subjects=3000
    )
    results['daily_90_no_comed'] = daily_dose_90_no_comed
    results['weekly_90_no_comed'] = weekly_dose_90_no_comed
    print(f"   Daily dose (no COMED): {daily_dose_90_no_comed:.1f} mg")
    print(f"   Weekly dose (no COMED): {weekly_dose_90_no_comed:.0f} mg")

    print("\n5. Doses for 75% target achievement:")

    # 75% scenarios with parallel execution
    scenarios_75 = [
        ('original', (50, 100), 0.5),
        ('heavy', (70, 140), 0.5),
        ('no_comed', (50, 100), 0.0)
    ]

    for scenario_name, bw_range, comed_prob in scenarios_75:
        daily_dose_75 = simulator.find_optimal_dose_adaptive(
            target_achievement=0.75, dosing_interval=24,
            bw_range=bw_range, comed_prob=comed_prob, n_subjects=2000
        )
        weekly_dose_75 = simulator.find_optimal_dose_adaptive(
            target_achievement=0.75, dosing_interval=168,
            bw_range=bw_range, comed_prob=comed_prob, n_subjects=2000
        )

        results[f'daily_75_{scenario_name}'] = daily_dose_75
        results[f'weekly_75_{scenario_name}'] = weekly_dose_75

        print(f"   {scenario_name.replace('_', ' ').title()} population:")
        print(f"     Daily (75%): {daily_dose_75:.1f} mg")
        print(f"     Weekly (75%): {weekly_dose_75:.0f} mg")

    # Enhanced summary with performance metrics
    print("\n" + "="*70)
    print("ENHANCED SUMMARY OF OPTIMAL DOSES")
    print("="*70)

    summary_data = []
    scenarios = ['original', 'heavy', 'no_comed']
    scenario_names = ['Original Pop', 'Heavy Pop (70-140kg)', 'No COMED']

    for scenario, name in zip(scenarios, scenario_names):
        summary_data.extend([
            {
                'Scenario': f'{name} (90%)',
                'Daily (mg)': results[f'daily_90_{scenario}'],
                'Weekly (mg)': results[f'weekly_90_{scenario}'],
                'Daily vs Weekly Ratio': results[f'weekly_90_{scenario}'] / (results[f'daily_90_{scenario}'] * 7)
            },
            {
                'Scenario': f'{name} (75%)',
                'Daily (mg)': results[f'daily_75_{scenario}'],
                'Weekly (mg)': results[f'weekly_75_{scenario}'],
                'Daily vs Weekly Ratio': results[f'weekly_75_{scenario}'] / (results[f'daily_75_{scenario}'] * 7)
            }
        ])

    summary_df = pd.DataFrame(summary_data)
    print(summary_df.to_string(index=False, float_format='%.2f'))

    # Enhanced dose reduction analysis
    print("\nENHANCED DOSE REDUCTION ANALYSIS (75% vs 90% achievement):")
    for scenario, name in zip(scenarios, scenario_names):
        daily_90_key = f'daily_90_{scenario}'
        daily_75_key = f'daily_75_{scenario}'
        weekly_90_key = f'weekly_90_{scenario}'
        weekly_75_key = f'weekly_75_{scenario}'

        daily_reduction = results[daily_90_key] - results[daily_75_key]
        weekly_reduction = results[weekly_90_key] - results[weekly_75_key]

        daily_pct_reduction = (daily_reduction / results[daily_90_key]) * 100
        weekly_pct_reduction = (weekly_reduction / results[weekly_90_key]) * 100

        print(f"   {name}:")
        print(f"     Daily: -{daily_reduction:.1f} mg ({daily_pct_reduction:.1f}% reduction)")
        print(f"     Weekly: -{weekly_reduction:.0f} mg ({weekly_pct_reduction:.1f}% reduction)")
        print(f"     Efficiency gain: {weekly_pct_reduction - daily_pct_reduction:.1f}% better with weekly dosing")

    # Performance report
    perf_monitor.report()

    return results, estimator, simulator

def create_enhanced_diagnostic_plots(estimator, simulator, results):
    """
    Create enhanced diagnostic plots with performance insights
    """
    print("\nGENERATING ENHANCED DIAGNOSTIC PLOTS...")

    fig, axes = plt.subplots(3, 3, figsize=(20, 16))

    # Generate comprehensive simulation data
    virtual_pop = simulator.generate_virtual_population_fast(1500, seed=42)
    sim_results_daily = simulator.simulate_dosing_regimen(virtual_pop, 5.0, 24)
    sim_results_weekly = simulator.simulate_dosing_regimen(virtual_pop, 35.0, 168)

    # Plot 1: Enhanced Parameter Estimates with Confidence Intervals
    ax1 = axes[0, 0]
    param_names = ['CL', 'V1', 'Q', 'V2', 'KA', 'KE0', 'IMAX', 'IC50', 'KIN', 'KOUT']
    param_values = estimator.theta[:10]
    param_se = np.sqrt(np.diag(estimator.omega))[:10] * 0.1  # Approximate SE

    bars = ax1.bar(range(len(param_names)), param_values,
                  yerr=param_se, capsize=5, color='skyblue', alpha=0.7)
    ax1.set_xlabel('Parameters')
    ax1.set_ylabel('Estimated Values ± SE')
    ax1.set_title('Population Parameter Estimates with Uncertainty')
    ax1.set_xticks(range(len(param_names)))
    ax1.set_xticklabels(param_names, rotation=45)
    ax1.grid(True, alpha=0.3)

    # Plot 2: Inter-Individual Variability Heatmap
    ax2 = axes[0, 1]
    omega_matrix = estimator.omega[:len(param_names), :len(param_names)]
    correlation_matrix = np.corrcoef(omega_matrix)

    im = ax2.imshow(correlation_matrix, cmap='RdYlBu_r', vmin=-1, vmax=1)
    ax2.set_xticks(range(len(param_names)))
    ax2.set_yticks(range(len(param_names)))
    ax2.set_xticklabels(param_names, rotation=45)
    ax2.set_yticklabels(param_names)
    ax2.set_title('Parameter Correlation Matrix')
    plt.colorbar(im, ax=ax2)

    # Plot 3: Target Achievement vs Body Weight (Both Dosing Regimens)
    ax3 = axes[0, 2]
    if len(sim_results_daily) > 0 and len(sim_results_weekly) > 0:
        bw_bins = np.linspace(50, 100, 6)
        daily_achievement = []
        weekly_achievement = []

        for i in range(len(bw_bins)-1):
            mask_daily = (sim_results_daily['bw'] >= bw_bins[i]) & (sim_results_daily['bw'] < bw_bins[i+1])
            mask_weekly = (sim_results_weekly['bw'] >= bw_bins[i]) & (sim_results_weekly['bw'] < bw_bins[i+1])

            daily_achievement.append(sim_results_daily[mask_daily]['target_achieved'].mean() * 100)
            weekly_achievement.append(sim_results_weekly[mask_weekly]['target_achieved'].mean() * 100)

        bw_centers = (bw_bins[:-1] + bw_bins[1:]) / 2
        ax3.plot(bw_centers, daily_achievement, 'o-', label='Daily', linewidth=2, markersize=6)
        ax3.plot(bw_centers, weekly_achievement, 's-', label='Weekly', linewidth=2, markersize=6)
        ax3.axhline(y=90, color='red', linestyle='--', alpha=0.7, label='90% Target')
        ax3.set_xlabel('Body Weight (kg)')
        ax3.set_ylabel('Achievement Rate (%)')
        ax3.set_title('Target Achievement vs Body Weight')
        ax3.grid(True, alpha=0.3)
        ax3.legend()

    # Plot 4: Advanced Dose-Response Surface
    ax4 = axes[1, 0]
    dose_range = np.linspace(1, 12, 8)
    bw_range = np.linspace(50, 100, 6)
    achievement_surface = np.zeros((len(bw_range), len(dose_range)))

    for i, bw_center in enumerate(bw_range):
        for j, dose in enumerate(dose_range):
            test_pop = simulator.generate_virtual_population_fast(
                200, bw_range=(bw_center-5, bw_center+5), seed=42+i*10+j
            )
            test_results = simulator.simulate_dosing_regimen(test_pop, dose, 24)
            if len(test_results) > 0:
                achievement_surface[i, j] = test_results['target_achieved'].mean() * 100

    im4 = ax4.contourf(dose_range, bw_range, achievement_surface,
                       levels=np.linspace(0, 100, 11), cmap='viridis')
    ax4.contour(dose_range, bw_range, achievement_surface,
                levels=[75, 90], colors=['orange', 'red'], linewidths=2)
    ax4.set_xlabel('Daily Dose (mg)')
    ax4.set_ylabel('Body Weight (kg)')
    ax4.set_title('Achievement Rate Surface (%)')
    plt.colorbar(im4, ax=ax4)

    # Plot 5: COMED Effect with Statistical Significance
    ax5 = axes[1, 1]
    if len(sim_results_daily) > 0:
        comed_groups = sim_results_daily.groupby('comed')['target_achieved']
        comed_means = comed_groups.mean() * 100
        comed_stds = comed_groups.std() * 100
        comed_counts = comed_groups.count()
        comed_se = comed_stds / np.sqrt(comed_counts)

        comed_labels = ['No COMED', 'COMED']
        bars5 = ax5.bar(comed_labels, comed_means.values,
                       yerr=comed_se.values, capsize=5,
                       color=['blue', 'orange'], alpha=0.7)
        ax5.set_ylabel('Achievement Rate (%) ± SE')
        ax5.set_title('Effect of Concomitant Medication')
        ax5.axhline(y=90, color='red', linestyle='--', alpha=0.7)
        ax5.grid(True, alpha=0.3)

        # Add significance test result
        from scipy.stats import ttest_ind
        group0 = sim_results_daily[sim_results_daily['comed']==0]['target_achieved']
        group1 = sim_results_daily[sim_results_daily['comed']==1]['target_achieved']
        if len(group0) > 0 and len(group1) > 0:
            t_stat, p_val = ttest_ind(group0, group1)
            ax5.text(0.5, max(comed_means) + 5, f'p = {p_val:.3f}',
                    ha='center', transform=ax5.transData)

    # Plot 6: Performance Metrics Dashboard
    ax6 = axes[1, 2]
    metrics_names = ['Successful\nSubjects', 'ODE\nEvaluations', 'Cache\nHits', 'Memory\n(MB)']
    metrics_values = [
        perf_monitor.metrics['successful_subjects'],
        perf_monitor.metrics['ode_evaluations'],
        len(simulator.model.solution_cache),
        perf_monitor.metrics['memory_peak_mb']
    ]

    # Normalize values for display
    normalized_values = [v/max(metrics_values) * 100 for v in metrics_values]
    bars6 = ax6.bar(metrics_names, normalized_values,
                   color=['green', 'blue', 'purple', 'red'], alpha=0.7)
    ax6.set_ylabel('Normalized Performance Score')
    ax6.set_title('Performance Metrics Dashboard')

    # Add actual values as text
    for bar, actual in zip(bars6, metrics_values):
        height = bar.get_height()
        ax6.text(bar.get_x() + bar.get_width()/2, height + 2,
                f'{actual:.0f}', ha='center', va='bottom', fontsize=9)

    # Plot 7: Residual Error Analysis
    ax7 = axes[2, 0]
    error_types = ['PK Error\n(Log-normal)', 'PD Error\n(Proportional)']
    error_values = [estimator.sigma[0] * 100, estimator.sigma[1] * 100]
    error_colors = ['purple', 'brown']

    bars7 = ax7.bar(error_types, error_values, color=error_colors, alpha=0.7)
    ax7.set_ylabel('Error Magnitude (%)')
    ax7.set_title('Residual Error Model Parameters')
    ax7.grid(True, alpha=0.3)

    for bar, val in zip(bars7, error_values):
        ax7.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                f'{val:.1f}%', ha='center', va='bottom', fontsize=10)


    # Plot 8: Dosing Efficiency Comparison
    ax9 = axes[2, 2]
    scenarios = ['Original', 'Heavy BW', 'No COMED']
    daily_doses = [
        results.get('daily_90_original', 0),
        results.get('daily_90_heavy', 0),
        results.get('daily_90_no_comed', 0)
    ]
    weekly_doses = [
        results.get('weekly_90_original', 0) / 7,  # Convert to daily equivalent
        results.get('weekly_90_heavy', 0) / 7,
        results.get('weekly_90_no_comed', 0) / 7
    ]

    x = np.arange(len(scenarios))
    width = 0.35

    bars_daily = ax9.bar(x - width/2, daily_doses, width, label='Daily', alpha=0.7)
    bars_weekly = ax9.bar(x + width/2, weekly_doses, width, label='Weekly (equiv.)', alpha=0.7)

    ax9.set_xlabel('Population Scenario')
    ax9.set_ylabel('Daily Dose Equivalent (mg)')
    ax9.set_title('Dosing Efficiency: Daily vs Weekly')
    ax9.set_xticks(x)
    ax9.set_xticklabels(scenarios)
    ax9.legend()
    ax9.grid(True, alpha=0.3)

    # Add efficiency percentages
    for i, (daily, weekly) in enumerate(zip(daily_doses, weekly_doses)):
        if daily > 0:
            efficiency = (daily - weekly) / daily * 100
            ax9.text(i, max(daily, weekly) + 0.1, f'{efficiency:.1f}%\nsavings',
                    ha='center', va='bottom', fontsize=8)

    plt.tight_layout()
    plt.show()


results, estimator, simulator = run_complete_simulation_enhanced(df)
create_enhanced_diagnostic_plots(estimator, simulator)

Quantum Enhanced Implementation

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.integrate import solve_ivp
from scipy.optimize import minimize, differential_evolution
from scipy.stats import multivariate_normal, norm
from scipy.interpolate import interp1d
import warnings
warnings.filterwarnings('ignore')

# Quantum computing imports
import pennylane as qml
from pennylane import numpy as qnp
import jax
import jax.numpy as jnp
from jax import jit, vmap
from functools import partial

# Performance imports
from numba import jit as numba_jit
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import multiprocessing as mp
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern, WhiteKernel
from tqdm import tqdm
import time
import psutil

# Set plotting style (same as original)
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)

# Quantum device setup for compartment simulation
n_qubits = 12  # 3 qubits per compartment for d=8 levels
dev = qml.device('default.qubit', wires=n_qubits)

print("Quantum-Enhanced PK/PD Framework Initialized")
print(f"Quantum Device: {n_qubits} qubits for compartment encoding")

class PerformanceMonitor:
    """Monitor and report performance metrics (same as original)"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.start_time = time.time()
        self.metrics = {
            'total_time': 0,
            'ode_solve_time': 0,
            'optimization_time': 0,
            'simulation_time': 0,
            'memory_peak_mb': 0,
            'cpu_count': mp.cpu_count(),
            'successful_subjects': 0,
            'failed_subjects': 0,
            'ode_evaluations': 0,
            'convergence_iterations': 0
        }

    def log_metric(self, key, value, increment=False):
        if increment:
            self.metrics[key] += value
        else:
            self.metrics[key] = value

    def get_memory_usage(self):
        return psutil.Process().memory_info().rss / 1024 / 1024

    def report(self):
        self.metrics['total_time'] = time.time() - self.start_time
        self.metrics['memory_peak_mb'] = self.get_memory_usage()

        print("\n" + "="*60)
        print("PERFORMANCE REPORT")
        print(f"Total Runtime: {self.metrics['total_time']:.2f} seconds")
        print(f"ODE Solving: {self.metrics['ode_solve_time']:.2f} seconds")
        print(f"Optimization: {self.metrics['optimization_time']:.2f} seconds")
        print(f"Simulation: {self.metrics['simulation_time']:.2f} seconds")
        print(f"Peak Memory: {self.metrics['memory_peak_mb']:.1f} MB")
        print(f"CPU Cores Used: {self.metrics['cpu_count']}")
        print(f"Successful Subjects: {self.metrics['successful_subjects']}")
        print(f"Failed Subjects: {self.metrics['failed_subjects']}")
        print(f"ODE Evaluations: {self.metrics['ode_evaluations']}")
        print(f"Optimization Iterations: {self.metrics['convergence_iterations']}")

        if self.metrics['successful_subjects'] > 0:
            success_rate = self.metrics['successful_subjects'] / (self.metrics['successful_subjects'] + self.metrics['failed_subjects']) * 100
            print(f"Success Rate: {success_rate:.1f}%")

# Global performance monitor (same as original)
perf_monitor = PerformanceMonitor()

class QuantumCompartmentSimulator:
    """
    Quantum simulator for PK/PD compartments using Lindblad master equation
    Based on the theoretical framework from your document
    """

    def __init__(self):
        self.d_levels = 8  # Truncated bosonic levels per compartment
        self.n_qubits_per_comp = 3  # log2(8) = 3 qubits per compartment
        self.compartments = ['A1', 'A2', 'AE', 'R']  # 4 compartments
        self.use_quantum = True

    def create_lindblad_operators(self, params, bw, comed):
        """
        Create Lindblad jump operators for PK/PD flows
        Maps classical rates to quantum jump operators as per your document
        """
        # Extract parameters (same as classical)
        CL, V1, Q, V2, KA = params[0], params[1], params[2], params[3], params[4]
        KE0, IMAX, IC50, KIN, KOUT = params[5], params[6], params[7], params[8], params[9]
        CLBW, V1BW, CLCOMED, KINCOMED = params[10], params[11], params[12], params[13]

        # Covariate effects (same as classical)
        CL_i = CL * ((bw/70.0)**CLBW) * (1.0 + CLCOMED * comed)
        V1_i = V1 * ((bw/70.0)**V1BW)
        KIN_i = KIN * (1.0 + KINCOMED * comed)

        # Calculate quantum jump rates (mapping from your document)
        rates = {
            'k12': Q / V1_i,      # A1 → A2 flow
            'k21': Q / V2,        # A2 → A1 flow
            'k_el_1': CL_i / V1_i, # A1 elimination
            'KE0': KE0,           # A1 ↔ AE coupling
            'KOUT': KOUT          # R elimination
        }

        return rates

    def quantum_evolution_circuit(self, params, rates, dt, dose_rate):
        """
        Quantum circuit for one time step of Lindblad evolution
        Implements simplified version of the quantum dynamics from your framework
        """
        @qml.qnode(dev)
        def circuit():
            # Initialize state (ground state for empty compartments)
            qml.BasisState([0]*n_qubits, wires=range(n_qubits))

            # Apply dosing to A1 compartment (qubits 0-2)
            if dose_rate > 0:
                for qubit in range(3):
                    qml.RY(np.sqrt(dose_rate * dt) * 0.01, wires=qubit)

            # Apply inter-compartment transfers (A1↔A2)
            for i in range(3):
                # A1 → A2 transfer
                qml.CRY(np.sqrt(rates['k12'] * dt) * 0.01, wires=[i, i+3])
                # A2 → A1 transfer
                qml.CRY(np.sqrt(rates['k21'] * dt) * 0.01, wires=[i+3, i])

            # Apply A1 ↔ AE coupling
            for i in range(3):
                qml.CRY(np.sqrt(rates['KE0'] * dt) * 0.01, wires=[i, i+6])
                qml.CRY(np.sqrt(rates['KE0'] * dt) * 0.01, wires=[i+6, i])

            # Apply elimination from A1 and R
            for qubit in range(3):
                qml.RY(-np.sqrt(rates['k_el_1'] * dt) * 0.01, wires=qubit)  # A1 elimination
                qml.RY(-np.sqrt(rates['KOUT'] * dt) * 0.01, wires=qubit+9)  # R elimination

            return qml.probs(wires=range(n_qubits))

        return circuit()
        # Initialize state (ground state for empty compartments)
        qml.BasisState([0]*n_qubits, wires=range(n_qubits))

        # Apply dosing to A1 compartment (qubits 0-2)
        if dose_rate > 0:
            for qubit in range(3):
                qml.RY(np.sqrt(dose_rate * dt) * 0.01, wires=qubit)

        # Apply inter-compartment transfers (A1↔A2)
        for i in range(3):
            # A1 → A2 transfer
            qml.CRY(np.sqrt(rates['k12'] * dt) * 0.01, wires=[i, i+3])
            # A2 → A1 transfer
            qml.CRY(np.sqrt(rates['k21'] * dt) * 0.01, wires=[i+3, i])

        # Apply A1 ↔ AE coupling
        for i in range(3):
            qml.CRY(np.sqrt(rates['KE0'] * dt) * 0.01, wires=[i, i+6])
            qml.CRY(np.sqrt(rates['KE0'] * dt) * 0.01, wires=[i+6, i])

        # Apply elimination from A1 and R
        for qubit in range(3):
            qml.RY(-np.sqrt(rates['k_el_1'] * dt) * 0.01, wires=qubit)  # A1 elimination
            qml.RY(-np.sqrt(rates['KOUT'] * dt) * 0.01, wires=qubit+9)  # R elimination

        return qml.probs(wires=range(n_qubits))

    def measure_concentrations(self, quantum_probs):
        """
        Extract concentration expectations from quantum probabilities
        Maps quantum state to classical concentrations
        """
        # Simplified measurement: extract expectation values for each compartment
        # In practice, this would compute ⟨n_i⟩ from the quantum state
        concentrations = []

        for comp_idx in range(4):  # 4 compartments
            start_qubit = comp_idx * 3
            # Approximate concentration from quantum state probabilities
            concentration = 0
            for level in range(self.d_levels):
                # Binary representation of level
                binary = [(level >> i) & 1 for i in range(3)]
                # Find probability of this configuration in the relevant qubits
                prob_contribution = np.sum(quantum_probs) * level / self.d_levels  # Simplified
                concentration += prob_contribution

            concentrations.append(max(0, concentration))

        return concentrations

    def simulate_quantum_pkpd(self, params, times, doses, dose_times, bw, comed, baseline_R=8.0):
        """
        Simulate PK/PD using quantum Lindblad evolution
        Fallback to classical if quantum simulation fails
        """
        try:
            # Get Lindblad rates
            rates = self.create_lindblad_operators(params, bw, comed)

            # Time stepping
            dt = times[1] - times[0] if len(times) > 1 else 0.1
            dose_interp = interp1d(dose_times, doses, kind='linear', bounds_error=False, fill_value=0) if len(dose_times) > 0 else lambda t: 0

            concentrations_over_time = []
            response_over_time = []

            # Initial quantum state evolution
            for t in times:
                dose_rate = dose_interp(t)

                # Evolve quantum state for one time step
                quantum_probs = self.quantum_evolution_circuit(params, rates, dt, dose_rate)

                # Measure concentrations
                measured_concs = self.measure_concentrations(quantum_probs)

                # Convert to physical units (same as classical)
                V1_i = params[1] * ((bw/70)**params[11])
                conc_central = measured_concs[0] / V1_i  # A1 concentration
                response = measured_concs[3]  # R response

                concentrations_over_time.append(max(0, conc_central))
                response_over_time.append(max(0, response))

            return np.array(concentrations_over_time), np.array(response_over_time)

        except Exception as e:
            # Fallback to classical simulation
            return None, None

# JAX-optimized classical fallback (same ODE system as original)
@jit
def pk_pd_system_jax(t, y, dose_rate, params, bw, comed):
    """JAX-compiled classical ODE system (same as original numba version)"""
    A1, A2, AE, R = y

    CL, V1, Q, V2, KA = params[0], params[1], params[2], params[3], params[4]
    KE0, IMAX, IC50, KIN, KOUT = params[5], params[6], params[7], params[8], params[9]
    CLBW, V1BW, CLCOMED, KINCOMED = params[10], params[11], params[12], params[13]

    CL_i = CL * ((bw/70.0)**CLBW) * (1.0 + CLCOMED * comed)
    V1_i = V1 * ((bw/70.0)**V1BW)
    KIN_i = KIN * (1.0 + KINCOMED * comed)

    C1 = A1 / V1_i
    CE = AE / V1_i

    dA1_dt = KA * dose_rate - (CL_i/V1_i + Q/V1_i) * A1 + (Q/V2) * A2
    dA2_dt = (Q/V1_i) * A1 - (Q/V2) * A2
    dAE_dt = KE0 * A1 - KE0 * AE

    inhibition = (IMAX * CE) / (IC50 + CE)
    dR_dt = KIN_i * (1.0 - inhibition) - KOUT * R

    return jnp.array([dA1_dt, dA2_dt, dAE_dt, dR_dt])

@numba_jit(nopython=True)
def pk_pd_system_numba(t, y, dose_rate, params, bw, comed):
    """Numba-compiled classical ODE system (same as original)"""
    A1, A2, AE, R = y

    CL, V1, Q, V2, KA = params[0], params[1], params[2], params[3], params[4]
    KE0, IMAX, IC50, KIN, KOUT = params[5], params[6], params[7], params[8], params[9]
    CLBW, V1BW, CLCOMED, KINCOMED = params[10], params[11], params[12], params[13]

    CL_i = CL * ((bw/70.0)**CLBW) * (1.0 + CLCOMED * comed)
    V1_i = V1 * ((bw/70.0)**V1BW)
    KIN_i = KIN * (1.0 + KINCOMED * comed)

    C1 = A1 / V1_i
    CE = AE / V1_i

    dA1_dt = KA * dose_rate - (CL_i/V1_i + Q/V1_i) * A1 + (Q/V2) * A2
    dA2_dt = (Q/V1_i) * A1 - (Q/V2) * A2
    dAE_dt = KE0 * A1 - KE0 * AE

    inhibition = (IMAX * CE) / (IC50 + CE)
    dR_dt = KIN_i * (1.0 - inhibition) - KOUT * R

    return np.array([dA1_dt, dA2_dt, dAE_dt, dR_dt])

class PK_PD_Model:
    """
    Enhanced PK/PD model with quantum backend option (same interface as original)
    """

    def __init__(self):
        # Same parameter names as original
        self.pk_params = ['CL', 'V1', 'Q', 'V2', 'KA']
        self.pd_params = ['KE0', 'IMAX', 'IC50', 'KIN', 'KOUT']
        self.covariate_params = ['CLBW', 'V1BW', 'CLCOMED', 'KINCOMED']
        self.error_params = ['SIGMA_PK', 'SIGMA_PD']

        self.all_params = (self.pk_params + self.pd_params +
                          self.covariate_params + self.error_params)

        # Initialize quantum simulator
        self.quantum_simulator = QuantumCompartmentSimulator()

        # Cached solutions (same as original)
        self.solution_cache = {}
        self.cache_tolerance = 1e-2

    def pk_pd_system(self, t, y, dose_func, params, bw, comed):
        """Traditional ODE system with performance monitoring (same as original)"""
        perf_monitor.log_metric('ode_evaluations', 1, increment=True)
        dose_rate = dose_func(t)
        return pk_pd_system_numba(t, y, dose_rate, params, bw, comed)

    def get_cache_key(self, params, bw, comed, doses, dose_times):
        """Create cache key (same as original)"""
        key_params = tuple(np.round(params, 3))
        key_bw = round(bw, 1)
        key_comed = int(comed)
        key_doses = tuple(np.round(doses, 1))
        key_dose_times = tuple(np.round(dose_times, 1))
        return (key_params, key_bw, key_comed, key_doses, key_dose_times)

    def simulate_individual_fast(self, params, times, doses, dose_times, bw, comed, baseline_R=8.0):
        """
        Enhanced simulation with quantum backend option (same interface as original)
        """
        start_time = time.time()

        # Check cache first (same as original)
        cache_key = self.get_cache_key(params, bw, comed, doses, dose_times)
        if cache_key in self.solution_cache:
            cached_times, cached_conc, cached_response = self.solution_cache[cache_key]
            if np.allclose(times, cached_times, rtol=1e-2):
                conc = np.interp(times, cached_times, cached_conc)
                response = np.interp(times, cached_times, cached_response)
                perf_monitor.log_metric('ode_solve_time', time.time() - start_time, increment=True)
                return conc, response

        # Try quantum simulation first
        if self.quantum_simulator.use_quantum:
            conc, response = self.quantum_simulator.simulate_quantum_pkpd(
                params, times, doses, dose_times, bw, comed, baseline_R
            )

            if conc is not None and response is not None:
                if len(self.solution_cache) < 10000:
                    self.solution_cache[cache_key] = (times.copy(), conc.copy(), response.copy())

                perf_monitor.log_metric('ode_solve_time', time.time() - start_time, increment=True)
                perf_monitor.log_metric('successful_subjects', 1, increment=True)
                return conc, response

        # Classical fallback (same as original optimize solver)
        dose_rates = np.zeros(len(times))
        dt = times[1] - times[0] if len(times) > 1 else 0.1

        # Create dose schedule
        for i in range(len(dose_times)):
            time_idx = int((dose_times[i] - times[0]) / dt)
            if 0 <= time_idx < len(dose_rates):
                dose_rates[time_idx] = doses[i] / dt

        dose_interp = interp1d(times, dose_rates, kind='linear', bounds_error=False, fill_value=0)

        def dose_func(t):
            return dose_interp(t)

        y0 = [0, 0, 0, baseline_R]

        try:
            # Use optimized solver settings (same as original)
            sol = solve_ivp(
                lambda t, y: self.pk_pd_system(t, y, dose_func, params, bw, comed),
                [times[0], times[-1]], y0, t_eval=times,
                method='DOP853',
                rtol=1e-4, atol=1e-7,
                max_step=1.0
            )

            if sol.success:
                A1, A2, AE, R = sol.y

                V1BW = params[11]
                V1_i = params[1] * (bw/70)**V1BW
                concentrations = A1 / V1_i

                # Cache successful solutions (same as original)
                if len(self.solution_cache) < 10000:
                    self.solution_cache[cache_key] = (times.copy(), concentrations.copy(), R.copy())

                perf_monitor.log_metric('ode_solve_time', time.time() - start_time, increment=True)
                perf_monitor.log_metric('successful_subjects', 1, increment=True)
                return concentrations, R
            else:
                perf_monitor.log_metric('failed_subjects', 1, increment=True)
                return np.full_like(times, np.nan), np.full_like(times, np.nan)

        except Exception as e:
            perf_monitor.log_metric('failed_subjects', 1, increment=True)
            perf_monitor.log_metric('ode_solve_time', time.time() - start_time, increment=True)
            return np.full_like(times, np.nan), np.full_like(times, np.nan)

    # Keep original method for compatibility
    def simulate_individual(self, params, times, doses, dose_times, bw, comed, baseline_R=8.0):
        return self.simulate_individual_fast(params, times, doses, dose_times, bw, comed, baseline_R)

class QuantumEnhancedSAEM:
    """
    SAEM with quantum-enhanced parameter sampling (same interface as original SAEM_Estimator)
    """

    def __init__(self, model):
        self.model = model
        self.data = None
        self.theta = None
        self.omega = None
        self.sigma = None

        # SAEM parameters (same as original)
        self.n_burn_in = 4
        self.n_iterations = 20
        self.n_chains = 4
        self.step_size = 1.0

        # Quantum parameter sampler setup
        self.n_param_qubits = 6
        self.param_dev = qml.device('default.qubit', wires=self.n_param_qubits)

    def load_data(self, df):
        """Load and prepare dataset (same as original)"""
        self.data = df.copy()
        self.pk_data = df[(df['DVID'] == 1) & (df['EVID'] == 0) & (df['MDV'] == 0)].copy()
        self.pd_data = df[(df['DVID'] == 2) & (df['EVID'] == 0) & (df['MDV'] == 0)].copy()

        self.unique_subjects = df['ID'].unique()
        self.subject_data_dict = {}

        for subject_id in self.unique_subjects:
            self.subject_data_dict[subject_id] = df[df['ID'] == subject_id].copy()

        print(f"Loaded {len(self.pk_data)} PK observations and {len(self.pd_data)} PD observations")
        print(f"Preprocessed {len(self.unique_subjects)} subjects")

    def quantum_parameter_sampler(self, phi_params):
        """
        Quantum variational circuit for parameter sampling
        """
        @qml.qnode(self.param_dev)
        def circuit(phi):
            # Create ansatz for parameter distribution
            for i in range(self.n_param_qubits):
                qml.RY(phi[i], wires=i)

            # Entangling layer
            for i in range(self.n_param_qubits - 1):
                qml.CNOT(wires=[i, i + 1])

            # Second rotation layer
            for i in range(self.n_param_qubits):
                qml.RY(phi[i + self.n_param_qubits], wires=i)

            return qml.probs(wires=range(self.n_param_qubits))

        return circuit(phi_params)

    def individual_likelihood_fast(self, eta_i, theta, omega, sigma, subject_data):
        """Fast likelihood calculation (same as original)"""
        params = theta.copy()

        pk_pd_indices = list(range(10))
        for i, idx in enumerate(pk_pd_indices):
            if i < len(eta_i):
                params[idx] *= np.exp(eta_i[i])

        bw = subject_data['BW'].iloc[0]
        comed = subject_data['COMED'].iloc[0]

        pk_obs = subject_data[subject_data['DVID'] == 1]
        pd_obs = subject_data[subject_data['DVID'] == 2]

        if len(pk_obs) == 0 and len(pd_obs) == 0:
            return -np.inf

        doses = subject_data[subject_data['EVID'] == 1]['AMT'].values
        dose_times = subject_data[subject_data['EVID'] == 1]['TIME'].values

        all_times = np.sort(subject_data['TIME'].unique())

        try:
            # Use enhanced simulation
            conc_pred, response_pred = self.model.simulate_individual_fast(
                params, all_times, doses, dose_times, bw, comed
            )

            log_likelihood = 0

            # Same likelihood calculations as original
            if len(pk_obs) > 0:
                pk_times = pk_obs['TIME'].values
                pk_observed = pk_obs['DV'].values

                pk_predicted = np.interp(pk_times, all_times, conc_pred)
                valid_idx = (pk_predicted > 0) & (pk_observed > 0) & np.isfinite(pk_predicted)

                if np.sum(valid_idx) > 0:
                    pk_residuals = np.log(pk_observed[valid_idx]) - np.log(pk_predicted[valid_idx])
                    log_likelihood += -0.5 * np.sum(pk_residuals**2 / sigma[0]**2)
                    log_likelihood += -0.5 * len(pk_residuals) * np.log(2 * np.pi * sigma[0]**2)
                    log_likelihood += -np.sum(np.log(pk_observed[valid_idx]))

            if len(pd_obs) > 0:
                pd_times = pd_obs['TIME'].values
                pd_observed = pd_obs['DV'].values

                pd_predicted = np.interp(pd_times, all_times, response_pred)
                valid_idx = (pd_predicted > 0) & (pd_observed > 0) & np.isfinite(pd_predicted)

                if np.sum(valid_idx) > 0:
                    pd_residuals = pd_observed[valid_idx] - pd_predicted[valid_idx]
                    pd_variance = (sigma[1] * pd_predicted[valid_idx])**2
                    log_likelihood += -0.5 * np.sum(pd_residuals**2 / pd_variance)
                    log_likelihood += -0.5 * np.sum(np.log(2 * np.pi * pd_variance))

            # Prior for random effects (same as original)
            if len(eta_i) > 0:
                log_likelihood += -0.5 * eta_i.T @ np.linalg.solve(omega, eta_i)

        except Exception as e:
            return -np.inf

        return log_likelihood

    def quantum_mcmc_step(self, eta, theta, omega, sigma, subject_data, step_size):
        """
        Enhanced MCMC step with quantum parameter proposals
        """
        # Generate quantum-enhanced proposal (simplified implementation)
        try:
            phi_params = np.random.uniform(0, 2*np.pi, 2 * self.n_param_qubits)

            # Set device for this qnode
            quantum_sampler = qml.QNode(self.quantum_parameter_sampler, self.param_dev)
            probs = quantum_sampler(phi_params)

            # Generate proposal from quantum distribution
            sample_idx = np.random.choice(len(probs), p=probs)
            eta_prop = eta + np.random.normal(0, step_size, len(eta))

        except:
            # Fallback to classical proposal
            eta_prop = eta + np.random.normal(0, step_size, len(eta))

        # Classical Metropolis acceptance (same as original)
        ll_current = self.individual_likelihood_fast(eta, theta, omega, sigma, subject_data)
        ll_prop = self.individual_likelihood_fast(eta_prop, theta, omega, sigma, subject_data)

        alpha = min(1, np.exp(ll_prop - ll_current))

        if np.random.random() < alpha:
            return eta_prop, ll_prop
        else:
            return eta, ll_current

    def mcmc_step(self, eta, theta, omega, sigma, subject_data, step_size):
        """Wrapper to maintain original interface"""
        return self.quantum_mcmc_step(eta, theta, omega, sigma, subject_data, step_size)

    def saem_iteration(self, theta, omega, sigma):
        """SAEM iteration (same structure as original)"""
        def process_subject(subject_id):
            subject_data = self.subject_data_dict[subject_id]
            eta = np.random.multivariate_normal(np.zeros(len(omega)), omega * 0.1)

            # MCMC steps for this subject
            for _ in range(3):
                eta, _ = self.quantum_mcmc_step(eta, theta, omega, sigma, subject_data, self.step_size)

            return subject_id, eta

        # Parallel processing (same as original)
        with ThreadPoolExecutor(max_workers=min(mp.cpu_count(), len(self.unique_subjects))) as executor:
            results = list(executor.map(process_subject, self.unique_subjects))

        eta_estimates = {}
        for subject_id, eta in results:
            eta_estimates[subject_id] = eta

        return eta_estimates

    def update_parameters(self, eta_estimates, iteration):
        """Update parameters (same as original)"""
        gamma = min(1.0, 10.0 / (iteration + 10))
        eta_values = np.array(list(eta_estimates.values()))

        if len(eta_values) > 0:
            empirical_cov = np.cov(eta_values.T)
            self.omega = (1 - gamma) * self.omega + gamma * empirical_cov

            # Ensure positive definiteness
            eigenvals, eigenvecs = np.linalg.eigh(self.omega)
            eigenvals = np.maximum(eigenvals, 1e-6)
            self.omega = eigenvecs @ np.diag(eigenvals) @ eigenvecs.T

    def fit(self, initial_params=None):
        """Fit using enhanced SAEM (same interface as original)"""
        start_time = time.time()

        if initial_params is None:
            initial_params = np.array([
                2.0, 10.0, 1.0, 20.0, 0.5,  # PK
                0.1, 0.8, 2.0, 5.0, 0.1,    # PD
                0.75, 1.0, 0.1, 0.1          # Covariates
            ])

        self.theta = initial_params.copy()
        self.omega = np.eye(10) * 0.1
        self.sigma = np.array([0.2, 0.15])

        print("Starting SAEM parameter estimation...")
        print(f"Burn-in: {self.n_burn_in} iterations")
        print(f"Estimation: {self.n_iterations - self.n_burn_in} iterations")

        best_ll = -np.inf
        best_params = self.theta.copy()

        # SAEM iterations (same as original)
        for iteration in tqdm(range(self.n_iterations), desc="SAEM Progress"):
            eta_estimates = self.saem_iteration(self.theta, self.omega, self.sigma)

            if iteration > self.n_burn_in:
                self.update_parameters(eta_estimates, iteration - self.n_burn_in)

            # Monitor convergence (same as original)
            if iteration % 50 == 0:
                total_ll = 0
                n_successful = 0

                for subject_id in list(self.unique_subjects)[:min(50, len(self.unique_subjects))]:
                    try:
                        eta = eta_estimates.get(subject_id, np.zeros(len(self.omega)))
                        ll = self.individual_likelihood_fast(
                            eta, self.theta, self.omega, self.sigma,
                            self.subject_data_dict[subject_id]
                        )
                        if np.isfinite(ll):
                            total_ll += ll
                            n_successful += 1
                    except:
                        continue

                if n_successful > 0:
                    avg_ll = total_ll / n_successful
                    if avg_ll > best_ll:
                        best_ll = avg_ll
                        best_params = self.theta.copy()

                    print(f"Iteration {iteration}: Avg LL = {avg_ll:.2f}, Best = {best_ll:.2f}")

            perf_monitor.log_metric('convergence_iterations', 1, increment=True)

        self.theta = best_params
        perf_monitor.log_metric('optimization_time', time.time() - start_time, increment=True)

        print(f"\nSAEM converged after {self.n_iterations} iterations")
        print("Final parameter estimates:")
        for i, param_name in enumerate(self.model.all_params[:-2]):
            if i < len(self.theta):
                print(f"  {param_name}: {self.theta[i]:.4f}")

        return self.theta, best_ll

# Keep original class name for compatibility
SAEM_Estimator = QuantumEnhancedSAEM
NLME_Estimator = QuantumEnhancedSAEM

class PopulationSimulator:
    """
    Enhanced population simulator (same interface as original)
    """

    def __init__(self, model, theta, omega, sigma):
        self.model = model
        self.theta = theta
        self.omega = omega
        self.sigma = sigma
        self.batch_size = 60

    def generate_virtual_population_fast(self, n_subjects, bw_range=(50, 100), comed_prob=0.5, seed=None):
        """Generate virtual population (same as original)"""
        if seed:
            np.random.seed(seed)

        body_weights = np.random.uniform(bw_range[0], bw_range[1], n_subjects)
        comed_status = np.random.binomial(1, comed_prob, n_subjects)

        try:
            eta_samples = np.random.multivariate_normal(
                np.zeros(len(self.omega)), self.omega, size=n_subjects
            )
        except np.linalg.LinAlgError:
            eta_samples = np.random.normal(0, 0.1, (n_subjects, len(self.omega)))

        virtual_population = []
        theta_broadcast = np.broadcast_to(self.theta, (n_subjects, len(self.theta)))

        for i in range(n_subjects):
            individual_params = theta_broadcast[i].copy()

            for j in range(min(eta_samples.shape[1], 10)):
                if j < len(individual_params):
                    individual_params[j] *= np.exp(eta_samples[i, j])

            virtual_population.append({
                'subject_id': i,
                'bw': body_weights[i],
                'comed': comed_status[i],
                'params': individual_params,
                'eta': eta_samples[i]
            })

        return virtual_population

    def simulate_batch(self, batch_subjects, dose_mg, dosing_interval_h, simulation_days=28, steady_state_days=21):
        """Simulate batch (same as original)"""
        def simulate_single(subject):
            dt = 1.0
            total_hours = simulation_days * 24
            times = np.arange(0, total_hours + dt, dt)

            dose_times = np.arange(0, total_hours, dosing_interval_h)
            doses = np.full(len(dose_times), dose_mg)

            try:
                conc, response = self.model.simulate_individual_fast(
                    subject['params'], times, doses, dose_times,
                    subject['bw'], subject['comed']
                )

                if not (np.isnan(conc).all() or np.isnan(response).all()):
                    steady_start_idx = int(steady_state_days * 24 / dt)
                    steady_response = response[steady_start_idx:]
                    steady_conc = conc[steady_start_idx:]

                    if dosing_interval_h == 24:
                        interval_size = int(24/dt)
                    else:
                        interval_size = int(168/dt)

                    target_achieved = True
                    for start in range(0, len(steady_response), interval_size):
                        end = min(start + interval_size, len(steady_response))
                        interval_response = steady_response[start:end]
                        if len(interval_response) > 0 and not np.all(interval_response < 3.3):
                            target_achieved = False
                            break

                    return {
                        'subject_id': subject['subject_id'],
                        'bw': subject['bw'],
                        'comed': subject['comed'],
                        'target_achieved': target_achieved,
                        'min_response': np.min(steady_response),
                        'mean_response': np.mean(steady_response),
                        'max_conc': np.max(steady_conc)
                    }
                else:
                    return None
            except:
                return None

        with ThreadPoolExecutor(max_workers=min(len(batch_subjects), mp.cpu_count())) as executor:
            results = list(executor.map(simulate_single, batch_subjects))

        return [r for r in results if r is not None]

    def simulate_dosing_regimen(self, virtual_population, dose_mg, dosing_interval_h, simulation_days=28, steady_state_days=21):
        """Simulate dosing regimen (same as original)"""
        start_time = time.time()

        print(f"Simulating {dose_mg} mg every {dosing_interval_h}h for {len(virtual_population)} subjects...")
        print(f"Using batch processing with {self.batch_size} subjects per batch")

        all_results = []

        for i in tqdm(range(0, len(virtual_population), self.batch_size), desc="Batch Progress"):
            batch_end = min(i + self.batch_size, len(virtual_population))
            batch_subjects = virtual_population[i:batch_end]

            batch_results = self.simulate_batch(
                batch_subjects, dose_mg, dosing_interval_h,
                simulation_days, steady_state_days
            )
            all_results.extend(batch_results)

        perf_monitor.log_metric('simulation_time', time.time() - start_time, increment=True)
        perf_monitor.log_metric('successful_subjects', len(all_results), increment=True)
        perf_monitor.log_metric('failed_subjects', len(virtual_population) - len(all_results), increment=True)

        print(f"Successfully simulated {len(all_results)}/{len(virtual_population)} subjects")

        return pd.DataFrame(all_results)

    def find_optimal_dose_adaptive(self, target_achievement=0.9, dose_range=(0.5, 20), dosing_interval=24, n_subjects=5000, **population_kwargs):
        """Adaptive dose optimization (same as original)"""
        print(f"\nAdaptive dose optimization for {target_achievement*100}% target achievement")
        print(f"Dosing interval: {dosing_interval} hours")

        virtual_pop = self.generate_virtual_population_fast(n_subjects, **population_kwargs)

        kernel = Matern(length_scale=1.0, nu=2.5) + WhiteKernel(noise_level=0.01)
        gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=2)

        initial_doses = np.linspace(dose_range[0], dose_range[1], 5)
        dose_history = []
        achievement_history = []

        for dose in initial_doses:
            if dosing_interval == 24:
                test_dose = dose
            else:
                test_dose = dose * 7

            results = self.simulate_dosing_regimen(virtual_pop, test_dose, dosing_interval)

            if len(results) > 0:
                achievement = results['target_achieved'].mean()
                dose_history.append(dose)
                achievement_history.append(achievement)
                print(f"  Dose {dose:.1f} mg: {achievement:.1%} achievement")

        if len(dose_history) >= 3:
            X = np.array(dose_history).reshape(-1, 1)
            y = np.array(achievement_history)
            gp.fit(X, y)

        # Adaptive optimization (same as original)
        for iteration in range(10):
            if len(dose_history) < 3:
                break

            test_doses = np.linspace(dose_range[0], dose_range[1], 100)
            mean_pred, std_pred = gp.predict(test_doses.reshape(-1, 1), return_std=True)

            beta = 2.0
            acquisition = mean_pred + beta * std_pred

            target_diff = np.abs(mean_pred - target_achievement)
            acquisition_adjusted = acquisition - target_diff

            next_dose_idx = np.argmax(acquisition_adjusted)
            next_dose = test_doses[next_dose_idx]

            if min(np.abs(np.array(dose_history) - next_dose)) < 0.2:
                break

            if dosing_interval == 24:
                test_dose = next_dose
            else:
                test_dose = next_dose * 7

            results = self.simulate_dosing_regimen(virtual_pop, test_dose, dosing_interval)

            if len(results) > 0:
                achievement = results['target_achieved'].mean()
                dose_history.append(next_dose)
                achievement_history.append(achievement)

                print(f"  Adaptive iteration {iteration+1}: Dose {next_dose:.1f} mg: {achievement:.1%}")

                X = np.array(dose_history).reshape(-1, 1)
                y = np.array(achievement_history)
                gp.fit(X, y)

                if abs(achievement - target_achievement) < 0.02:
                    print(f"  Converged at dose {next_dose:.1f} mg")
                    return round(next_dose * 2) / 2 if dosing_interval == 24 else round(next_dose / 5) * 5

        achievement_array = np.array(achievement_history)
        target_mask = achievement_array >= target_achievement

        if np.any(target_mask):
            valid_doses = np.array(dose_history)[target_mask]
            optimal_dose = np.min(valid_doses)
        else:
            optimal_dose = dose_history[np.argmax(achievement_array)]

        if dosing_interval == 24:
            optimal_dose = round(optimal_dose * 2) / 2
        else:
            optimal_dose = round(optimal_dose / 5) * 5

        return optimal_dose

    # Keep original method names for compatibility
    def find_optimal_dose(self, target_achievement=0.9, dose_range=(0.5, 20), dosing_interval=24, n_subjects=5000, **population_kwargs):
        return self.find_optimal_dose_adaptive(target_achievement, dose_range, dosing_interval, n_subjects, **population_kwargs)

    def generate_virtual_population(self, n_subjects, **kwargs):
        return self.generate_virtual_population_fast(n_subjects, **kwargs)

def run_complete_simulation_enhanced(df):
    """
    Complete simulation with quantum enhancement (same interface and outputs as original)
    """
    perf_monitor.reset()

    print("STARTING ENHANCED PKPD SIMULATION")
    print(f"CPU Cores Available: {mp.cpu_count()}")
    print("Quantum Enhancement: PennyLane-based compartment simulation")

    # Initialize enhanced model and estimator (same interface)
    model = PK_PD_Model()
    estimator = SAEM_Estimator(model)
    estimator.load_data(df)

    # Fit model (same interface)
    print("\nFITTING NLME MODEL WITH SAEM...")
    final_params, final_ll = estimator.fit()

    # Initialize simulator (same interface)
    simulator = PopulationSimulator(model, estimator.theta, estimator.omega, estimator.sigma)

    # Dose optimization
    print("ENHANCED DOSE OPTIMIZATION RESULTS")

    results = {}

    print("\n1. Daily dose for 90% target achievement (original population):")
    daily_dose_90 = simulator.find_optimal_dose_adaptive(
        target_achievement=0.9, dosing_interval=24,
        bw_range=(50, 100), comed_prob=0.5, n_subjects=200
    )
    results['daily_90_original'] = daily_dose_90
    print(f"   Optimal daily dose: {daily_dose_90:.1f} mg")

    print("\n2. Weekly dose for 90% target achievement (original population):")
    weekly_dose_90 = simulator.find_optimal_dose_adaptive(
        target_achievement=0.9, dosing_interval=168,
        bw_range=(50, 100), comed_prob=0.5, n_subjects=200
    )
    results['weekly_90_original'] = weekly_dose_90
    print(f"   Optimal weekly dose: {weekly_dose_90:.0f} mg")

    print("\n3. Effect of changed body weight distribution (70-140 kg):")
    daily_dose_90_heavy = simulator.find_optimal_dose_adaptive(
        target_achievement=0.9, dosing_interval=24,
        bw_range=(70, 140), comed_prob=0.5, n_subjects=200
    )
    weekly_dose_90_heavy = simulator.find_optimal_dose_adaptive(
        target_achievement=0.9, dosing_interval=168,
        bw_range=(70, 140), comed_prob=0.5, n_subjects=200
    )
    results['daily_90_heavy'] = daily_dose_90_heavy
    results['weekly_90_heavy'] = weekly_dose_90_heavy
    print(f"   Daily dose (heavy population): {daily_dose_90_heavy:.1f} mg")
    print(f"   Weekly dose (heavy population): {weekly_dose_90_heavy:.0f} mg")

    print("\n4. Effect of restricting concomitant medication:")
    daily_dose_90_no_comed = simulator.find_optimal_dose_adaptive(
        target_achievement=0.9, dosing_interval=24,
        bw_range=(50, 100), comed_prob=0.0, n_subjects=200
    )
    weekly_dose_90_no_comed = simulator.find_optimal_dose_adaptive(
        target_achievement=0.9, dosing_interval=168,
        bw_range=(50, 100), comed_prob=0.0, n_subjects=200
    )
    results['daily_90_no_comed'] = daily_dose_90_no_comed
    results['weekly_90_no_comed'] = weekly_dose_90_no_comed
    print(f"   Daily dose (no COMED): {daily_dose_90_no_comed:.1f} mg")
    print(f"   Weekly dose (no COMED): {weekly_dose_90_no_comed:.0f} mg")

    print("\n5. Doses for 75% target achievement:")

    scenarios_75 = [
        ('original', (50, 100), 0.5),
        ('heavy', (70, 140), 0.5),
        ('no_comed', (50, 100), 0.0)
    ]

    for scenario_name, bw_range, comed_prob in scenarios_75:
        daily_dose_75 = simulator.find_optimal_dose_adaptive(
            target_achievement=0.75, dosing_interval=24,
            bw_range=bw_range, comed_prob=comed_prob, n_subjects=200
        )
        weekly_dose_75 = simulator.find_optimal_dose_adaptive(
            target_achievement=0.75, dosing_interval=168,
            bw_range=bw_range, comed_prob=comed_prob, n_subjects=200
        )

        results[f'daily_75_{scenario_name}'] = daily_dose_75
        results[f'weekly_75_{scenario_name}'] = weekly_dose_75

        print(f"   {scenario_name.replace('_', ' ').title()} population:")
        print(f"     Daily (75%): {daily_dose_75:.1f} mg")
        print(f"     Weekly (75%): {weekly_dose_75:.0f} mg")

    # Enhanced summary (same format as original)
    print("ENHANCED SUMMARY OF OPTIMAL DOSES")

    summary_data = []
    scenarios = ['original', 'heavy', 'no_comed']
    scenario_names = ['Original Pop', 'Heavy Pop (70-140kg)', 'No COMED']

    for scenario, name in zip(scenarios, scenario_names):
        summary_data.extend([
            {
                'Scenario': f'{name} (90%)',
                'Daily (mg)': results[f'daily_90_{scenario}'],
                'Weekly (mg)': results[f'weekly_90_{scenario}'],
                'Daily vs Weekly Ratio': results[f'weekly_90_{scenario}'] / (results[f'daily_90_{scenario}'] * 7)
            },
            {
                'Scenario': f'{name} (75%)',
                'Daily (mg)': results[f'daily_75_{scenario}'],
                'Weekly (mg)': results[f'weekly_75_{scenario}'],
                'Daily vs Weekly Ratio': results[f'weekly_75_{scenario}'] / (results[f'daily_75_{scenario}'] * 7)
            }
        ])

    summary_df = pd.DataFrame(summary_data)
    print(summary_df.to_string(index=False, float_format='%.2f'))

    # Dose reduction analysis (same as original)
    print("\nENHANCED DOSE REDUCTION ANALYSIS (75% vs 90% achievement):")
    for scenario, name in zip(scenarios, scenario_names):
        daily_90_key = f'daily_90_{scenario}'
        daily_75_key = f'daily_75_{scenario}'
        weekly_90_key = f'weekly_90_{scenario}'
        weekly_75_key = f'weekly_75_{scenario}'

        daily_reduction = results[daily_90_key] - results[daily_75_key]
        weekly_reduction = results[weekly_90_key] - results[weekly_75_key]

        daily_pct_reduction = (daily_reduction / results[daily_90_key]) * 100
        weekly_pct_reduction = (weekly_reduction / results[weekly_90_key]) * 100

        print(f"   {name}:")
        print(f"     Daily: -{daily_reduction:.1f} mg ({daily_pct_reduction:.1f}% reduction)")
        print(f"     Weekly: -{weekly_reduction:.0f} mg ({weekly_pct_reduction:.1f}% reduction)")
        print(f"     Efficiency gain: {weekly_pct_reduction - daily_pct_reduction:.1f}% better with weekly dosing")

    # Performance report
    perf_monitor.report()

    return results, estimator, simulator

results, estimator, simulator = run_complete_simulation_enhanced(df)
create_enhanced_diagnostic_plots(estimator, simulator, results)