# Generating Data

In [None]:
# Generating the synthetic dataset for schizophrenia clinical trials

import pandas as pd
import numpy as np
import random
import time
import math
import matplotlib.pyplot as plt
import seaborn as sns

# --- Parameters & Configurations ---

# --- Regional Placebo Effects ---
PLACEBO_PANSS_CHANGE_MEAN_EU = -4.0
PLACEBO_PANSS_CHANGE_MEAN_US = -15.0
PLACEBO_PANSS_CHANGE_STD_EU = 2.0
PLACEBO_PANSS_CHANGE_STD_US = 4.0

# --- Regional Demographic Parameters ---
ETHNICITY_CATEGORIES = ['White', 'Black', 'Hispanic/Latino', 'Asian', 'Other/Mixed']
ETHNICITY_WEIGHTS_US = [0.60, 0.15, 0.15, 0.06, 0.04]
ETHNICITY_WEIGHTS_EU = [0.88, 0.03, 0.01, 0.03, 0.05]
COMORBIDITY_SUBSTANCE_USE_RATE_US = 0.40
COMORBIDITY_SUBSTANCE_USE_RATE_EU = 0.30

# --- Drug Profiles ---
DRUG_PROFILES = {
    'Placebo': {
        'effect_per_unit': 0, 'unit_mg': 0, 'effect_std_bonus': 0, 'neg_symptom_bonus': 0,
        'cog_effect_per_unit': 0, 'weight_gain_per_unit': 0, 'bmi_change_per_unit': 0,
        'glucose_change_per_unit': 0, 'ldl_change_per_unit': 0, 'eps_effect_per_unit': 0,
        'ae_bonus_factor': 0,
        'MOA': 'None', 'Drug_Class': 'Placebo', 'Sponsor': 'N/A', 'available_doses': [0] 
    },
    'Drug_A': { # Moderate all-rounder
        'effect_per_unit': -2.5, 'unit_mg': 10, 'effect_std_bonus': 1.0, 'neg_symptom_bonus': 0,
        'cog_effect_per_unit': 0.02, 'weight_gain_per_unit': 0.7, 'bmi_change_per_unit': 0.25,
        'glucose_change_per_unit': 2.0, 'ldl_change_per_unit': 3.0, 'eps_effect_per_unit': 0.05,
        'ae_bonus_factor': 0.05, 'available_doses': [10, 20, 40],
        'MOA': 'D2/5HT2A Antagonist', 'Drug_Class': 'SGA', 'Sponsor': 'PharmaCo A' 
    },
    'Drug_B': { # Stronger, slight negative focus, some EPS
        'effect_per_unit': -4.0, 'unit_mg': 5, 'effect_std_bonus': 1.5, 'neg_symptom_bonus': -1.0,
        'cog_effect_per_unit': 0.08, 'weight_gain_per_unit': 0.2, 'bmi_change_per_unit': 0.08,
        'glucose_change_per_unit': 0.5, 'ldl_change_per_unit': 0.8, 'eps_effect_per_unit': 0.25,
        'ae_bonus_factor': 0.16, 'available_doses': [5, 10],
        'MOA': 'Potent D2 Antagonist', 'Drug_Class': 'SGA', 'Sponsor': 'PharmaCo B' 
    },
    'Drug_C': { # Negative Symptom Focus, Moderate Efficacy, Metabolic Risk
        'effect_per_unit': -2.0, 'unit_mg': 20, 'effect_std_bonus': 1.2, 'neg_symptom_bonus': -1.5,
        'cog_effect_per_unit': 0.03, 'weight_gain_per_unit': 1.0, 'bmi_change_per_unit': 0.35,
        'glucose_change_per_unit': 2.5, 'ldl_change_per_unit': 4.0, 'eps_effect_per_unit': 0.08,
        'ae_bonus_factor': 0.04, 'available_doses': [20, 40, 60],
        'MOA': '5HT2A/D2 Antagonist', 'Drug_Class': 'SGA', 'Sponsor': 'PharmaCo A' 
    },
    'Drug_D': { # High Efficacy (Positive?), High EPS Risk
        'effect_per_unit': -5.0, 'unit_mg': 10, 'effect_std_bonus': 1.8, 'neg_symptom_bonus': 0.5,
        'cog_effect_per_unit': 0.01, 'weight_gain_per_unit': 0.3, 'bmi_change_per_unit': 0.1,
        'glucose_change_per_unit': 0.8, 'ldl_change_per_unit': 1.0, 'eps_effect_per_unit': 0.40,
        'ae_bonus_factor': 0.10, 'available_doses': [5, 10, 15],
        'MOA': 'Strong D2 Antagonist', 'Drug_Class': 'FGA-like SGA', 'Sponsor': 'PharmaCo C' 
    },
    'Drug_E': { # Lower Efficacy, Good Tolerability, Cognitive Edge
        'effect_per_unit': -1.5, 'unit_mg': 50, 'effect_std_bonus': 0.8, 'neg_symptom_bonus': -0.5,
        'cog_effect_per_unit': 0.12, 'weight_gain_per_unit': 0.1, 'bmi_change_per_unit': 0.04,
        'glucose_change_per_unit': 0.2, 'ldl_change_per_unit': 0.3, 'eps_effect_per_unit': 0.02,
        'ae_bonus_factor': 0.01, 'available_doses': [50, 100, 150],
        'MOA': 'Partial D2 Agonist', 'Drug_Class': 'TGA', 'Sponsor': 'PharmaCo B' 
    },
    'Drug_F': { # Similar Efficacy to A, Less Metabolic, Higher AE count
        'effect_per_unit': -2.6, 'unit_mg': 10, 'effect_std_bonus': 1.1, 'neg_symptom_bonus': -0.2,
        'cog_effect_per_unit': 0.03, 'weight_gain_per_unit': 0.4, 'bmi_change_per_unit': 0.15,
        'glucose_change_per_unit': 1.0, 'ldl_change_per_unit': 1.5,
        'eps_effect_per_unit': 0.06,
        'ae_bonus_factor': 0.08, 'available_doses': [10, 20, 30],
        'MOA': 'D2/5HT2A/5HT1A', 'Drug_Class': 'SGA', 'Sponsor': 'PharmaCo C' 
    },
    'Drug_G': { # Adjunctive candidate? Low PANSS effect, good Cog/Neg, Tolerable
        'effect_per_unit': -1.0, 'unit_mg': 25, 'effect_std_bonus': 0.7, 'neg_symptom_bonus': -1.0,
        'cog_effect_per_unit': 0.15,
        'weight_gain_per_unit': 0.05, 'bmi_change_per_unit': 0.02,
        'glucose_change_per_unit': 0.1, 'ldl_change_per_unit': 0.2,
        'eps_effect_per_unit': 0.01,
        'ae_bonus_factor': 0.02, 'available_doses': [25, 50, 75],
        'MOA': 'Glycine Modulator', 'Drug_Class': 'Adjunctive', 'Sponsor': 'Academia Inc' 
    }
}


# --- Study Configurations ---
STUDIES_CONFIG = [
    {"Study_ID": "ACUTE_EFF_001", "N_Patients": 800, "Drugs": ['Placebo', 'Drug_A'], "Doses": {'Drug_A': [10, 20, 40]}, "Percent_US": 0.60, "Duration_Weeks": 12},
    {"Study_ID": "NEG_SYM_002", "N_Patients": 750, "Drugs": ['Placebo', 'Drug_B', 'Drug_C'], "Doses": {'Drug_B': [5, 10], 'Drug_C': [20, 40]}, "Percent_US": 0.30, "Duration_Weeks": 12},
    {"Study_ID": "COMPARE_003", "N_Patients": 1200, "Drugs": ['Placebo', 'Drug_A', 'Drug_B'], "Doses": {'Drug_A': [20], 'Drug_B': [10]}, "Percent_US": 0.50, "Duration_Weeks": 10},
    {"Study_ID": "DOSE_FIND_A_004", "N_Patients": 600, "Drugs": ['Placebo', 'Drug_A'], "Doses": {'Drug_A': [5, 10, 20, 40]}, "Percent_US": 0.80, "Duration_Weeks": 8},
    {"Study_ID": "NEG_FOCUS_005", "N_Patients": 900, "Drugs": ['Placebo', 'Drug_C', 'Drug_E'], "Doses": {'Drug_C': [40], 'Drug_E': [100, 150]}, "Percent_US": 0.40, "Duration_Weeks": 16},
    {"Study_ID": "HIGH_EPS_COMP_006", "N_Patients": 1000, "Drugs": ['Placebo', 'Drug_B', 'Drug_D'], "Doses": {'Drug_B': [10], 'Drug_D': [10, 15]}, "Percent_US": 0.55, "Duration_Weeks": 10},
    {"Study_ID": "LOW_DOSE_E_007", "N_Patients": 700, "Drugs": ['Placebo', 'Drug_E'], "Doses": {'Drug_E': [50, 100]}, "Percent_US": 0.70, "Duration_Weeks": 12},
    {"Study_ID": "EU_DRUG_D_008", "N_Patients": 850, "Drugs": ['Placebo', 'Drug_D'], "Doses": {'Drug_D': [5, 10]}, "Percent_US": 0.15, "Duration_Weeks": 8},
    {"Study_ID": "US_DRUG_C_009", "N_Patients": 950, "Drugs": ['Placebo', 'Drug_C'], "Doses": {'Drug_C': [20, 40, 60]}, "Percent_US": 0.85, "Duration_Weeks": 12},
    {"Study_ID": "MIX_ALL_010", "N_Patients": 1500, "Drugs": ['Placebo', 'Drug_A', 'Drug_C', 'Drug_E'], "Doses": {'Drug_A': [20], 'Drug_C': [40], 'Drug_E': [100]}, "Percent_US": 0.50, "Duration_Weeks": 12},
    {"Study_ID": "ACUTE_COMP_011", "N_Patients": 1100, "Drugs": ['Placebo', 'Drug_A', 'Drug_D'], "Doses": {'Drug_A': [40], 'Drug_D': [10]}, "Percent_US": 0.65, "Duration_Weeks": 8},
    {"Study_ID": "MAINTENANCE_012", "N_Patients": 1300, "Drugs": ['Placebo', 'Drug_A', 'Drug_E'], "Doses": {'Drug_A': [20], 'Drug_E': [100]}, "Percent_US": 0.45, "Duration_Weeks": 24},
    {"Study_ID": "EU_MIX_013", "N_Patients": 1000, "Drugs": ['Placebo', 'Drug_B', 'Drug_C'], "Doses": {'Drug_B': [5], 'Drug_C': [20, 40]}, "Percent_US": 0.25, "Duration_Weeks": 10},
    {"Study_ID": "US_MIX_014", "N_Patients": 1200, "Drugs": ['Placebo', 'Drug_A', 'Drug_E'], "Doses": {'Drug_A': [10, 20], 'Drug_E': [50, 100]}, "Percent_US": 0.75, "Duration_Weeks": 12},
    {"Study_ID": "DOSE_FIND_D_015", "N_Patients": 800, "Drugs": ['Placebo', 'Drug_D'], "Doses": {'Drug_D': [5, 10, 15, 20]}, "Percent_US": 0.50, "Duration_Weeks": 6},
    {"Study_ID": "DOSE_FIND_C_016", "N_Patients": 900, "Drugs": ['Placebo', 'Drug_C'], "Doses": {'Drug_C': [20, 40, 60, 80]}, "Percent_US": 0.60, "Duration_Weeks": 10},
    {"Study_ID": "BALANCED_AE_017", "N_Patients": 1100, "Drugs": ['Placebo', 'Drug_A', 'Drug_E'], "Doses": {'Drug_A': [20], 'Drug_E': [100]}, "Percent_US": 0.50, "Duration_Weeks": 12},
    {"Study_ID": "HEAD_TO_HEAD_018", "N_Patients": 1400, "Drugs": ['Drug_A', 'Drug_D'], "Doses": {'Drug_A': [40], 'Drug_D': [10]}, "Percent_US": 0.50, "Duration_Weeks": 10}, # No Placebo
    {"Study_ID": "ADD_ON_NEG_019", "N_Patients": 850, "Drugs": ['Placebo', 'Drug_C'], "Doses": {'Drug_C': [20, 40]}, "Percent_US": 0.35, "Duration_Weeks": 16}, # Assumes add-on context for effect size interpretation
    {"Study_ID": "LONG_TERM_TOL_020", "N_Patients": 1000, "Drugs": ['Placebo', 'Drug_E'], "Doses": {'Drug_E': [100, 150]}, "Percent_US": 0.60, "Duration_Weeks": 26},
    {"Study_ID": "DOSE_FIND_E_021", "N_Patients": 750, "Drugs": ['Placebo', 'Drug_E'], "Doses": {'Drug_E': [50, 100, 150, 200]}, "Percent_US": 0.45, "Duration_Weeks": 10},
    {"Study_ID": "COMPARE_F_A_022", "N_Patients": 1300, "Drugs": ['Placebo', 'Drug_A', 'Drug_F'], "Doses": {'Drug_A': [20], 'Drug_F': [20]}, "Percent_US": 0.50, "Duration_Weeks": 12}, # Compare A and new F
    {"Study_ID": "US_DRUG_F_023", "N_Patients": 900, "Drugs": ['Placebo', 'Drug_F'], "Doses": {'Drug_F': [10, 20, 30]}, "Percent_US": 0.90, "Duration_Weeks": 8},
    {"Study_ID": "EU_DRUG_F_024", "N_Patients": 800, "Drugs": ['Placebo', 'Drug_F'], "Doses": {'Drug_F': [10, 20]}, "Percent_US": 0.10, "Duration_Weeks": 10},
    {"Study_ID": "ADJUNCT_COG_G_025", "N_Patients": 950, "Drugs": ['Placebo', 'Drug_G'], "Doses": {'Drug_G': [50, 75]}, "Percent_US": 0.55, "Duration_Weeks": 16}, # Test G as adjunct (assume context)
    {"Study_ID": "COMPARE_G_E_026", "N_Patients": 1100, "Drugs": ['Placebo', 'Drug_E', 'Drug_G'], "Doses": {'Drug_E': [100], 'Drug_G': [50]}, "Percent_US": 0.40, "Duration_Weeks": 12}, # Compare low AE / Cog drugs
    {"Study_ID": "LARGE_GLOBAL_027", "N_Patients": 2000, "Drugs": ['Placebo', 'Drug_A', 'Drug_F'], "Doses": {'Drug_A': [20], 'Drug_F': [20]}, "Percent_US": 0.50, "Duration_Weeks": 12}, # Very large Phase 3
    {"Study_ID": "LONG_MAINT_028", "N_Patients": 1500, "Drugs": ['Placebo', 'Drug_A', 'Drug_F'], "Doses": {'Drug_A': [20], 'Drug_F': [10]}, "Percent_US": 0.60, "Duration_Weeks": 52}, # 1 year study
    {"Study_ID": "SHORT_ACUTE_D_029", "N_Patients": 650, "Drugs": ['Placebo', 'Drug_D'], "Doses": {'Drug_D': [10]}, "Percent_US": 0.70, "Duration_Weeks": 6},
    {"Study_ID": "EU_COG_FOCUS_030", "N_Patients": 850, "Drugs": ['Placebo', 'Drug_E', 'Drug_G'], "Doses": {'Drug_E': [150], 'Drug_G': [75]}, "Percent_US": 0.20, "Duration_Weeks": 16},
    {"Study_ID": "MIX_BCF_031", "N_Patients": 1250, "Drugs": ['Placebo', 'Drug_B', 'Drug_C', 'Drug_F'], "Doses": {'Drug_B': [10], 'Drug_C': [40], 'Drug_F': [20]}, "Percent_US": 0.50, "Duration_Weeks": 10},
    {"Study_ID": "HEAD_TO_HEAD_BF_032", "N_Patients": 1000, "Drugs": ['Drug_B', 'Drug_F'], "Doses": {'Drug_B': [10], 'Drug_F': [30]}, "Percent_US": 0.60, "Duration_Weeks": 12}, # Active comparator
    {"Study_ID": "PHASE_II_G_033", "N_Patients": 450, "Drugs": ['Placebo', 'Drug_G'], "Doses": {'Drug_G': [25, 50, 75]}, "Percent_US": 0.40, "Duration_Weeks": 8}, # Smaller Phase II size
    {"Study_ID": "PHASE_II_F_034", "N_Patients": 500, "Drugs": ['Placebo', 'Drug_F'], "Doses": {'Drug_F': [5, 10, 20, 30]}, "Percent_US": 0.70, "Duration_Weeks": 8},
    {"Study_ID": "SWITCH_STUDY_AE_035", "N_Patients": 900, "Drugs": ['Drug_A', 'Drug_E'], "Doses": {'Drug_A': [20], 'Drug_E': [100]}, "Percent_US": 0.50, "Duration_Weeks": 24}, # Simulating switch for tolerability (analyze baseline differently in reality)
    {"Study_ID": "GLOBAL_NEG_SYM_036", "N_Patients": 1600, "Drugs": ['Placebo', 'Drug_C', 'Drug_G'], "Doses": {'Drug_C': [40], 'Drug_G': [50]}, "Percent_US": 0.48, "Duration_Weeks": 16},
    {"Study_ID": "US_ONLY_A_HIGH_037", "N_Patients": 700, "Drugs": ['Placebo', 'Drug_A'], "Doses": {'Drug_A': [40, 60]}, "Percent_US": 1.00, "Duration_Weeks": 10}, # Higher doses of A, US only
    {"Study_ID": "EU_ONLY_B_LOW_038", "N_Patients": 600, "Drugs": ['Placebo', 'Drug_B'], "Doses": {'Drug_B': [2.5, 5]}, "Percent_US": 0.00, "Duration_Weeks": 12}, # Lower doses of B, EU only
    {"Study_ID": "COMPARE_D_F_039", "N_Patients": 1150, "Drugs": ['Placebo', 'Drug_D', 'Drug_F'], "Doses": {'Drug_D': [10], 'Drug_F': [20]}, "Percent_US": 0.50, "Duration_Weeks": 8},
    {"Study_ID": "LONG_TERM_G_ADJ_040", "N_Patients": 1000, "Drugs": ['Placebo', 'Drug_G'], "Doses": {'Drug_G': [50]}, "Percent_US": 0.55, "Duration_Weeks": 52}, # Long term adjunct G
]


# Total patients for progress tracking
TOTAL_PATIENTS_TO_GENERATE = sum(study['N_Patients'] for study in STUDIES_CONFIG)


# --- Baseline & Simulation Parameters ---
AGE_MEAN, AGE_STD = 35, 8
SEX_DIST = {'Male': 0.65, 'Female': 0.35}
YEARS_DX_MEAN, YEARS_DX_STD = 7, 5
AGE_ONSET_MEAN, AGE_ONSET_STD = 24, 6
EDUCATION_DIST = {'High School/GED': 0.45, 'Some College/Associate': 0.35, 'Bachelor Degree+': 0.20}
SES_DIST = {'Low': 0.40, 'Medium': 0.50, 'High': 0.10}
BMI_MEAN, BMI_STD = 27, 4
GLUCOSE_MEAN, GLUCOSE_STD = 95, 15
LDL_MEAN, LDL_STD = 110, 25
CV_RISK_RATE = 0.20

# Baseline Clinical
PANSS_TOTAL_MEAN, PANSS_TOTAL_STD = 90, 12
PANSS_NEG_MEAN_RATIO = 0.25
PANSS_POS_MEAN_RATIO = 0.23
BPRS_BASELINE_MEAN_RATIO = 0.5
BPRS_BASELINE_STD_RATIO = 0.05
SANS_BASELINE_MEAN_RATIO = 1.5
SANS_BASELINE_STD_RATIO = 0.2
CGI_S_BASELINE_DIST = {4: 0.2, 5: 0.5, 6: 0.3}
GAF_BASELINE_MEAN, GAF_BASELINE_STD = 45, 8
QOL_BASELINE_MEAN, QOL_BASELINE_STD = 50, 10
COG_BASELINE_MEAN, COG_BASELINE_STD = -1.0, 0.7
PREV_AP_MEAN, PREV_AP_STD = 2.5, 1.5

# Outcome Simulation Parameters
BPRS_CHANGE_FACTOR = 0.55
SANS_CHANGE_FACTOR = 1.6
CGI_CHANGE_SCALE_FACTOR = 0.1
PLACEBO_GAF_CHANGE_MEAN, PLACEBO_GAF_CHANGE_STD = 1.5, 4
PLACEBO_QOL_CHANGE_MEAN, PLACEBO_QOL_CHANGE_STD = 2, 5
PLACEBO_COG_CHANGE_MEAN, PLACEBO_COG_CHANGE_STD = 0.05, 0.2
PLACEBO_WEIGHT_GAIN_MEAN, PLACEBO_WEIGHT_GAIN_STD = 0.5, 1.0
PLACEBO_BMI_CHANGE_MEAN, PLACEBO_BMI_CHANGE_STD = 0.2, 0.4
PLACEBO_GLUCOSE_CHANGE_MEAN, PLACEBO_GLUCOSE_CHANGE_STD = 1, 5
PLACEBO_LDL_CHANGE_MEAN, PLACEBO_LDL_CHANGE_STD = 1, 8
PLACEBO_EPS_SAS_CHANGE_MEAN, PLACEBO_EPS_SAS_CHANGE_STD = 0.1, 0.3
PLACEBO_AE_COUNT_MEAN, PLACEBO_AE_COUNT_STD = 1.5, 1.0
AE_DISCONTINUE_BASE_RATE = 0.03
AE_DISCONTINUE_PER_AE = 0.02
AE_DISCONTINUE_PER_KG_WG = 0.01
AE_DISCONTINUE_PER_EPS_UNIT = 0.03


# --- Data Generation Function ---
def generate_patient_record(patient_id_counter, study_config):
    patient = {}
    patient['Patient_ID'] = f"P{patient_id_counter}"
    patient['Study_ID'] = study_config["Study_ID"]

    percent_us = study_config["Percent_US"]
    is_us = random.random() < percent_us
    patient['Region'] = 'US' if is_us else 'EU'

    if is_us:
        patient['Ethnicity'] = random.choices(ETHNICITY_CATEGORIES, weights=ETHNICITY_WEIGHTS_US, k=1)[0]
        patient['Site_Type'] = random.choices(
            ['Academic Hospital', 'Private Clinic', 'VA Hospital', 'Community Mental Health Center'],
            weights=[0.35, 0.40, 0.15, 0.10], k=1)[0]
        age_mod, years_dx_mod, panss_total_mod = 0, -0.5, 1
        ses_weights = [0.35, 0.55, 0.10]
        substance_use_rate = COMORBIDITY_SUBSTANCE_USE_RATE_US
    else: # EU
        patient['Ethnicity'] = random.choices(ETHNICITY_CATEGORIES, weights=ETHNICITY_WEIGHTS_EU, k=1)[0]
        patient['Site_Type'] = random.choices(
            ['University Hospital', 'National Health Service Clinic', 'Private Practice'],
            weights=[0.55, 0.35, 0.10], k=1)[0]
        age_mod, years_dx_mod, panss_total_mod = 0, 0.5, -1
        ses_weights = [0.45, 0.48, 0.07]
        substance_use_rate = COMORBIDITY_SUBSTANCE_USE_RATE_EU

    # --- Demographics & Baseline ---
    patient['Age'] = max(18, int(np.random.normal(AGE_MEAN + age_mod, AGE_STD)))
    patient['Sex'] = random.choices(list(SEX_DIST.keys()), weights=list(SEX_DIST.values()), k=1)[0]
    patient['Age_at_Onset'] = max(14, int(np.random.normal(min(patient['Age']-1, AGE_ONSET_MEAN), AGE_ONSET_STD)))
    years_since_onset = patient['Age'] - patient['Age_at_Onset']
    patient['Years_Since_Dx'] = max(0.1, round(np.random.normal(years_since_onset + years_dx_mod, YEARS_DX_STD / 2), 1))
    patient['Education_Level'] = random.choices(list(EDUCATION_DIST.keys()), weights=list(EDUCATION_DIST.values()), k=1)[0]
    patient['Socioeconomic_Status'] = random.choices(list(SES_DIST.keys()), weights=ses_weights, k=1)[0]
    patient['Baseline_BMI'] = round(max(16, np.random.normal(BMI_MEAN, BMI_STD)), 1)
    patient['Baseline_Glucose_mgdL'] = int(max(60, np.random.normal(GLUCOSE_MEAN, GLUCOSE_STD)))
    patient['Baseline_LDL_mgdL'] = int(max(50, np.random.normal(LDL_MEAN, LDL_STD)))
    patient['Cardiovascular_Risk_Factor'] = 1 if random.random() < CV_RISK_RATE else 0
    patient['Comorbidity_Substance_Use'] = 1 if random.random() < substance_use_rate else 0

    # --- Baseline Clinical ---
    patient['Baseline_PANSS_Total'] = int(np.random.normal(PANSS_TOTAL_MEAN + panss_total_mod, PANSS_TOTAL_STD))
    base_neg_prop = np.random.normal(PANSS_NEG_MEAN_RATIO, 0.06)
    base_pos_prop = np.random.normal(PANSS_POS_MEAN_RATIO, 0.06)
    base_gen_prop = max(0, 1.0 - base_neg_prop - base_pos_prop)
    norm_factor = base_neg_prop + base_pos_prop + base_gen_prop
    base_neg_prop /= norm_factor
    base_pos_prop /= norm_factor
    base_gen_prop /= norm_factor
    patient['Baseline_PANSS_Negative'] = max(7, int(patient['Baseline_PANSS_Total'] * base_neg_prop))
    patient['Baseline_PANSS_Positive'] = max(7, int(patient['Baseline_PANSS_Total'] * base_pos_prop))
    patient['Baseline_PANSS_General'] = max(16, int(patient['Baseline_PANSS_Total'] * base_gen_prop))
    patient['Baseline_PANSS_Total'] = patient['Baseline_PANSS_Negative'] + patient['Baseline_PANSS_Positive'] + patient['Baseline_PANSS_General']

    patient['Baseline_BPRS'] = int(patient['Baseline_PANSS_Total'] * np.random.normal(BPRS_BASELINE_MEAN_RATIO, BPRS_BASELINE_STD_RATIO))
    patient['Baseline_SANS'] = int(patient['Baseline_PANSS_Negative'] * np.random.normal(SANS_BASELINE_MEAN_RATIO, SANS_BASELINE_STD_RATIO))
    patient['Baseline_CGI_S'] = random.choices(list(CGI_S_BASELINE_DIST.keys()), weights=list(CGI_S_BASELINE_DIST.values()), k=1)[0]
    patient['Baseline_GAF'] = int(max(20, min(70, np.random.normal(GAF_BASELINE_MEAN, GAF_BASELINE_STD))))
    patient['Baseline_QoL'] = int(max(10, min(90, np.random.normal(QOL_BASELINE_MEAN, QOL_BASELINE_STD))))
    patient['Baseline_Cognitive_Score'] = round(np.random.normal(COG_BASELINE_MEAN, COG_BASELINE_STD), 2)
    patient['Previous_Antipsychotics_Count'] = max(0, int(np.random.normal(PREV_AP_MEAN, PREV_AP_STD)))

    # --- Treatment Assignment based on Study Config ---
    available_drugs_in_study = study_config["Drugs"]
    if 'Placebo' not in available_drugs_in_study and len(available_drugs_in_study) > 0 :
         patient['Drug'] = random.choice(available_drugs_in_study)
    elif 'Placebo' in available_drugs_in_study:
         patient['Drug'] = random.choice(available_drugs_in_study)
    else:
         patient['Drug'] = 'Placebo' # Default to placebo if error

    drug_profile = DRUG_PROFILES[patient['Drug']] 

    # --- V3 Add: Get Drug Info ---
    patient['Drug_MOA'] = drug_profile['MOA']
    patient['Drug_Class'] = drug_profile['Drug_Class']
    patient['Drug_Sponsor'] = drug_profile['Sponsor']

    # --- Assign Dose ---
    if patient['Drug'] == 'Placebo':
        patient['Dose_mg_day'] = 0
    else:
        available_doses = study_config["Doses"].get(patient['Drug'], drug_profile['available_doses']) # Fallback to profile doses
        if not isinstance(available_doses, list) or len(available_doses) == 0:
             available_doses = drug_profile['available_doses'] # Ensure fallback
        patient['Dose_mg_day'] = random.choice(available_doses)

    patient['Treatment_Duration_Weeks'] = study_config["Duration_Weeks"]

    # --- Simulate Outcomes ---
    dose = patient['Dose_mg_day']
    unit_mg = drug_profile['unit_mg']
    dose_units = (dose / unit_mg) if unit_mg > 0 else 0

    # Base Placebo Effects
    placebo_effect_mean = PLACEBO_PANSS_CHANGE_MEAN_US if is_us else PLACEBO_PANSS_CHANGE_MEAN_EU
    placebo_effect_std = PLACEBO_PANSS_CHANGE_STD_US if is_us else PLACEBO_PANSS_CHANGE_STD_EU
    individual_responder_factor = np.random.normal(1.0, 0.3) # Individual variability

    base_panss_change = np.random.normal(placebo_effect_mean, placebo_effect_std) * individual_responder_factor
    # Distribute placebo effect across subscales proportionally
    neg_prop = np.random.normal(PANSS_NEG_MEAN_RATIO, 0.08)
    pos_prop = np.random.normal(PANSS_POS_MEAN_RATIO, 0.08)
    gen_prop = max(0, 1.0 - neg_prop - pos_prop)
    norm_f = neg_prop+pos_prop+gen_prop
    neg_prop /= norm_f ; pos_prop /= norm_f ; gen_prop /= norm_f
    base_panss_neg_change = base_panss_change * neg_prop
    base_panss_pos_change = base_panss_change * pos_prop
    base_panss_gen_change = base_panss_change * gen_prop

    # Placebo effects for other outcomes
    base_gaf_change = np.random.normal(PLACEBO_GAF_CHANGE_MEAN, PLACEBO_GAF_CHANGE_STD) * individual_responder_factor
    base_qol_change = np.random.normal(PLACEBO_QOL_CHANGE_MEAN, PLACEBO_QOL_CHANGE_STD) * individual_responder_factor
    base_cog_change = np.random.normal(PLACEBO_COG_CHANGE_MEAN, PLACEBO_COG_CHANGE_STD) * max(0.5, individual_responder_factor) # Less variable placebo cog effect
    base_weight_change = np.random.normal(PLACEBO_WEIGHT_GAIN_MEAN, PLACEBO_WEIGHT_GAIN_STD)
    base_bmi_change = np.random.normal(PLACEBO_BMI_CHANGE_MEAN, PLACEBO_BMI_CHANGE_STD)
    base_glucose_change = np.random.normal(PLACEBO_GLUCOSE_CHANGE_MEAN, PLACEBO_GLUCOSE_CHANGE_STD)
    base_ldl_change = np.random.normal(PLACEBO_LDL_CHANGE_MEAN, PLACEBO_LDL_CHANGE_STD)
    base_eps_change = np.random.normal(PLACEBO_EPS_SAS_CHANGE_MEAN, PLACEBO_EPS_SAS_CHANGE_STD)
    base_ae_count = max(0, np.random.normal(PLACEBO_AE_COUNT_MEAN, PLACEBO_AE_COUNT_STD))

    # Add Drug Effects (if not Placebo)
    drug_panss_effect, drug_panss_neg_effect, drug_panss_pos_effect, drug_panss_gen_effect = 0, 0, 0, 0
    drug_gaf_effect, drug_qol_effect, drug_cog_effect = 0, 0, 0
    drug_weight_effect, drug_bmi_effect, drug_glucose_effect, drug_ldl_effect = 0, 0, 0, 0
    drug_eps_effect, drug_ae_bonus_calc = 0, 0

    if patient['Drug'] != 'Placebo':
        effect_mean = drug_profile['effect_per_unit'] * dose_units
        effect_std = drug_profile['effect_std_bonus'] * math.sqrt(max(0.1, dose_units)) # Std scales with sqrt(dose units)
        drug_panss_effect = np.random.normal(effect_mean, effect_std)

        # Distribute drug effect across subscales, considering negative symptom
        neg_bonus = drug_profile['neg_symptom_bonus'] * dose_units
        remaining_effect = drug_panss_effect - neg_bonus
        # Use baseline PANSS ratios to distribute remaining effect proportionally
        total_ratio = PANSS_NEG_MEAN_RATIO + PANSS_POS_MEAN_RATIO + (1-PANSS_NEG_MEAN_RATIO-PANSS_POS_MEAN_RATIO)
        neg_prop_drug = PANSS_NEG_MEAN_RATIO / total_ratio
        pos_prop_drug = PANSS_POS_MEAN_RATIO / total_ratio
        gen_prop_drug = (1 - PANSS_NEG_MEAN_RATIO - PANSS_POS_MEAN_RATIO) / total_ratio

        drug_panss_neg_effect = remaining_effect * neg_prop_drug * np.random.normal(1, 0.1) + neg_bonus
        drug_panss_pos_effect = remaining_effect * pos_prop_drug * np.random.normal(1, 0.1)
        drug_panss_gen_effect = remaining_effect * gen_prop_drug * np.random.normal(1, 0.1)

        # Other effects (simple linear scaling with dose units + noise)
        drug_gaf_effect = np.random.normal(1.0 * dose_units, 1.0 * math.sqrt(max(0.1, dose_units))) # GAF effect
        drug_qol_effect = np.random.normal(1.2 * dose_units, 1.2 * math.sqrt(max(0.1, dose_units))) # QoL effect
        drug_cog_effect = np.random.normal(drug_profile['cog_effect_per_unit'] * dose_units, 0.05 * math.sqrt(max(0.1, dose_units))) # Cog effect

        # Safety effects
        drug_weight_effect = np.random.normal(drug_profile['weight_gain_per_unit'] * dose_units, 0.5 * math.sqrt(max(0.1, dose_units)))
        drug_bmi_effect = np.random.normal(drug_profile['bmi_change_per_unit'] * dose_units, 0.1 * math.sqrt(max(0.1, dose_units)))
        drug_glucose_effect = np.random.normal(drug_profile['glucose_change_per_unit'] * dose_units, 1.0 * math.sqrt(max(0.1, dose_units)))
        drug_ldl_effect = np.random.normal(drug_profile['ldl_change_per_unit'] * dose_units, 1.5 * math.sqrt(max(0.1, dose_units)))
        drug_eps_effect = np.random.normal(drug_profile['eps_effect_per_unit'] * dose_units, 0.1 * math.sqrt(max(0.1, dose_units)))
        drug_ae_bonus_calc = drug_profile['ae_bonus_factor'] * dose_units * np.random.normal(1.0, 0.2) # AE count

    # Individual responder factor to drug effects
    drug_panss_effect *= individual_responder_factor
    drug_panss_neg_effect *= individual_responder_factor
    drug_panss_pos_effect *= individual_responder_factor
    drug_panss_gen_effect *= individual_responder_factor
    drug_gaf_effect *= individual_responder_factor
    drug_qol_effect *= individual_responder_factor
    drug_cog_effect *= max(0.5, individual_responder_factor) # Apply cognitive effect

    # --- Final Outcomes ---
    # Combine Placebo + Drug effects for PANSS changes
    patient['Change_PANSS_Total'] = round(base_panss_change + drug_panss_effect, 2)
    patient['Endpoint_PANSS_Total'] = max(30, int(patient['Baseline_PANSS_Total'] + patient['Change_PANSS_Total']))
    patient['Change_PANSS_Negative'] = round(base_panss_neg_change + drug_panss_neg_effect, 2)
    patient['Endpoint_PANSS_Negative'] = max(7, int(patient['Baseline_PANSS_Negative'] + patient['Change_PANSS_Negative']))
    patient['Change_PANSS_Positive'] = round(base_panss_pos_change + drug_panss_pos_effect, 2)
    patient['Endpoint_PANSS_Positive'] = max(7, int(patient['Baseline_PANSS_Positive'] + patient['Change_PANSS_Positive']))
    patient['Change_PANSS_General'] = round(base_panss_gen_change + drug_panss_gen_effect, 2)
    patient['Endpoint_PANSS_General'] = max(16, int(patient['Baseline_PANSS_General'] + patient['Change_PANSS_General']))
    patient['Endpoint_PANSS_Total'] = patient['Endpoint_PANSS_Negative'] + patient['Endpoint_PANSS_Positive'] + patient['Endpoint_PANSS_General']
    patient['Change_PANSS_Total'] = patient['Endpoint_PANSS_Total'] - patient['Baseline_PANSS_Total']

    # Other efficacy endpoints
    patient['Change_BPRS'] = round(patient['Change_PANSS_Total'] * np.random.normal(BPRS_CHANGE_FACTOR, 0.1), 1)
    patient['Change_SANS'] = round(patient['Change_PANSS_Negative'] * np.random.normal(SANS_CHANGE_FACTOR, 0.2), 1)
    patient['Change_CGI_S'] = round(patient['Change_PANSS_Total'] * CGI_CHANGE_SCALE_FACTOR * np.random.normal(1, 0.1), 1)
    patient['Endpoint_CGI_S'] = round(max(1, min(7, patient['Baseline_CGI_S'] + patient['Change_CGI_S'])), 1)

    # Estimate CGI-I based on PANSS change bins + noise
    panss_change = patient['Change_PANSS_Total']
    if panss_change <= -25: cgi_i_base = 1
    elif panss_change <= -18: cgi_i_base = 2
    elif panss_change <= -8: cgi_i_base = 3
    elif panss_change < 5: cgi_i_base = 4
    elif panss_change < 15: cgi_i_base = 5
    elif panss_change < 25: cgi_i_base = 6
    else: cgi_i_base = 7
    patient['Endpoint_CGI_I'] = max(1, min(7, cgi_i_base + random.choice([-1, 0, 0, 1]))) # Add some noise

    patient['Change_GAF'] = round(base_gaf_change + drug_gaf_effect, 1)
    patient['Endpoint_GAF'] = int(max(1, min(100, patient['Baseline_GAF'] + patient['Change_GAF'])))
    patient['Change_QoL'] = round(base_qol_change + drug_qol_effect, 1)
    patient['Endpoint_QoL'] = int(max(0, min(100, patient['Baseline_QoL'] + patient['Change_QoL'])))
    patient['Change_Cognitive_Score'] = round(base_cog_change + drug_cog_effect, 2)
    patient['Endpoint_Cognitive_Score'] = round(patient['Baseline_Cognitive_Score'] + patient['Change_Cognitive_Score'], 2)

    # Safety outcomes
    patient['Weight_Gain_kg'] = round(max(-5, base_weight_change + drug_weight_effect), 1) # Allow slight weight loss
    patient['Change_BMI'] = round(base_bmi_change + drug_bmi_effect, 1)
    patient['Endpoint_BMI'] = round(max(15, patient['Baseline_BMI'] + patient['Change_BMI']), 1) # Ensure BMI >= 15
    patient['Change_Glucose'] = int(base_glucose_change + drug_glucose_effect)
    patient['Endpoint_Glucose_mgdL'] = max(50, patient['Baseline_Glucose_mgdL'] + patient['Change_Glucose']) # Ensure Glucose >= 50
    patient['Change_LDL'] = int(base_ldl_change + drug_ldl_effect)
    patient['Endpoint_LDL_mgdL'] = max(40, patient['Baseline_LDL_mgdL'] + patient['Change_LDL']) # Ensure LDL >= 40
    patient['EPS_SAS_Change'] = round(max(0, base_eps_change + drug_eps_effect), 1) # SAS change >= 0
    patient['AE_Count'] = int(max(0, base_ae_count + drug_ae_bonus_calc)) # AE count >= 0

    # Discontinuation logic
    discontinuation_prob = AE_DISCONTINUE_BASE_RATE + \
                           patient['AE_Count'] * AE_DISCONTINUE_PER_AE + \
                           max(0, patient['Weight_Gain_kg']) * AE_DISCONTINUE_PER_KG_WG + \
                           patient['EPS_SAS_Change'] * AE_DISCONTINUE_PER_EPS_UNIT
    if patient['Change_PANSS_Total'] < -15 or patient['Change_GAF'] > 8: discontinuation_prob *= 0.4 
    if patient['Endpoint_CGI_I'] >= 6: discontinuation_prob *= 1.5 
    patient['Discontinued_Due_To_AE'] = 1 if random.random() < max(0, min(1, discontinuation_prob)) else 0

    # Responder and Remission Status
    percent_change = (patient['Change_PANSS_Total'] / patient['Baseline_PANSS_Total']) if patient['Baseline_PANSS_Total'] > 0 else 0
    patient['Responder_Status'] = 1 if (percent_change <= -0.30 or patient['Endpoint_CGI_I'] <= 2) else 0 # >=30% reduction or CGI-I <= 2
    # Remission criteria
    is_low_symptoms = patient['Endpoint_PANSS_Total'] <= 40 
    is_good_function = patient['Endpoint_GAF'] >= 61 
    is_stable_cgis = patient['Endpoint_CGI_S'] <= 3 
    # Simple remission definition for simulation purposes
    patient['Remission_Status'] = 1 if (is_low_symptoms and is_good_function and is_stable_cgis) else 0

    return patient


# --- Main Generation Loop ---
start_main_time = time.time()
print(f"\nStarting V3 Enhanced Generation for {len(STUDIES_CONFIG)} studies...")
print(f"Targeting ~{TOTAL_PATIENTS_TO_GENERATE:,} total patient records.")
print(f"Including regional demographics and Drugs A-G with MOA/Class/Sponsor.") 
print(f"Start time: {time.strftime('%Y-%m-%d %H:%M:%S')}")

all_patient_data = []
patient_id_counter = 50000 

total_generated = 0
for study_conf in STUDIES_CONFIG:
    study_start_time = time.time()
    n_study = study_conf['N_Patients']
    print(f"  Generating {n_study} patients for Study: {study_conf['Study_ID']} ({study_conf['Percent_US']*100:.0f}% US)...")
    print_freq = max(100, n_study // 5)

    for i in range(n_study):
        new_patient = generate_patient_record(patient_id_counter, study_conf)
        all_patient_data.append(new_patient)
        patient_id_counter += 1
    total_generated += n_study
    study_end_time = time.time()


# --- Create Final DataFrame ---
df_multi_study_v3 = pd.DataFrame(all_patient_data)
end_main_time = time.time()
print(f"\nFinished generating all {len(df_multi_study_v3):,} patient records from {len(STUDIES_CONFIG)} studies.")
print(f"Total generation time: {end_main_time - start_main_time:.2f} seconds.")


# --- Column Order ---
baseline_cols = [
    'Baseline_PANSS_Total', 'Baseline_PANSS_Positive', 'Baseline_PANSS_Negative', 'Baseline_PANSS_General',
    'Baseline_BPRS', 'Baseline_SANS', 'Baseline_CGI_S', 'Baseline_GAF', 'Baseline_QoL', 'Baseline_Cognitive_Score',
    'Baseline_BMI', 'Baseline_Glucose_mgdL', 'Baseline_LDL_mgdL'
]
demographic_cols = [
    'Patient_ID', 'Study_ID', 'Region', 'Site_Type', 'Age', 'Sex', 'Ethnicity', 'Education_Level', 'Socioeconomic_Status',
    'Age_at_Onset', 'Years_Since_Dx', 'Previous_Antipsychotics_Count', 'Comorbidity_Substance_Use', 'Cardiovascular_Risk_Factor'
]
# Drug Info
drug_info_cols = ['Drug', 'Dose_mg_day', 'Drug_MOA', 'Drug_Class', 'Drug_Sponsor']
treatment_cols = ['Treatment_Duration_Weeks'] # Moved duration here

endpoint_efficacy_cols = [
    'Endpoint_PANSS_Total', 'Endpoint_PANSS_Positive', 'Endpoint_PANSS_Negative', 'Endpoint_PANSS_General',
    'Endpoint_CGI_S', 'Endpoint_CGI_I', 'Endpoint_GAF', 'Endpoint_QoL', 'Endpoint_Cognitive_Score',
    'Responder_Status', 'Remission_Status'
]
change_efficacy_cols = [
     'Change_PANSS_Total', 'Change_PANSS_Positive', 'Change_PANSS_Negative', 'Change_PANSS_General',
     'Change_BPRS', 'Change_SANS', 'Change_CGI_S', 'Change_GAF', 'Change_QoL', 'Change_Cognitive_Score'
]
safety_cols = [
    'Weight_Gain_kg', 'Change_BMI', 'Endpoint_BMI',
    'Change_Glucose', 'Endpoint_Glucose_mgdL', 'Change_LDL', 'Endpoint_LDL_mgdL',
    'EPS_SAS_Change', 'AE_Count', 'Discontinued_Due_To_AE'
]
# column order
all_cols = demographic_cols + drug_info_cols + treatment_cols + baseline_cols + \
           endpoint_efficacy_cols + change_efficacy_cols + safety_cols

present_cols = [col for col in all_cols if col in df_multi_study_v3.columns]
df_multi_study_v3 = df_multi_study_v3[present_cols]


# --- Show and Save ---
print("\n--- V3 Enhanced Multi-Study Synthetic Schizophrenia Clinical Trial Dataset ---")
print(f"Generated {len(df_multi_study_v3):,} total patient records across {len(STUDIES_CONFIG)} studies.")
print(f"Total Drugs Modeled (Active + Placebo): {len(DRUG_PROFILES)}")
print(f"Dataset includes Drug MOA, Class, and Sponsor.") # V3 Update

print("\nDataset Head (First 5 rows):")
print(df_multi_study_v3.head().to_markdown(index=False))

print("\nDataset Info:")
df_multi_study_v3.info()

print("\nValue Counts for Study ID (Top 10):")
print(df_multi_study_v3['Study_ID'].value_counts().head(10))

print("\nDrugs Tested per Study (Sample):")
try:
    print(df_multi_study_v3.groupby('Study_ID')['Drug'].value_counts().sample(15, random_state=1))
except ValueError:
     print(df_multi_study_v3.groupby('Study_ID')['Drug'].value_counts())


print("\nRegion Distribution per Study (Summary):")
print(df_multi_study_v3.groupby('Study_ID')['Region'].value_counts(normalize=True).unstack().fillna(0).agg(['mean', 'median', 'min', 'max']))

print("\nEthnicity Distribution by Region:")
print(df_multi_study_v3.groupby('Region')['Ethnicity'].value_counts(normalize=True).unstack().fillna(0))

# Save to CSV
output_filename_v3 = "synthetic_schizophrenia_multi_study_v3.csv"
df_multi_study_v3.to_csv(output_filename_v3, index=False)
print(f"\nV3 Dataset saved to {output_filename_v3}")

# Display completion time
current_time_str = time.strftime('%Y-%m-%d %H:%M:%S')
print(f"\nProcessing completed at: {current_time_str}")


# --- Data Visualization ---
print("\n--- Generating Data Visualizations ---")

sns.set_theme(style="whitegrid")
df_to_plot = df_multi_study_v3

if df_to_plot is not None:

    # --- General Distributions ---

    # 1. Distribution of Baseline PANSS Total
    print("  Plotting: Baseline PANSS Total Distribution")
    try:
        plt.figure(figsize=(10, 6))
        sns.histplot(data=df_to_plot, x='Baseline_PANSS_Total', kde=True, bins=30)
        plt.title('Distribution of Baseline PANSS Total')
        plt.xlabel('Baseline PANSS Total Score')
        plt.ylabel('Frequency')
        plt.tight_layout()
        # plt.savefig('v3_hist_baseline_panss.png') 
        plt.show() 
        plt.close()
    except Exception as e:
        print(f"  Error plotting Baseline PANSS histogram: {e}")

    # 2. Distribution of Change in PANSS Total
    print("  Plotting: Change in PANSS Total Distribution (All Arms)")
    try:
        plt.figure(figsize=(10, 6))
        sns.histplot(data=df_to_plot, x='Change_PANSS_Total', kde=True, bins=40)
        plt.title('Distribution of Change in PANSS Total (All Arms)')
        plt.xlabel('Change from Baseline PANSS Total Score')
        plt.ylabel('Frequency')
        plt.tight_layout()
        # plt.savefig('v3_hist_change_panss.png') 
        plt.show() 
        plt.close()
    except Exception as e:
        print(f"  Error plotting Change PANSS histogram: {e}")

    # 3. Drug Assignment Counts
    print("  Plotting: Patient Counts per Drug Assignment")
    try:
        plt.figure(figsize=(12, 7))
        sns.countplot(data=df_to_plot, y='Drug', order = df_to_plot['Drug'].value_counts().index)
        plt.title('Patient Counts per Drug Assignment')
        plt.xlabel('Number of Patients')
        plt.ylabel('Drug')
        plt.tight_layout()
        # plt.savefig('v3_count_drug.png')
        plt.show() 
        plt.close()
    except Exception as e:
        print(f"  Error plotting Drug counts: {e}")


    # --- US vs EU Comparisons ---

    # 4. Change in PANSS Total by Region (Overall)
    print("  Plotting: Change in PANSS Total by Region (Boxplot)")
    try:
        plt.figure(figsize=(10, 7))
        sns.boxplot(data=df_to_plot, x='Region', y='Change_PANSS_Total', order=['EU', 'US'])
        plt.title('Change in PANSS Total by Region (All Arms)')
        plt.xlabel('Region')
        plt.ylabel('Change from Baseline PANSS Total Score')
        plt.tight_layout()
        # plt.savefig('v3_box_change_panss_region.png')
        plt.show()
        plt.close()
    except Exception as e:
        print(f"  Error plotting Change PANSS boxplot by region: {e}")

    # 5. Placebo Response: Change in PANSS Total by Region
    print("  Plotting: Placebo Response Distribution by Region (KDE)")
    try:
        plt.figure(figsize=(12, 7))
        df_placebo = df_to_plot[df_to_plot['Drug'] == 'Placebo'].copy()
        sns.kdeplot(data=df_placebo, x='Change_PANSS_Total', hue='Region', fill=True, common_norm=False, hue_order=['EU', 'US'])
        mean_eu = df_placebo[df_placebo['Region']=='EU']['Change_PANSS_Total'].mean()
        mean_us = df_placebo[df_placebo['Region']=='US']['Change_PANSS_Total'].mean()
        plt.axvline(mean_eu, color=sns.color_palette()[0], linestyle='--', label=f'EU Mean: {mean_eu:.2f}')
        plt.axvline(mean_us, color=sns.color_palette()[1], linestyle='--', label=f'US Mean: {mean_us:.2f}')
        plt.title('Placebo Response Distribution (Change in PANSS Total) by Region')
        plt.xlabel('Change from Baseline PANSS Total Score')
        plt.ylabel('Density')
        plt.legend()
        plt.tight_layout()
        # plt.savefig('v3_kde_placebo_panss_region.png')
        plt.show()
        plt.close()
    except Exception as e:
        print(f"  Error plotting Placebo PANSS KDE by region: {e}")

    # 6. Example Drug Response: Change in PANSS for Drug D by Region
    print("  Plotting: Drug D vs Placebo Response by Region")
    try:
        plt.figure(figsize=(12, 7))
        df_drug_d = df_to_plot[df_to_plot['Drug'] == 'Drug_D'].copy()
        df_placebo_d_comp = pd.concat([df_placebo, df_drug_d]) # Reuse placebo df from above
        sns.boxplot(data=df_placebo_d_comp, x='Region', y='Change_PANSS_Total', hue='Drug', order=['EU', 'US'], hue_order=['Placebo', 'Drug_D'])
        plt.title('Drug D vs Placebo Response (Change PANSS) by Region')
        plt.xlabel('Region')
        plt.ylabel('Change from Baseline PANSS Total Score')
        plt.tight_layout()
        # plt.savefig('v3_box_drug_d_vs_placebo_region.png')
        plt.show()
        plt.close()
    except Exception as e:
        print(f"  Error plotting Drug D vs Placebo boxplot: {e}")


    # 7. Baseline PANSS Distribution by Region
    print("  Plotting: Baseline PANSS Total Distribution by Region (KDE)")
    try:
        plt.figure(figsize=(10, 6))
        sns.kdeplot(data=df_to_plot, x='Baseline_PANSS_Total', hue='Region', fill=True, common_norm=False, hue_order=['EU', 'US'])
        plt.title('Baseline PANSS Total Distribution by Region')
        plt.xlabel('Baseline PANSS Total Score')
        plt.ylabel('Density')
        plt.tight_layout()
        # plt.savefig('v3_kde_baseline_panss_region.png')
        plt.show()
        plt.close()
    except Exception as e:
        print(f"  Error plotting Baseline PANSS KDE by region: {e}")


    # 8. Ethnicity Distribution by Region
    print("  Plotting: Ethnicity Distribution by Region (Bar)")
    try:
        plt.figure(figsize=(12, 8))
        ethnicity_props = df_to_plot.groupby('Region')['Ethnicity'].value_counts(normalize=True).unstack().fillna(0)
        ethnicity_props.plot(kind='bar', stacked=True, figsize=(10,6)) 
        plt.title('Proportion of Ethnicity Categories by Region')
        plt.xlabel('Region')
        plt.ylabel('Proportion')
        plt.xticks(rotation=0)
        plt.legend(title='Ethnicity', bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout(rect=[0, 0, 0.85, 1])
        # plt.savefig('v3_bar_ethnicity_region.png')
        plt.show() 
        plt.close()
    except Exception as e:
        print(f"  Error plotting Ethnicity bar chart: {e}")

    # 9. Substance Use Comorbidity Rate by Region
    print("  Plotting: Substance Use Comorbidity Rate by Region (Bar)")
    try:
        plt.figure(figsize=(7, 5))
        sns.barplot(data=df_to_plot, x='Region', y='Comorbidity_Substance_Use', order=['EU', 'US'], estimator=np.mean, errorbar=None) # Use np.mean, removed ci for clarity
        plt.title('Substance Use Comorbidity Rate by Region')
        plt.xlabel('Region')
        plt.ylabel('Proportion with Comorbidity')
        plt.ylim(0, 1)
        plt.tight_layout()
        # plt.savefig('v3_bar_substance_use_region.png')
        plt.show()
        plt.close()
    except Exception as e:
        print(f"  Error plotting Substance Use bar chart: {e}")


    # 10. Distribution of Drug Classes
    print("  Plotting: Distribution of Drug Classes")
    try:
        plt.figure(figsize=(10, 6))
        sns.countplot(data=df_to_plot, y='Drug_Class', order=df_to_plot['Drug_Class'].value_counts().index)
        plt.title('Patient Counts per Drug Class')
        plt.xlabel('Number of Patients')
        plt.ylabel('Drug Class')
        plt.tight_layout()
        # plt.savefig('v3_count_drug_class.png')
        plt.show()
        plt.close()
    except Exception as e:
        print(f"  Error plotting Drug Class counts: {e}")

    # 11. Change in PANSS by Drug Class (Example)
    print("  Plotting: Change in PANSS by Drug Class (Boxplot)")
    try:
        plt.figure(figsize=(12, 8))
        sns.boxplot(data=df_to_plot[df_to_plot['Drug_Class'] != 'Placebo'], x='Change_PANSS_Total', y='Drug_Class') # Excluded Placebo
        plt.title('Change in PANSS Total by Active Drug Class')
        plt.xlabel('Change from Baseline PANSS Total Score')
        plt.ylabel('Drug Class')
        plt.tight_layout()
        # plt.savefig('v3_box_change_panss_drug_class.png')
        plt.show()
        plt.close()
    except Exception as e:
        print(f"  Error plotting Change PANSS by Drug Class: {e}")

    # 12. Change in PANSS by MOA (Top MOAs)
    print("  Plotting: Change in PANSS by Top MOAs (Boxplot)")
    try:
        plt.figure(figsize=(14, 8))
        # Select top N MOAs excluding 'None' (Placebo) for clarity
        top_moas = df_to_plot[df_to_plot['Drug_MOA'] != 'None']['Drug_MOA'].value_counts().nlargest(5).index
        df_top_moa = df_to_plot[df_to_plot['Drug_MOA'].isin(top_moas)]
        sns.boxplot(data=df_top_moa, x='Change_PANSS_Total', y='Drug_MOA', order=top_moas)
        plt.title('Change in PANSS Total by Top 5 MOAs (Excluding Placebo)')
        plt.xlabel('Change from Baseline PANSS Total Score')
        plt.ylabel('Mechanism of Action (MOA)')
        plt.tight_layout()
        # plt.savefig('v3_box_change_panss_moa.png')
        plt.show()
        plt.close()
    except Exception as e:
        print(f" Error plotting Change PANSS by MOA: {e}")


    print("--- Visualization generation complete ---")

else:
    print("Problem.")


# --- End of Script ---

# Conditional Cycle GAN

In [None]:
# Conditional CycleGAN

import os
import time
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Dense, LayerNormalization, LeakyReLU, Add,
    Embedding, Concatenate, Flatten
)
from tensorflow.keras.models import Model
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import NearestNeighbors
import scipy.stats as stats
import matplotlib.pyplot as plt

# -------------------------------

# Hyperparameters & Config

DATASET_FILENAME = "synthetic_schizophrenia_multi_study_v3.csv"
CHECKPOINT_DIR = "./tf_checkpoints/cyclegan_cond_aug"
SAVE_FREQ_EPOCHS = 10
BATCH_SIZE = 32
TOTAL_EPOCHS = 50

# Embedding dimensions
STUDY_EMBED_DIM = 16
DRUG_EMBED_DIM = 8
# Noise dimension for one-to-many
Z_DIM = 16

# CycleGAN loss weights
LAMBDA_CYCLE = 10.0
LAMBDA_ID = 0.1

# Features
FEATURES_X = [
    'Change_PANSS_Total', 'Change_PANSS_Positive', 'Change_PANSS_Negative',
    'Change_GAF', 'Change_Cognitive_Score', 'Weight_Gain_kg',
    'EPS_SAS_Change', 'AE_Count'
]
FEATURES_C_CONT = [
    'Age', 'Baseline_PANSS_Total', 'Age_at_Onset', 'Years_Since_Dx',
    'Previous_Antipsychotics_Count', 'Dose_mg_day', 'Treatment_Duration_Weeks'
]
FEATURES_C_CAT = ['Study_ID', 'Drug']

# -------------------------------

# Dense Residual Block
def dense_residual_block(x, width):
    shortcut = x
    y = Dense(width)(x)
    y = LayerNormalization()(y)
    y = LeakyReLU(alpha=0.2)(y)
    y = Dense(width)(y)
    y = LayerNormalization()(y)
    out = Add()([shortcut, y])
    out = LeakyReLU(alpha=0.2)(out)
    return out

# -------------------------------

# Generator with Learned Embeddings & Noise

def build_generator_cGAN(
    input_dim_x, cont_dim,
    n_studies, n_drugs,
    study_emb_dim, drug_emb_dim, z_dim,
    num_residual=6, layer_width=128
):
    inp_x     = Input(shape=(input_dim_x,), name='gen_input_x')
    inp_cont  = Input(shape=(cont_dim,), name='gen_input_cont')
    inp_study = Input(shape=(), dtype='int32', name='gen_input_study')
    inp_drug  = Input(shape=(), dtype='int32', name='gen_input_drug')
    inp_z     = Input(shape=(z_dim,), name='gen_input_z')

    # learned embeddings
    stud_emb = Embedding(n_studies, study_emb_dim, name='study_emb')(inp_study)
    stud_emb = Flatten()(stud_emb)
    drug_emb = Embedding(n_drugs, drug_emb_dim, name='drug_emb')(inp_drug)
    drug_emb = Flatten()(drug_emb)

    x = Concatenate()([inp_x, inp_cont, stud_emb, drug_emb, inp_z])
    x = Dense(layer_width)(x)
    x = LayerNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    for _ in range(num_residual):
        x = dense_residual_block(x, layer_width)
    out = Dense(input_dim_x, name='gen_output')(x)  # linear
    return Model(
        inputs=[inp_x, inp_cont, inp_study, inp_drug, inp_z],
        outputs=out,
        name=f'Generator_R{num_residual}_W{layer_width}_Z{z_dim}'
    )

# -------------------------------

# Discriminator with Learned Embeddings

def build_discriminator_cGAN(
    input_dim_x, cont_dim,
    n_studies, n_drugs,
    study_emb_dim, drug_emb_dim,
    num_layers=4, start_width=128, name_suffix=""
):
    inp_x     = Input(shape=(input_dim_x,), name=f'disc_input_x{name_suffix}')
    inp_cont  = Input(shape=(cont_dim,), name=f'disc_input_cont{name_suffix}')
    inp_study = Input(shape=(), dtype='int32', name=f'disc_input_study{name_suffix}')
    inp_drug  = Input(shape=(), dtype='int32', name=f'disc_input_drug{name_suffix}')

    stud_emb = Embedding(n_studies, study_emb_dim, name=f'study_emb{name_suffix}')(inp_study)
    stud_emb = Flatten()(stud_emb)
    drug_emb = Embedding(n_drugs, drug_emb_dim, name=f'drug_emb{name_suffix}')(inp_drug)
    drug_emb = Flatten()(drug_emb)

    x = Concatenate()([inp_x, inp_cont, stud_emb, drug_emb])
    x = Dense(start_width)(x)
    x = LeakyReLU(alpha=0.2)(x)
    curr_w = start_width
    for _ in range(num_layers - 2):
        curr_w = max(32, curr_w // 2)
        x = Dense(curr_w)(x)
        x = LeakyReLU(alpha=0.2)(x)
    out = Dense(1, name=f'disc_output{name_suffix}')(x)  # linear real/fake score
    return Model(
        inputs=[inp_x, inp_cont, inp_study, inp_drug],
        outputs=out,
        name=f'Discriminator_L{num_layers}_W{start_width}{name_suffix}'
    )

# -------------------------------

# Load & Prepare Data

if not os.path.exists(DATASET_FILENAME):
    raise FileNotFoundError(f"Dataset file not found: {DATASET_FILENAME}")
df = pd.read_csv(DATASET_FILENAME)

# ensure required columns
required = ['Region'] + FEATURES_X + FEATURES_C_CONT + FEATURES_C_CAT
missing = [c for c in required if c not in df.columns]
if missing:
    raise ValueError(f"Missing columns in data: {missing}")

# label‐encode categorical context
le_study = LabelEncoder().fit(df['Study_ID'])
df['Study_IDX'] = le_study.transform(df['Study_ID'])
n_studies = len(le_study.classes_)

le_drug = LabelEncoder().fit(df['Drug'])
df['Drug_IDX'] = le_drug.transform(df['Drug'])
n_drugs = len(le_drug.classes_)

# split regions
df_eu = df[df['Region']=='EU'].copy()
df_us = df[df['Region']=='US'].copy()
if df_eu.empty or df_us.empty:
    raise ValueError("EU or US split has zero rows.")

# scale continuous context and x‐features
scaler_c = StandardScaler().fit(df[FEATURES_C_CONT])
scaler_x = StandardScaler().fit(df[FEATURES_X])

# transform
def to_numpy(subdf):
    x_np     = scaler_x.transform(subdf[FEATURES_X])
    cont_np  = scaler_c.transform(subdf[FEATURES_C_CONT])
    study_np = subdf['Study_IDX'].to_numpy(dtype='int32')
    drug_np  = subdf['Drug_IDX'].to_numpy(dtype='int32')
    return x_np.astype('float32'), cont_np.astype('float32'), study_np, drug_np

x_eu_np, c_eu_np, s_eu_np, d_eu_np = to_numpy(df_eu)
x_us_np, c_us_np, s_us_np, d_us_np = to_numpy(df_us)

# build tf.data pipelines
ds_eu = tf.data.Dataset.from_tensor_slices((x_eu_np, c_eu_np, s_eu_np, d_eu_np))\
    .shuffle( max(10000, len(x_eu_np)) )\
    .batch(BATCH_SIZE, drop_remainder=True)
ds_us = tf.data.Dataset.from_tensor_slices((x_us_np, c_us_np, s_us_np, d_us_np))\
    .shuffle( max(10000, len(x_us_np)) )\
    .batch(BATCH_SIZE, drop_remainder=True)
train_ds = tf.data.Dataset.zip((ds_eu, ds_us)).prefetch(tf.data.AUTOTUNE).repeat()
steps_per_epoch = min(len(x_eu_np), len(x_us_np)) // BATCH_SIZE
if steps_per_epoch == 0:
    raise ValueError("Not enough data for one batch.")

# -------------------------------

# Build Models & Checkpoints

input_dim_x = len(FEATURES_X)
cont_dim    = len(FEATURES_C_CONT)

G_EU2US = build_generator_cGAN(
    input_dim_x, cont_dim,
    n_studies, n_drugs,
    STUDY_EMBED_DIM, DRUG_EMBED_DIM, Z_DIM,
    num_residual=6, layer_width=128
)
G_US2EU = build_generator_cGAN(
    input_dim_x, cont_dim,
    n_studies, n_drugs,
    STUDY_EMBED_DIM, DRUG_EMBED_DIM, Z_DIM,
    num_residual=6, layer_width=128
)
D_EU = build_discriminator_cGAN(
    input_dim_x, cont_dim,
    n_studies, n_drugs,
    STUDY_EMBED_DIM, DRUG_EMBED_DIM,
    num_layers=4, start_width=128, name_suffix="_EU"
)
D_US = build_discriminator_cGAN(
    input_dim_x, cont_dim,
    n_studies, n_drugs,
    STUDY_EMBED_DIM, DRUG_EMBED_DIM,
    num_layers=4, start_width=128, name_suffix="_US"
)

gen_opt  = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
disc_opt = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
mse_loss = tf.keras.losses.MeanSquaredError()
mae_loss = tf.keras.losses.MeanAbsoluteError()

# checkpoint
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
ckpt = tf.train.Checkpoint(
    G_EU2US=G_EU2US, G_US2EU=G_US2EU,
    D_EU=D_EU, D_US=D_US,
    gen_opt=gen_opt, disc_opt=disc_opt,
    epoch=tf.Variable(0)
)
ckpt_mgr = tf.train.CheckpointManager(ckpt, CHECKPOINT_DIR, max_to_keep=5)
if ckpt_mgr.latest_checkpoint:
    ckpt.restore(ckpt_mgr.latest_checkpoint).expect_partial()
    print("Restored from", ckpt_mgr.latest_checkpoint)
initial_epoch = int(ckpt.epoch.numpy())

# -------------------------------

# Training

@tf.function
def train_step(eu_batch, us_batch):
    x_eu, c_eu, s_eu, d_eu = eu_batch
    x_us, c_us, s_us, d_us = us_batch
    bs = tf.shape(x_eu)[0]
    z_eu = tf.random.normal((bs, Z_DIM))
    z_us = tf.random.normal((bs, Z_DIM))
    z0   = tf.zeros((bs, Z_DIM))

    with tf.GradientTape(persistent=True) as tape:
        # forward
        fake_us = G_EU2US([x_eu, c_eu, s_eu, d_eu, z_eu], training=True)
        fake_eu = G_US2EU([x_us, c_us, s_us, d_us, z_us], training=True)
        # cycle (use zero‐noise for cycle/identity)
        cycled_eu = G_US2EU([fake_us, c_eu, s_eu, d_eu, z0], training=True)
        cycled_us = G_EU2US([fake_eu, c_us, s_us, d_us, z0], training=True)
        same_eu   = G_US2EU([x_eu, c_eu, s_eu, d_eu, z0], training=True)
        same_us   = G_EU2US([x_us, c_us, s_us, d_us, z0], training=True)

        # adversarial loss
        D_US_real = D_US([x_us, c_us, s_us, d_us], training=True)
        D_US_fake = D_US([fake_us, c_eu, s_eu, d_eu], training=True)
        D_EU_real = D_EU([x_eu, c_eu, s_eu, d_eu], training=True)
        D_EU_fake = D_EU([fake_eu, c_us, s_us, d_us], training=True)

        loss_G_adv = (
            mse_loss(tf.ones_like(D_US_fake), D_US_fake) +
            mse_loss(tf.ones_like(D_EU_fake), D_EU_fake)
        )
        loss_cycle = mae_loss(x_eu, cycled_eu) + mae_loss(x_us, cycled_us)
        loss_id    = mae_loss(x_eu, same_eu) + mae_loss(x_us, same_us)
        total_G    = loss_G_adv + LAMBDA_CYCLE * loss_cycle + LAMBDA_ID * loss_id

        loss_D_US = 0.5 * (
            mse_loss(tf.ones_like(D_US_real), D_US_real) +
            mse_loss(tf.zeros_like(D_US_fake), D_US_fake)
        )
        loss_D_EU = 0.5 * (
            mse_loss(tf.ones_like(D_EU_real), D_EU_real) +
            mse_loss(tf.zeros_like(D_EU_fake), D_EU_fake)
        )
        total_D = loss_D_US + loss_D_EU

    grads_G = tape.gradient(total_G, G_EU2US.trainable_variables + G_US2EU.trainable_variables)
    grads_D = tape.gradient(total_D, D_US.trainable_variables + D_EU.trainable_variables)
    gen_opt.apply_gradients(zip(grads_G, G_EU2US.trainable_variables + G_US2EU.trainable_variables))
    disc_opt.apply_gradients(zip(grads_D, D_US.trainable_variables + D_EU.trainable_variables))
    return {
        "G_adv": loss_G_adv, "cycle_loss": loss_cycle,
        "id_loss": loss_id, "D_loss": total_D
    }

train_iter = iter(train_ds)
for epoch in range(initial_epoch, TOTAL_EPOCHS):
    start = time.time()
    epoch_losses = {"G_adv":0., "cycle_loss":0., "id_loss":0., "D_loss":0.}
    for step in range(steps_per_epoch):
        eu_b, us_b = next(train_iter)
        losses = train_step(eu_b, us_b)
        for k in epoch_losses: epoch_losses[k] += losses[k]
    # average
    for k in epoch_losses:
        epoch_losses[k] /= steps_per_epoch
    print(f"Epoch {epoch+1}/{TOTAL_EPOCHS} | "
          f"G_adv={epoch_losses['G_adv']:.4f} "
          f"cycle={epoch_losses['cycle_loss']:.4f} "
          f"id={epoch_losses['id_loss']:.4f} "
          f"D={epoch_losses['D_loss']:.4f} "
          f"({time.time()-start:.1f}s)")
    # checkpoint
    if (epoch+1) % SAVE_FREQ_EPOCHS == 0:
        ckpt.epoch.assign(epoch+1)
        p = ckpt_mgr.save()
        print("Saved checkpoint:", p)

# -------------------------------

# Generate Fake Translations

n_eval = min(1000, x_eu_np.shape[0])
x_eu_sub   = x_eu_np[:n_eval]
c_eu_sub   = c_eu_np[:n_eval]
s_eu_sub   = s_eu_np[:n_eval]
d_eu_sub   = d_eu_np[:n_eval]
z_eval     = np.random.randn(n_eval, Z_DIM).astype('float32')
fake_us_np = G_EU2US.predict([x_eu_sub, c_eu_sub, s_eu_sub, d_eu_sub, z_eval])

real_us_np = x_us_np[:n_eval]

In [None]:
# Visualizations

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.neighbors import NearestNeighbors

# 1) Subsample both domains
n_eval = min(1000, x_eu_np.shape[0], x_us_np.shape[0])
x_eu_sub = x_eu_np[:n_eval]
c_eu_sub = c_eu_np[:n_eval]
s_eu_sub = s_eu_np[:n_eval]
d_eu_sub = d_eu_np[:n_eval]
x_us_sub = x_us_np[:n_eval]
c_us_sub = c_us_np[:n_eval]
s_us_sub = s_us_np[:n_eval]
d_us_sub = d_us_np[:n_eval]

# 2) Generate fake US and fake EU
z_eval = np.random.randn(n_eval, Z_DIM).astype('float32')
fake_us_np = G_EU2US.predict([x_eu_sub, c_eu_sub, s_eu_sub, d_eu_sub, z_eval])
fake_eu_np = G_US2EU.predict([x_us_sub, c_us_sub, s_us_sub, d_us_sub, z_eval])
real_us_np = x_us_sub
real_eu_np = x_eu_sub

# 3) MMD helper
def gaussian_kernel(X, Y, sigma=1.0):
    XX = np.sum(X**2, axis=1).reshape(-1,1)
    YY = np.sum(Y**2, axis=1).reshape(1,-1)
    d2 = XX - 2*X.dot(Y.T) + YY
    return np.exp(-d2/(2*sigma**2))

def compute_mmd(X, Y, sigma=1.0):
    Kxx = gaussian_kernel(X, X, sigma)
    Kyy = gaussian_kernel(Y, Y, sigma)
    Kxy = gaussian_kernel(X, Y, sigma)
    return Kxx.mean() + Kyy.mean() - 2*Kxy.mean()

# 4) Feature‐wise KDE plots (first 6 features)
nplots = min(6, len(FEATURES_X))
fig, axes = plt.subplots(nplots, 2, figsize=(12, 3*nplots))
for i, feat in enumerate(FEATURES_X[:nplots]):
    ax_eu, ax_us = axes[i]
    sns.kdeplot(real_eu_np[:,i],    label='real EU', ax=ax_eu, fill=True)
    sns.kdeplot(fake_eu_np[:,i],    label='fake EU', ax=ax_eu, fill=True)
    ax_eu.set_title(f'EU Reconst.: {feat}')
    sns.kdeplot(real_us_np[:,i],    label='real US', ax=ax_us, fill=True)
    sns.kdeplot(fake_us_np[:,i],    label='fake US', ax=ax_us, fill=True)
    ax_us.set_title(f'US Trans.: {feat}')
    ax_eu.legend(); ax_us.legend()
plt.tight_layout()
plt.show()

# 5) PCA 2D projection
def plot_pca2d(datasets, labels, title):
    pca = PCA(n_components=2)
    Y = pca.fit_transform(np.vstack(datasets))
    splits = np.cumsum([0]+[len(d) for d in datasets])
    plt.figure(figsize=(6,5))
    for idx, lab in enumerate(labels):
        sl = slice(splits[idx], splits[idx+1])
        plt.scatter(Y[sl,0], Y[sl,1], s=15, alpha=0.6, label=lab)
    plt.title(title)
    plt.xlabel('PC1'); plt.ylabel('PC2')
    plt.legend(); plt.tight_layout(); plt.show()

plot_pca2d([real_eu_np, fake_eu_np], ['real EU','fake EU'], 'PCA: EU real vs fake')
plot_pca2d([real_us_np, fake_us_np], ['real US','fake US'], 'PCA: US real vs fake')

# 6) Correlation‐matrix heatmaps for US domain
real_corr_us = np.corrcoef(real_us_np, rowvar=False)
fake_corr_us = np.corrcoef(fake_us_np, rowvar=False)
fig, (ax1,ax2) = plt.subplots(1,2,figsize=(10,4))
sns.heatmap(real_corr_us, vmin=-1, vmax=1, cmap='coolwarm', ax=ax1, cbar=False)
ax1.set_title('Real US Corr')
sns.heatmap(fake_corr_us, vmin=-1, vmax=1, cmap='coolwarm', ax=ax2, cbar=False)
ax2.set_title('Fake US Corr')
plt.tight_layout()
plt.show()

# 7) NN‐distance histogram (privacy proxy)
nbr = NearestNeighbors(n_neighbors=1).fit(real_us_np)
dists, _ = nbr.kneighbors(fake_us_np)
dists = dists.ravel()
plt.figure(figsize=(6,4))
sns.histplot(dists, bins=50, kde=True)
plt.title('NN distance: fake US → nearest real US')
plt.xlabel('Euclidean dist'); plt.ylabel('Count')
plt.tight_layout()
plt.show()
print(f"NN distances  mean={dists.mean():.3f}, med={np.median(dists):.3f}, min={dists.min():.3f}, max={dists.max():.3f}")

# 8) MMD vs σ
sigmas = [0.5, 1.0, 2.0, 5.0]
mmd_vals = [compute_mmd(real_us_np, fake_us_np, sigma=s) for s in sigmas]
plt.figure(figsize=(6,4))
plt.plot(sigmas, mmd_vals, '-o')
plt.title('MMD(real US, fake US) vs σ')
plt.xlabel('σ'); plt.ylabel('MMD')
plt.tight_layout()
plt.show()

In [None]:
# Conditional CycleGAN Evaluation

import numpy as np, pandas as pd, scipy.stats as stats
from   sklearn.linear_model   import LogisticRegression
from   sklearn.metrics        import accuracy_score, roc_auc_score, mean_squared_error
from   sklearn.neighbors      import NearestNeighbors


real_eu_np   = x_eu_np
cont_eu_np   = c_eu_np
study_eu_idx = s_eu_np
drug_eu_idx  = d_eu_np

real_us_np   = x_us_np
cont_us_np   = c_us_np
study_us_idx = s_us_np
drug_us_idx  = d_us_np

try:
    df_us
    df_eu
except NameError:
    df_us = df[df['Region'] == 'US'].copy()
    df_eu = df[df['Region'] == 'EU'].copy()

def ks_test(real, fake, feature_names):
    rows = []
    for i, feat in enumerate(feature_names):
        ks, p = stats.ks_2samp(real[:, i], fake[:, i])
        rows.append((feat, ks, p))
    return pd.DataFrame(rows, columns=['feature', 'ks_stat', 'p_value'])

def chi2_categorical(real_df, fake_df, columns):
    """
    χ² homogeneity test for each categorical column.
    Uses chi2_contingency (2×k table) – never throws the
    'observed / expected sums differ' error.
    """
    rows = []
    for col in columns:
        cats = pd.Index(real_df[col].unique()).union(fake_df[col].unique())
        real_counts = real_df[col].value_counts().reindex(cats, fill_value=0)
        fake_counts = fake_df[col].value_counts().reindex(cats, fill_value=0)
        table = np.vstack([real_counts.values, fake_counts.values])
        chi2, p, _, _ = stats.chi2_contingency(table, correction=False)
        rows.append((col, chi2, p))
    return pd.DataFrame(rows, columns=['column', 'chi2_stat', 'p_value'])

def corr_frobenius(r, f):
    return np.linalg.norm(np.corrcoef(r, rowvar=False) -
                          np.corrcoef(f, rowvar=False))

def _rbf(X, Y, s=1.0):
    XX = np.sum(X*X, 1)[:, None]
    YY = np.sum(Y*Y, 1)[None, :]
    return np.exp(-(XX - 2*X@Y.T + YY) / (2*s*s))

def mmd(X, Y, s=1.0):
    return _rbf(X, X, s).mean() + _rbf(Y, Y, s).mean() - 2*_rbf(X, Y, s).mean()

def domain_confusion(real_src, real_tgt, fake_tgt):
    X = np.vstack([real_src, real_tgt])
    y = np.r_[np.zeros(len(real_src)), np.ones(len(real_tgt))]
    clf = LogisticRegression(max_iter=1000).fit(X, y)
    return (clf.predict(fake_tgt) == 1).mean()

def tstr(fake_df, real_df, feats, label):
    clf = LogisticRegression(max_iter=1000).fit(fake_df[feats], fake_df[label])
    yhat = clf.predict(real_df[feats])
    acc  = accuracy_score(real_df[label], yhat)
    try:
        auc  = roc_auc_score(real_df[label],
                             clf.predict_proba(real_df[feats])[:, 1])
    except ValueError:
        auc = np.nan
    return acc, auc

def nn_stats(real, fake):
    nbr = NearestNeighbors(n_neighbors=1).fit(real)
    d, _ = nbr.kneighbors(fake)
    d = d.ravel()
    return {"mean": d.mean(), "median": np.median(d),
            "min":  d.min(),  "max":    d.max()}

def cycle_id(G, F, real, cont, s_idx, d_idx, zdim):
    bs  = real.shape[0]
    z_r = np.random.randn(bs, zdim).astype('float32')
    z_0 = np.zeros((bs, zdim), dtype='float32')
    fake   = G.predict([real, cont, s_idx, d_idx, z_r], verbose=0)
    cycled = F.predict([fake, cont, s_idx, d_idx, z_0], verbose=0)
    ident  = F.predict([real, cont, s_idx, d_idx, z_0], verbose=0)
    return (mean_squared_error(real, cycled),
            mean_squared_error(real, ident))

# Generate samples

z_us  = np.random.randn(len(real_eu_np), Z_DIM).astype('float32')
fake_us_np = G_EU2US.predict([real_eu_np, cont_eu_np,
                              study_eu_idx, drug_eu_idx, z_us],
                             verbose=0)

z_eu  = np.random.randn(len(real_us_np), Z_DIM).astype('float32')
fake_eu_np = G_US2EU.predict([real_us_np, cont_us_np,
                              study_us_idx, drug_us_idx, z_eu],
                             verbose=0)


df_fake_us = pd.DataFrame(fake_us_np, columns=FEATURES_X)

# Recover categorical columns from the EU rows we just translated
df_fake_us['Study_ID'] = df_eu['Study_ID'].values[:len(df_fake_us)]
df_fake_us['Drug']     = df_eu['Drug'].values[:len(df_fake_us)]

# Attach labels for TSTR
df_fake_us[['Responder_Status', 'Remission_Status']] = \
    df_us[['Responder_Status', 'Remission_Status']]\
        .iloc[:len(df_fake_us)].reset_index(drop=True)


print("\n================  E V A L U A T I O N  =================\n")

# a) Univariate KS
print("Univariate KS tests (US domain):")
print(ks_test(real_us_np, fake_us_np, FEATURES_X).to_string(index=False))

# b) χ² on categorical context
print("\nCategorical χ² tests (US domain):")
print(chi2_categorical(df_us, df_fake_us, ['Study_ID', 'Drug']).to_string(index=False))

# c) Correlation / MMD
print(f"\nCorrelation Frobenius norm (US): {corr_frobenius(real_us_np, fake_us_np):.5f}")
print(f"MMD (σ=1) (US)                  : {mmd(real_us_np, fake_us_np):.6f}")

# d) Domain confusion
print(f"\nDomain classifier – fake‑US recognised as US: "
      f"{domain_confusion(real_eu_np, real_us_np, fake_us_np):.3f}")

# e) TSTR utility
acc, auc = tstr(df_fake_us, df_us, FEATURES_X, 'Responder_Status')
print(f"\nTSTR Responder_Status  acc={acc:.3f}, AUC={auc:.3f}")

# f) Privacy proxy
print("\nNearest‑Neighbour distance fake‑US ➜ real‑US:")
for k, v in nn_stats(real_us_np, fake_us_np).items():
    print(f"  {k:<6s}: {v:.4f}")

# g) Cycle & identity consistency
mse_cycle_us, mse_id_us = cycle_id(G_EU2US, G_US2EU,
                                   real_eu_np, cont_eu_np,
                                   study_eu_idx, drug_eu_idx, Z_DIM)
mse_cycle_eu, mse_id_eu = cycle_id(G_US2EU, G_EU2US,
                                   real_us_np, cont_us_np,
                                   study_us_idx, drug_us_idx, Z_DIM)

print(f"\nEU→US→EU cycle MSE : {mse_cycle_us:.5f}   |   EU identity MSE : {mse_id_us:.5f}")
print(f"US→EU→US cycle MSE : {mse_cycle_eu:.5f}   |   US identity MSE : {mse_id_eu:.5f}")

print("\n================  E N D   O F   E V A L  ================\n")