## Synthetic data validation

In [1]:
# ================================================================
# Pharma PoC — Synthetic Generator (ml_optimized toggle) + ML Pipeline
# + compare_datasets + generate_ml_benchmark_datasets
# (compatible with previous artifacts and API)
# ================================================================
import pandas as pd
import numpy as np
import random
from typing import Dict, Tuple, Optional, Any, List
from datetime import datetime
import warnings, os
warnings.filterwarnings('ignore')

try:
    import openpyxl  # для Excel IO
except Exception as e:
    print("openpyxl not found. Excel export will be skipped.", e)

# ------------------------------------------------
# 1) Data generator (orthogonal rules)
# Signal/noise parameters are configured via ml_optimized
# ------------------------------------------------
class EnhancedPharmaDataGenerator:
    def __init__(self, seed: Optional[int] = None, signal_strength: float = 1.0, noise_level: float = 0.01):
        if seed is not None:
            random.seed(seed); np.random.seed(seed)
        self.signal_strength = float(np.clip(signal_strength, 0.0, 1.0))
        self.noise_level = float(np.clip(noise_level, 0.0, 0.05))

        self.column_headers = [
            'ADJ DATE MEDICARE TXN DATE/TIME','NPI PHARMACY NPI','RX NBR PRESCRIPTION NUMBER',
            'SERV DATE DATE OF SERVICE','CLAIM ID MEDICARE ICN/AUTH #','PROD SVC ID PRODUCT/SERVICE ID',
            'DRUG NAME N/A','QTY DISP QUANTITY DISPENSED','WAC UNIT PRICE WAC UNIT PRICE',
            'MFP UNIT PRICE MFP UNIT PRICE','EST REIMB AMT Estimated MFG Reimbursement Amt',
            'MFP RULE PRC POINT MFP Rule Price Point','MFP RULE DISC MFP Rule Disc %',
            'MFP RULE UNIT PRICE MFP Rule Unit Price','MFG NAME MFG Name','MFG CUST ID MFG Customer ID',
            'Medispan/FDB WAC Price','Medispan/FDB MFP Price','Medispan/FDB Effective Date',
            'Medispan/FDB Termination Date','835 report Check Number','835 report Claim Number',
            '835 report Pharmacy Number','835 report Rx Number','835 report Refill Number',
            '835 report Estimated refund from adjudication','835 report Actual Payment Amount',
            '835 report Date Filled/Date of Service','835 report Quantity',
            '835 report Adjudicated Procedure Code (Product/Service ID)',
            '835 report Adjustment Code / CARC codes','835 report Adjustment Amount',
            '835 report Adjustment Quantity','835 report Qualifier Code / RARC codes',
            'Expected Outcomes Error category','Expected Outcomes Clerk Input',' Questions/comments','Case'
        ]

        def pv(v): 
            base = v*(1.0 - 0.8*self.signal_strength)
            return max(0.005, base)

        self.drugs_db = {
            'STELARA INJ 5MG/ML': {'base_wac': 81.5254,'base_mfp': 25.1081,'qty_range': [5],
                'ndc_codes': ['57894005427'],'manufacturer': 'JANSSEN BIOTECH','therapeutic_class': 'Immunosuppressant','price_volatility': pv(0.02)},
            'ENBREL INJ 25/0.5ML': {'base_wac': 2039.4,'base_mfp': 583.58,'qty_range': [1],
                'ndc_codes': ['58406001001'],'manufacturer': 'AMGEN/ IMMUNEX','therapeutic_class': 'Immunosuppressant','price_volatility': pv(0.03)},
            'STELARA INJ 90MG/ML': {'base_wac': 29151.46,'base_mfp': 8980.1446,'qty_range': [1],
                'ndc_codes': ['57894006103'],'manufacturer': 'JANSSEN BIOTECH','therapeutic_class': 'Immunosuppressant','price_volatility': pv(0.02)},
            'ENBREL SRCLK INJ 50MG/ML': {'base_wac': 1635.15,'base_mfp': 583.7669,'qty_range': [1],
                'ndc_codes': ['58406044501'],'manufacturer': 'AMGEN/ IMMUNEX','therapeutic_class': 'Immunosuppressant','price_volatility': pv(0.03)},
            'ENTRESTO CAP 15-16MG': {'base_wac': 11.7535,'base_mfp': 1.2292,'qty_range': [30],
                'ndc_codes': ['00078123820'],'manufacturer': 'NOVARTIS','therapeutic_class': 'Cardiovascular','price_volatility': pv(0.05)},
            'FIASP F/P PEN 100U/ML': {'base_wac': 37.26,'base_mfp': 8.96,'qty_range': [10,15],
                'ndc_codes': ['00169184110'],'manufacturer': 'NOVO NORDISK','therapeutic_class': 'Diabetes','price_volatility': pv(0.04)},
            # extended
            'HUMIRA INJ 40MG/0.8ML': {'base_wac': 2400.85,'base_mfp': 720.25,'qty_range': [1,2],
                'ndc_codes': ['00074123456'],'manufacturer': 'ABBVIE INC','therapeutic_class': 'Immunosuppressant','price_volatility': pv(0.03)},
            'KEYTRUDA INJ 100MG/4ML': {'base_wac': 12500.0,'base_mfp': 3750.0,'qty_range': [1],
                'ndc_codes': ['00006234567'],'manufacturer': 'MERCK & CO','therapeutic_class': 'Oncology','price_volatility': pv(0.02)},
            'OZEMPIC INJ 0.25MG/1.5ML': {'base_wac': 850.5,'base_mfp': 255.15,'qty_range': [1,3],
                'ndc_codes': ['00169345678'],'manufacturer': 'SANOFI-AVENTIS','therapeutic_class': 'Diabetes','price_volatility': pv(0.04)},
            'DUPIXENT INJ 300MG/2ML': {'base_wac': 3200.75,'base_mfp': 960.23,'qty_range': [1,2],
                'ndc_codes': ['00024456789'],'manufacturer': 'REGENERON PHARM','therapeutic_class': 'Dermatology','price_volatility': pv(0.03)},
            'SKYRIZI INJ 150MG/ML': {'base_wac': 18500.0,'base_mfp': 5550.0,'qty_range': [1],
                'ndc_codes': ['00074567890'],'manufacturer': 'BRISTOL MYERS SQUIBB','therapeutic_class': 'Dermatology','price_volatility': pv(0.02)}
        }

        self.mfg_customer_ids = {
            'JANSSEN BIOTECH': 'MFP00008','AMGEN/ IMMUNEX': 'MFP00003','NOVARTIS': 'MFP00002',
            'NOVO NORDISK': 'MFP00009','ABBVIE INC': 'MFP00010','MERCK & CO': 'MFP00011',
            'SANOFI-AVENTIS': 'MFP00012','REGENERON PHARM': 'MFP00013','BRISTOL MYERS SQUIBB': 'MFP00014'
        }

        self.pharmacy_npis = [1013998921,1518039858,1043382302,1063510097,1326029232,1629341177,1578836698,
                              1234567890,1345678901,1456789012,1567890123,1678901234]
        self.pharmacy_profiles = {
            1013998921:'careful',1518039858:'careful',1043382302:'sloppy',1063510097:'sloppy',
            1326029232:'neutral',1629341177:'neutral',1578836698:'neutral',1234567890:'sloppy',
            1345678901:'careful',1456789012:'neutral',1567890123:'sloppy',1678901234:'neutral'
        }

        self.rarc_codes = ['N907','N908','N910','N911']
        self.error_categories = [
            'Write off due to manual data entry error',
            'Write off due to payment applied to the wrong claim',
            'Write off due to WAC UNIT PRICE diff from WAC Medispan',
            'Write off due to duplicate payment',
            'Sent to collection due to MFG timing difference'
        ]
        self.error_to_rarc_primary = {
            'Write off due to manual data entry error': 'N907',
            'Write off due to payment applied to the wrong claim': 'N908',
            'Write off due to WAC UNIT PRICE diff from WAC Medispan': 'N910',
            'Write off due to duplicate payment': 'N911',
            'Sent to collection due to MFG timing difference': 'N910'
        }
        self.error_to_carc = {
            'Write off due to manual data entry error': '129',
            'Write off due to payment applied to the wrong claim': '22',
            'Write off due to WAC UNIT PRICE diff from WAC Medispan': '237',
            'Write off due to duplicate payment': '97',
            'Sent to collection due to MFG timing difference': '45'
        }
        self.clerk_input_mapping = {
            'Write off due to manual data entry error': [
                'Manual correction needed','Data entry validation required',
                'Review input accuracy','Store these 340B claims and compare with manufacturer payment'
            ],
            'Write off due to payment applied to the wrong claim': [
                'Payment reallocation required','Accepting MFP price and business has to review all 908',
                'Payment reversal and reapply','Business is to review'
            ],
            'Write off due to WAC UNIT PRICE diff from WAC Medispan': [
                'Credit notes not accounted','WAC price verification needed','Medispan price reconciliation required'
            ],
            'Write off due to duplicate payment': [
                'Duplicate payment identified','Payment consolidation required','Refund processing needed','Business is to review'
            ],
            'Sent to collection due to MFG timing difference': [
                'Exception handling','Manufacturer timing review required',
                'Collection process initiated','Business is to review'
            ]
        }
        self.slow_mfg = {'JANSSEN BIOTECH','AMGEN/ IMMUNEX','BRISTOL MYERS SQUIBB'}
        self.error_underpayment_ranges = {
            'Write off due to manual data entry error': (0.05, 0.09),
            'Write off due to payment applied to the wrong claim': (0.10, 0.15),
            'Write off due to WAC UNIT PRICE diff from WAC Medispan': (0.18, 0.25),
            'Write off due to duplicate payment': (0.35, 0.45),
            'Sent to collection due to MFG timing difference': (0.26, 0.32)
        }
        self.dup_rate = 0.06
        self.claim_id_counter = 1820000001
        self.rx_counter = 1820001

    # ---- helpers
    def _generate_timestamp(self, year_range=(2025,2026)):
        year = random.randint(*year_range); month = random.randint(1,12); day = random.randint(1,28)
        hour = random.randint(8,17); minute = random.randint(0,59); second = random.randint(0,59)
        return f"{year:04d}{month:02d}{day:02d}{hour:02d}{minute:02d}{second:02d}0"

    def _generate_date(self, format_type="service", force_month: Optional[int]=None) -> str:
        if format_type == "service":
            year = random.choice([2025,2026])
        elif format_type == "payment":
            year = random.choice([2026,2027])
        else:
            year = 2026
        month = force_month if force_month else random.randint(1,12)
        day = random.randint(1,28)
        return f"{day:02d}/{month:02d}/{year}"

    def _pick_drug_by(self, *, classes=None, manufacturers=None, qty_options=None) -> Dict[str, Any]:
        pool = []
        for name, d in self.drugs_db.items():
            if classes and d.get('therapeutic_class') not in classes: continue
            if manufacturers and d.get('manufacturer') not in manufacturers: continue
            if qty_options and not any(q in d['qty_range'] for q in qty_options): continue
            pool.append((name, d))
        name, d = random.choice(pool if pool else list(self.drugs_db.items()))
        vol = d.get('price_volatility', 0.02)
        return {
            'name': name,
            'wac_price': round(d['base_wac'] * random.uniform(1-vol, 1+vol), 4),
            'mfp_price': round(d['base_mfp'] * random.uniform(1-vol, 1+vol), 4),
            'quantity': random.choice(d['qty_range']),
            'ndc_code': random.choice(d['ndc_codes']),
            'manufacturer': d['manufacturer'],
            'therapeutic_class': d.get('therapeutic_class','General')
        }

    def _select_drug_with_variations(self) -> Dict[str, Any]:
        all_drugs = list(self.drugs_db.keys())
        reference_drugs = all_drugs[:6]; extended_drugs = all_drugs[6:]
        name = random.choice(reference_drugs) if random.random() < (0.90 - 0.3*self.noise_level) else random.choice(extended_drugs)
        d = self.drugs_db[name]; vol = d.get('price_volatility', 0.02)
        return {
            'name': name,
            'wac_price': round(d['base_wac'] * random.uniform(1-vol, 1+vol), 4),
            'mfp_price': round(d['base_mfp'] * random.uniform(1-vol, 1+vol), 4),
            'quantity': random.choice(d['qty_range']),
            'ndc_code': random.choice(d['ndc_codes']),
            'manufacturer': d['manufacturer'],
            'therapeutic_class': d.get('therapeutic_class','General')
        }

    def _calculate_pricing_rules(self, wac_price: float) -> Dict[str, Any]:
        p_wac = 0.78
        if random.random() < p_wac:
            return {'rule_point':'WAC','discount':round(random.uniform(2.0, 20.0), 1),'rule_unit_price':''}
        return {'rule_point':'','discount':'','rule_unit_price':round(wac_price * random.uniform(0.66, 0.85), 4)}

    def _calculate_estimated_reimbursement(self, wac_price: float, mfp_price: float, quantity: int, rules: Dict[str, Any]) -> float:
        base_amount = wac_price * quantity
        if rules['rule_point']=='WAC' and rules['discount']!='':
            estimated = base_amount * (1 - float(rules['discount'])/100) * random.uniform(0.82, 0.9)
        elif rules['rule_unit_price']!='':
            estimated = float(rules['rule_unit_price']) * quantity * random.uniform(0.995, 1.01)
        else:
            estimated = base_amount * random.uniform(0.72, 0.8)
        return round(max(estimated, mfp_price * quantity), 2)

    # ---- flags for strict rules
    def _rule_flags(self, rec: Dict[str, Any]) -> Dict[str, bool]:
        def month_from_ddmmyyyy(s: str) -> int:
            try: return int(s.split('/')[1])
            except: return 1
        month = month_from_ddmmyyyy(rec['SERV DATE DATE OF SERVICE'])
        quarter = (month-1)//3 + 1

        is_wac = (rec['MFP RULE PRC POINT MFP Rule Price Point'] == 'WAC')
        disc = float(rec['MFP RULE DISC MFP Rule Disc %']) if rec['MFP RULE DISC MFP Rule Disc %']!='' else 0.0
        has_fixed = (rec['MFP RULE UNIT PRICE MFP Rule Unit Price']!='')
        qty1 = (rec['QTY DISP QUANTITY DISPENSED']==1)

        drug_name = rec['DRUG NAME N/A']
        manuf = rec['MFG NAME MFG Name']
        th_class = self.drugs_db.get(drug_name,{}).get('therapeutic_class','General')
        npi = rec['NPI PHARMACY NPI']
        npi_prof = self.pharmacy_profiles.get(npi,'neutral')

        est_amt = float(rec['EST REIMB AMT Estimated MFG Reimbursement Amt'])
        unit_wac = float(rec['WAC UNIT PRICE WAC UNIT PRICE'])
        unit_mfp = float(rec['MFP UNIT PRICE MFP UNIT PRICE'])

        flags = dict(
            month=month, quarter=quarter,
            is_wac_rule=is_wac,
            has_fixed_unit_price=has_fixed,
            disc_is_high18=(disc>=18.0),
            disc_lt3=(disc<3.0),
            qty_is_1=qty1,
            est_lt_1200=(est_amt<1200.0),
            is_slow_mfg=(manuf in self.slow_mfg),
            is_immuno_or_derm=(th_class in {'Immunosuppressant','Dermatology'}),
            is_cardio_or_diab=(th_class in {'Cardiovascular','Diabetes'}),
            npi_sloppy=(npi_prof=='sloppy'),
            npi_careful=(npi_prof=='careful'),
            q_in_14=(quarter in {1,4}),
            est_vs_wac_amt=(est_amt/(unit_wac*max(rec['QTY DISP QUANTITY DISPENSED'],1)+1e-9)),
            mfp_vs_wac_unit=(unit_mfp/(unit_wac+1e-9))
        )
        return flags

    def _decide_error_category(self, rec: Dict[str, Any], force_class: Optional[str]=None) -> str:
        if force_class: 
            return force_class
        f = self._rule_flags(rec)
        if f['is_wac_rule'] and f['disc_is_high18'] and f['is_immuno_or_derm'] and f['qty_is_1']:
            return 'Write off due to WAC UNIT PRICE diff from WAC Medispan'
        if f['has_fixed_unit_price'] and f['is_slow_mfg'] and f['q_in_14'] and f['qty_is_1']:
            return 'Sent to collection due to MFG timing difference'
        if f['npi_sloppy'] and f['is_cardio_or_diab'] and f['disc_lt3'] and f['est_lt_1200']:
            return 'Write off due to payment applied to the wrong claim'
        return 'Write off due to manual data entry error'

    def _select_target_variables(self, rec: Dict[str, Any], force_class: Optional[str]=None) -> Tuple[str,str,str]:
        err = self._decide_error_category(rec, force_class=force_class)
        rarc = self.error_to_rarc_primary[err]
        clerk = random.choice(self.clerk_input_mapping[err])
        return rarc, err, clerk

    def _tune_record_for_error(self, rec: Dict[str, Any], err: str) -> Dict[str, Any]:
        drug = None; rules = None; force_month = None
        if err == 'Write off due to payment applied to the wrong claim':
            drug = self._pick_drug_by(classes={'Cardiovascular','Diabetes'})
            if drug['name']=='ENTRESTO CAP 15-16MG': drug['quantity'] = 30
            elif drug['name']=='FIASP F/P PEN 100U/ML': drug['quantity'] = random.choice([10,15])
            rules = {'rule_point':'WAC','discount':round(random.uniform(0.0, 2.5),1),'rule_unit_price':''}
            sloppy_npis = [n for n,p in self.pharmacy_profiles.items() if p=='sloppy']
            rec['NPI PHARMACY NPI'] = random.choice(sloppy_npis)
        elif err == 'Sent to collection due to MFG timing difference':
            drug = self._pick_drug_by(classes={'Immunosuppressant','Dermatology'}, manufacturers=self.slow_mfg, qty_options=[1])
            drug['quantity'] = 1
            rules = {'rule_point':'','discount':'','rule_unit_price': round(drug['wac_price']*random.uniform(0.66,0.84),4)}
            force_month = random.choice([1,2,3,10,11,12])
        elif err == 'Write off due to WAC UNIT PRICE diff from WAC Medispan':
            drug = self._pick_drug_by(classes={'Immunosuppressant','Dermatology'}, qty_options=[1])
            drug['quantity'] = 1
            rules = {'rule_point':'WAC','discount': round(random.uniform(20.0,25.0),1),'rule_unit_price':''}
        elif err == 'Write off due to duplicate payment':
            return rec
        else:
            drug = self._pick_drug_by(classes=None)
            rules = {'rule_point':'WAC','discount': round(random.uniform(6.0,12.0),1),'rule_unit_price':''}
            non_sloppy = [n for n,p in self.pharmacy_profiles.items() if p!='sloppy']
            rec['NPI PHARMACY NPI'] = random.choice(non_sloppy)

        if drug is not None:
            rec.update({
                'DRUG NAME N/A': drug['name'],
                'PROD SVC ID PRODUCT/SERVICE ID': drug['ndc_code'],
                'QTY DISP QUANTITY DISPENSED': drug['quantity'],
                'WAC UNIT PRICE WAC UNIT PRICE': drug['wac_price'],
                'MFP UNIT PRICE MFP UNIT PRICE': drug['mfp_price'],
                'MFG NAME MFG Name': drug['manufacturer'],
                'MFG CUST ID MFG Customer ID': self.mfg_customer_ids[drug['manufacturer']],
                'Medispan/FDB WAC Price': drug['wac_price'],
                'Medispan/FDB MFP Price': drug['mfp_price'],
            })
            if force_month:
                rec['SERV DATE DATE OF SERVICE'] = self._generate_date("service", force_month=force_month)

        if rules is not None:
            rec['MFP RULE PRC POINT MFP Rule Price Point'] = rules['rule_point']
            rec['MFP RULE DISC MFP Rule Disc %'] = rules['discount']
            rec['MFP RULE UNIT PRICE MFP Rule Unit Price'] = rules['rule_unit_price']

        est = self._calculate_estimated_reimbursement(
            rec['WAC UNIT PRICE WAC UNIT PRICE'], rec['MFP UNIT PRICE MFP UNIT PRICE'],
            rec['QTY DISP QUANTITY DISPENSED'],
            {'rule_point': rec['MFP RULE PRC POINT MFP Rule Price Point'],
             'discount': rec['MFP RULE DISC MFP Rule Disc %'],
             'rule_unit_price': rec['MFP RULE UNIT PRICE MFP Rule Unit Price']}
        )
        rec['EST REIMB AMT Estimated MFG Reimbursement Amt'] = est
        return rec

    def _generate_base_record(self, force_error: Optional[str]=None) -> Tuple[Dict[str, Any], float]:
        drug = self._select_drug_with_variations()
        rules = self._calculate_pricing_rules(drug['wac_price'])
        est = self._calculate_estimated_reimbursement(drug['wac_price'], drug['mfp_price'], drug['quantity'], rules)
        adj_date = self._generate_timestamp()
        serv_date = self._generate_date("service")
        claim_id = getattr(self, "claim_id_counter"); self.claim_id_counter += 1
        rx_nbr = getattr(self, "rx_counter"); self.rx_counter += 1
        npi = random.choice(self.pharmacy_npis)

        rec = {
            'ADJ DATE MEDICARE TXN DATE/TIME': adj_date,
            'NPI PHARMACY NPI': npi,
            'RX NBR PRESCRIPTION NUMBER': rx_nbr,
            'SERV DATE DATE OF SERVICE': serv_date,
            'CLAIM ID MEDICARE ICN/AUTH #': claim_id,
            'PROD SVC ID PRODUCT/SERVICE ID': drug['ndc_code'],
            'DRUG NAME N/A': drug['name'],
            'QTY DISP QUANTITY DISPENSED': drug['quantity'],
            'WAC UNIT PRICE WAC UNIT PRICE': drug['wac_price'],
            'MFP UNIT PRICE MFP UNIT PRICE': drug['mfp_price'],
            'EST REIMB AMT Estimated MFG Reimbursement Amt': est,
            'MFP RULE PRC POINT MFP Rule Price Point': rules['rule_point'],
            'MFP RULE DISC MFP Rule Disc %': rules['discount'],
            'MFP RULE UNIT PRICE MFP Rule Unit Price': rules['rule_unit_price'],
            'MFG NAME MFG Name': drug['manufacturer'],
            'MFG CUST ID MFG Customer ID': self.mfg_customer_ids[drug['manufacturer']],
            'Medispan/FDB WAC Price': drug['wac_price'],
            'Medispan/FDB MFP Price': drug['mfp_price'],
            'Medispan/FDB Effective Date': '01/01/2026',
            'Medispan/FDB Termination Date': '01/01/2027',
        }

        if force_error is not None:
            rec = self._tune_record_for_error(rec, force_error)
            est = rec['EST REIMB AMT Estimated MFG Reimbursement Amt']

        rarc, err, clerk = self._select_target_variables(rec, force_class=force_error)
        rec.update({
            '835 report Qualifier Code / RARC codes': rarc,
            'Expected Outcomes Error category': err,
            'Expected Outcomes Clerk Input': clerk,
            ' Questions/comments': f"Review case for {rec['DRUG NAME N/A']} - {err}",
        })
        return rec, est

    def _calculate_actual_payment(self, estimated_amount: float, error_category: str) -> float:
        lo, hi = self.error_underpayment_ranges[error_category]
        span = (hi - lo) * (1 - 0.6*self.signal_strength)
        mid = (lo + hi)/2
        lo2, hi2 = max(0.0, mid - span/2), mid + span/2
        return round(max(estimated_amount * (1 - random.uniform(lo2, hi2)), 1.0), 2)

    def generate_step1_record(self, force_error: Optional[str]=None) -> Dict[str, Any]:
        base, est = self._generate_base_record(force_error=force_error)
        actual = self._calculate_actual_payment(est, base['Expected Outcomes Error category'])
        base.update({
            '835 report Check Number': 'N/A',
            '835 report Claim Number': base['CLAIM ID MEDICARE ICN/AUTH #'],
            '835 report Pharmacy Number': base['NPI PHARMACY NPI'],
            '835 report Rx Number': base['RX NBR PRESCRIPTION NUMBER'],
            '835 report Refill Number': '00',
            '835 report Estimated refund from adjudication': 'N/A',
            '835 report Actual Payment Amount': actual,
            '835 report Date Filled/Date of Service': self._generate_date("payment"),
            '835 report Quantity': base['QTY DISP QUANTITY DISPENSED'],
            '835 report Adjudicated Procedure Code (Product/Service ID)': base['PROD SVC ID PRODUCT/SERVICE ID'],
            '835 report Adjustment Code / CARC codes': self.error_to_carc[base['Expected Outcomes Error category']],
            '835 report Adjustment Amount': 'N/A',
            '835 report Adjustment Quantity': 'N/A',
            'Case': 'Step 1 - Train'
        })
        return base

    def generate_step2_record(self) -> Dict[str, Any]:
        rec = self.generate_step1_record()
        rec['Case'] = 'Step 2 - Suggest'
        return rec

    def generate_step3_record(self) -> Dict[str, Any]:
        base, _ = self._generate_base_record()
        base.update({
            '835 report Check Number': '',
            '835 report Claim Number': '',
            '835 report Pharmacy Number': '',
            '835 report Rx Number': '',
            '835 report Refill Number': '',
            '835 report Estimated refund from adjudication': '',
            '835 report Actual Payment Amount': '',
            '835 report Date Filled/Date of Service': '',
            '835 report Quantity': '',
            '835 report Adjudicated Procedure Code (Product/Service ID)': '',
            '835 report Adjustment Code / CARC codes': '',
            '835 report Adjustment Amount': '',
            '835 report Adjustment Quantity': '',
            'Case': 'Step 3 - Forecast'
        })
        return base

    def _make_duplicates(self, records: list, rate: float = 0.06):
        n = int(len(records) * rate)
        source_pool = [r for r in records if r['Case']!='Step 3 - Forecast']
        for _ in range(n):
            base = random.choice(source_pool)
            dup = base.copy()
            dup['ADJ DATE MEDICARE TXN DATE/TIME'] = self._generate_timestamp()
            dup['SERV DATE DATE OF SERVICE'] = self._generate_date("service")
            dup['Expected Outcomes Error category'] = 'Write off due to duplicate payment'
            dup['Expected Outcomes Clerk Input'] = random.choice(
                self.clerk_input_mapping['Write off due to duplicate payment']
            )
            dup['835 report Qualifier Code / RARC codes'] = 'N911'
            dup['835 report Adjustment Code / CARC codes'] = self.error_to_carc['Write off due to duplicate payment']
            records.append(dup)

    def generate_etalon_based_dataset(self, n_step1: int = 350, n_step2: int = 350, n_step3: int = 350) -> pd.DataFrame:
        all_records = []
        step1_quota = {  # сбалансированный train
            'Write off due to manual data entry error': 0.20,
            'Write off due to payment applied to the wrong claim': 0.20,
            'Write off due to WAC UNIT PRICE diff from WAC Medispan': 0.20,
            'Write off due to duplicate payment': 0.20,
            'Sent to collection due to MFG timing difference': 0.20
        }
        min_counts = {k: int(round(n_step1 * v)) for k,v in step1_quota.items()}
        counts = {k: 0 for k in min_counts}
        for _ in range(n_step1):
            force = next((c for c in min_counts if counts[c] < min_counts[c]), None)
            rec = self.generate_step1_record(force_error=force)
            counts[rec['Expected Outcomes Error category']] += 1
            all_records.append(rec)
        for _ in range(n_step2):
            all_records.append(self.generate_step2_record())
        for _ in range(n_step3):
            all_records.append(self.generate_step3_record())
        self._make_duplicates(all_records, rate=self.dup_rate)
        random.shuffle(all_records)
        df = pd.DataFrame(all_records).reindex(columns=self.column_headers)
        return df

    def quality_check(self, df: pd.DataFrame) -> Dict[str, Any]:
        quality_stats = {}
        case_dist = df['Case'].value_counts()
        quality_stats['case_distribution'] = case_dist.to_dict()

        filled = df[df['Case']!='Step 3 - Forecast']
        if len(filled)>0:
            est = pd.to_numeric(filled['EST REIMB AMT Estimated MFG Reimbursement Amt'], errors='coerce')
            act = pd.to_numeric(filled['835 report Actual Payment Amount'], errors='coerce')
            valid = (~est.isna()) & (~act.isna())
            if valid.sum()>0:
                under = ((est[valid]-act[valid])/est[valid]*100)
                quality_stats['payment_validation'] = {
                    'all_underpaid': bool((act[valid] < est[valid]).all()),
                    'avg_underpayment_pct': round(float(under.mean()),2),
                    'underpayment_range': [round(float(under.min()),2), round(float(under.max()),2)]
                }

        step3 = df[df['Case']=='Step 3 - Forecast']
        if len(step3)>0:
            u_ag = ['835 report Check Number','835 report Claim Number','835 report Pharmacy Number','835 report Rx Number',
                    '835 report Refill Number','835 report Estimated refund from adjudication','835 report Actual Payment Amount',
                    '835 report Date Filled/Date of Service','835 report Quantity','835 report Adjudicated Procedure Code (Product/Service ID)',
                    '835 report Adjustment Code / CARC codes','835 report Adjustment Amount','835 report Adjustment Quantity']
            empty_ok = {f:int((step3[f]=='').sum()) for f in u_ag}
            quality_stats['step3_validation'] = {'empty_u_ag_fields': empty_ok}

        rel = {}
        for s,t,desc in [
            ('WAC UNIT PRICE WAC UNIT PRICE','Medispan/FDB WAC Price','I → Q'),
            ('CLAIM ID MEDICARE ICN/AUTH #','835 report Claim Number','E → V'),
            ('NPI PHARMACY NPI','835 report Pharmacy Number','B → W'),
            ('RX NBR PRESCRIPTION NUMBER','835 report Rx Number','C → X'),
            ('QTY DISP QUANTITY DISPENSED','835 report Quantity','H → AC')
        ]:
            check_df = df if '835 report' not in t else df[df['Case']!='Step 3 - Forecast']
            if len(check_df)>0:
                rel[desc] = round(100*(check_df[s]==check_df[t]).sum()/len(check_df),1)
        quality_stats['relationship_validation'] = rel

        diversity_stats = {
            'unique_drugs': int(df['DRUG NAME N/A'].nunique()),
            'unique_manufacturers': int(df['MFG NAME MFG Name'].nunique()),
            'unique_pharmacies': int(df['NPI PHARMACY NPI'].nunique()),
            'unique_rarc_codes': int(df['835 report Qualifier Code / RARC codes'].nunique()),
            'unique_error_categories': int(df['Expected Outcomes Error category'].nunique())
        }
        quality_stats['diversity_stats'] = diversity_stats

        wac_rules = int((df['MFP RULE PRC POINT MFP Rule Price Point']=='WAC').sum())
        fixed_rules = int((df['MFP RULE UNIT PRICE MFP Rule Unit Price']!='').sum())
        total = len(df)
        pricing_stats = {
            'wac_rules_pct': round(100*wac_rules/total,1),
            'fixed_rules_pct': round(100*fixed_rules/total,1),
            'mutually_exclusive': bool(len(df[(df['MFP RULE PRC POINT MFP Rule Price Point']=='WAC') & (df['MFP RULE UNIT PRICE MFP Rule Unit Price']!='')])==0)
        }
        quality_stats['pricing_stats'] = pricing_stats

        indicators = 0
        indicators += 1 if quality_stats.get('payment_validation',{}).get('all_underpaid',False) else 0
        indicators += 1 if pricing_stats['mutually_exclusive'] else 0
        indicators += 1 if diversity_stats['unique_drugs']>=6 else 0
        indicators += 1 if all(v>95 for v in rel.values()) else 0
        indicators += 1
        quality_stats['quality_score'] = round(100*indicators/5,0)
        return quality_stats

    def save_to_excel(self, df: pd.DataFrame, filename: str) -> Optional[str]:
        try:
            with pd.ExcelWriter(filename, engine='openpyxl') as writer:
                df.to_excel(writer, sheet_name='Scenario Data Structure', index=False)
                ws = writer.sheets['Scenario Data Structure']
                for col in ws.columns:
                    max_len = max(len(str(c.value)) if c.value is not None else 0 for c in col)
                    ws.column_dimensions[col[0].column_letter].width = min(max_len+2, 50)
            return filename
        except Exception as e:
            print("Excel save error:", e)
            return None

    def save_to_csv(self, df: pd.DataFrame, filename: str) -> Optional[str]:
        try:
            df.to_csv(filename, index=False)
            return filename
        except Exception as e:
            print("CSV save error:", e)
            return None


# ------------------------------------------------
#2) Compatible wrappers (added ml_optimized)
# ------------------------------------------------
def _make_generator_from_flag(seed: Optional[int], ml_optimized: bool) -> EnhancedPharmaDataGenerator:
    """
    If ml_optimized=True — maximum signal/minimum noise.
    If False — moderate signal/slightly more noise (for benchmark).
    """
    if ml_optimized:
        return EnhancedPharmaDataGenerator(seed=seed, signal_strength=1.0, noise_level=0.005)
    else:
        return EnhancedPharmaDataGenerator(seed=seed, signal_strength=0.7, noise_level=0.025)

def generate_etalon_Pharma_dataset(seed: Optional[int] = None, ml_optimized: bool = False):
    """
    Compatible API:
    df, stats, xlsx, csv = generate_etalon_Pharma_dataset(seed=42, ml_optimized=True)
    """
    gen = _make_generator_from_flag(seed, ml_optimized)
    df = gen.generate_etalon_based_dataset(350, 350, 350)
    stats = gen.quality_check(df)
    # имена файлов — стабильные
    base = 'optimized' if ml_optimized else 'baseline'
    excel_path = f'Pharma_poc_enhanced_1050_{base}.xlsx'
    csv_path = f'Pharma_poc_enhanced_1050_{base}.csv'
    gen.save_to_excel(df, excel_path)
    gen.save_to_csv(df, csv_path)
    return df, stats, excel_path, csv_path

def generate_custom_dataset(step1_count: int = 100,
                            step2_count: int = 100,
                            step3_count: int = 100,
                            seed: Optional[int] = None,
                            ml_optimized: bool = False):
    """
    Compatible API:
    df, stats, xlsx, csv = generate_custom_dataset(500, 250, 250, ml_optimized=True)
    """
    gen = _make_generator_from_flag(seed, ml_optimized)
    df = gen.generate_etalon_based_dataset(step1_count, step2_count, step3_count)
    stats = gen.quality_check(df)
    total = step1_count + step2_count + step3_count
    base = 'optimized' if ml_optimized else 'baseline'
    excel_path = f'Pharma_poc_custom_{total}_{base}.xlsx'
    csv_path = f'Pharma_poc_custom_{total}_{base}.csv'
    gen.save_to_excel(df, excel_path)
    gen.save_to_csv(df, csv_path)
    print(f"[Custom Generation] {step1_count}+{step2_count}+{step3_count} = {total} records")
    print(f"Files created:\n - {excel_path}\n - {csv_path}")
    return df, stats, excel_path, csv_path

def analyze_existing_data(filename: str) -> Optional[Dict[str, Any]]:
    try:
        df = pd.read_excel(filename)
    except Exception as e:
        print(f"❌ Could not read file '{filename}': {e}")
        return None
    gen = EnhancedPharmaDataGenerator(seed=0)
    stats = gen.quality_check(df)
    try:
        mem_kb = df.memory_usage(deep=True).sum() / 1024
    except Exception:
        mem_kb = None
    print("\n[Existing File Analysis]")
    print(f"  • Filename: {filename}")
    print(f"  • Rows × Cols: {df.shape[0]} × {df.shape[1]}")
    if mem_kb is not None:
        print(f"  • Approx size in memory: {mem_kb:.1f} KB")
    print(f"  • Case distribution: {stats.get('case_distribution', {})}")
    print(f"  • Quality score: {stats.get('quality_score', 'NA')}%")
    return stats


# ------------------------------------------------
#3) Pipeline (as before, compatible artifacts)
# ------------------------------------------------
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import shap
import matplotlib.pyplot as plt
import seaborn as sns

class PharmaPoCMLPipeline:      # Baseline w Random Forest
    def __init__(self, input_file):
        self.input_file = input_file
        self.df = None
        self.train_data = None
        self.step2_data = None
        self.step3_data = None
        self.model1_ah = RandomForestClassifier(n_estimators=400, random_state=42)
        self.model1_ai = RandomForestClassifier(n_estimators=400, random_state=42, class_weight='balanced')
        self.model2_ah = RandomForestClassifier(n_estimators=300, random_state=42)
        self.model2_ai = RandomForestClassifier(n_estimators=300, random_state=42, class_weight='balanced')
        self.label_encoders = {}
        self.target_encoders = {}
        self.explainers = {}
        self.model1_features = []
        self.model2_features = []
        print("🚀 Pharma PoC ML Pipeline initialized (rule-first)")

    def load_and_prepare_data(self):
        print("📁 Loading data from:", self.input_file)
        self.df = pd.read_excel(self.input_file)
        print(f"✅ Loaded {len(self.df)} records")
        self.train_data = self.df[self.df['Case'] == 'Step 1 - Train'].copy()
        self.step2_data = self.df[self.df['Case'] == 'Step 2 - Suggest'].copy()
        self.step3_data = self.df[self.df['Case'] == 'Step 3 - Forecast'].copy()
        case_dist = self.df['Case'].value_counts()
        print("📊 Case distribution:")
        for case, count in case_dist.items():
            pct = (count/len(self.df))*100
            print(f"   {case}: {count} ({pct:.1f}%)")
        print(f"🎯 Training data: {len(self.train_data)}")
        print(f"🎯 Step 2 (Suggest): {len(self.step2_data)}")
        print(f"🎯 Step 3 (Forecast): {len(self.step3_data)}")

    def identify_significant_features(self):
        print("\n🔍 Identifying significant features...")
        all_columns = list(self.df.columns)
        exclude_fields = [
            'ADJ DATE MEDICARE TXN DATE/TIME','RX NBR PRESCRIPTION NUMBER','CLAIM ID MEDICARE ICN/AUTH #',
            'Medispan/FDB Effective Date','Medispan/FDB Termination Date','835 report Check Number',
            '835 report Refill Number','835 report Estimated refund from adjudication',
            '835 report Adjustment Amount','835 report Adjustment Quantity','Case'
        ]
        copy_fields = [
            'Medispan/FDB WAC Price','835 report Claim Number','835 report Pharmacy Number',
            '835 report Rx Number','835 report Date Filled/Date of Service','835 report Quantity'
        ]
        targets = ['835 report Qualifier Code / RARC codes','Expected Outcomes Error category','Expected Outcomes Clerk Input',' Questions/comments']
        exclude_all = exclude_fields + copy_fields + targets
        m1_candidates = [c for c in all_columns if c not in exclude_all]
        m2_exclude_835 = [c for c in all_columns if c.startswith('835 report')]
        m2_candidates = [c for c in all_columns if c not in (exclude_all + m2_exclude_835)]

        def has_var(col): return self.df[col].nunique(dropna=False) > 1
        self.model1_features = [c for c in m1_candidates if has_var(c)]
        self.model2_features = [c for c in m2_candidates if has_var(c)]

        print(f"📊 Model 1 base features ({len(self.model1_features)}):")
        for i, f in enumerate(self.model1_features, 1): print(f"{i:>4}. {f}")
        print(f"\n📊 Model 2 base features ({len(self.model2_features)}):")
        for i, f in enumerate(self.model2_features, 1): print(f"{i:>4}. {f}")

        self.engineered_flags = [
            '_month','_quarter','_is_wac_rule','_has_fixed_unit_price','_disc_is_high18',
            '_disc_lt3','_qty_is_1','_est_lt_1200','_q_in_14','_is_slow_mfg',
            '_is_immuno_or_derm','_is_cardio_or_diab','_est_vs_wac_amt','_mfp_vs_wac_unit',
            '_rule_wac18_immunoDerm_qty1','_rule_fixed_slow_q14_qty1'
        ]
        print("\n➕ Engineered features to be added at transform-time:")
        print("   " + ", ".join(self.engineered_flags))

    def _compute_engineered(self, df: pd.DataFrame) -> pd.DataFrame:
        def month_from_ddmmyyyy(s: str) -> int:
            try: return int(str(s).split('/')[1])
            except: return 1
        m = df['SERV DATE DATE OF SERVICE'].astype(str).map(month_from_ddmmyyyy)
        q = ((m-1)//3 + 1).astype(int)
        is_wac = (df['MFP RULE PRC POINT MFP Rule Price Point'].astype(str) == 'WAC')
        disc = pd.to_numeric(df['MFP RULE DISC MFP Rule Disc %'], errors='coerce').fillna(0.0)
        has_fixed = (df['MFP RULE UNIT PRICE MFP Rule Unit Price'].astype(str)!='')
        qty1 = (pd.to_numeric(df['QTY DISP QUANTITY DISPENSED'], errors='coerce').fillna(0)==1)
        est_amt = pd.to_numeric(df['EST REIMB AMT Estimated MFG Reimbursement Amt'], errors='coerce').fillna(0.0)
        wac_unit = pd.to_numeric(df['WAC UNIT PRICE WAC UNIT PRICE'], errors='coerce').fillna(1.0)
        mfp_unit = pd.to_numeric(df['MFP UNIT PRICE MFP UNIT PRICE'], errors='coerce').fillna(1.0)
        qty = pd.to_numeric(df['QTY DISP QUANTITY DISPENSED'], errors='coerce').fillna(1.0)
        immuno_derm_set = {'Immunosuppressant','Dermatology'}
        cardio_diab_set = {'Cardiovascular','Diabetes'}
        slow_mfg_set = {'JANSSEN BIOTECH','AMGEN/ IMMUNEX','BRISTOL MYERS SQUIBB'}
        def infer_class(drug: str) -> str:
            drug = str(drug)
            if any(k in drug for k in ['STELARA','ENBREL','HUMIRA']): return 'Immunosuppressant'
            if any(k in drug for k in ['DUPIXENT','SKYRIZI']): return 'Dermatology'
            if 'ENTRESTO' in drug: return 'Cardiovascular'
            if any(k in drug for k in ['OZEMPIC','FIASP']): return 'Diabetes'
            return 'General'
        th_cls = df['DRUG NAME N/A'].astype(str).map(infer_class)
        is_immuno_or_derm = th_cls.isin(immuno_derm_set)
        is_cardio_or_diab = th_cls.isin(cardio_diab_set)
        is_slow_mfg = df['MFG NAME MFG Name'].isin(slow_mfg_set)
        npi_prof_map = {
            1013998921:'careful',1518039858:'careful',1043382302:'sloppy',1063510097:'sloppy',
            1326029232:'neutral',1629341177:'neutral',1578836698:'neutral',1234567890:'sloppy',
            1345678901:'careful',1456789012:'neutral',1567890123:'sloppy',1678901234:'neutral'
        }
        npi_prof = df['NPI PHARMACY NPI'].map(npi_prof_map).fillna('neutral')
        npi_sloppy = (npi_prof=='sloppy')
        est_vs_wac_amt = est_amt / (wac_unit*qty + 1e-9)
        mfp_vs_wac_unit = mfp_unit / (wac_unit + 1e-9)
        engineered = pd.DataFrame({
            '_month': m.astype(int),
            '_quarter': q.astype(int),
            '_is_wac_rule': is_wac.astype(int),
            '_has_fixed_unit_price': has_fixed.astype(int),
            '_disc_is_high18': (disc>=18.0).astype(int),
            '_disc_lt3': (disc<3.0).astype(int),
            '_qty_is_1': qty1.astype(int),
            '_est_lt_1200': (est_amt<1200.0).astype(int),
            '_q_in_14': q.isin([1,4]).astype(int),
            '_is_slow_mfg': is_slow_mfg.astype(int),
            '_is_immuno_or_derm': is_immuno_or_derm.astype(int),
            '_is_cardio_or_diab': is_cardio_or_diab.astype(int),
            '_est_vs_wac_amt': est_vs_wac_amt.values,
            '_mfp_vs_wac_unit': mfp_vs_wac_unit.values
        }, index=df.index)
        engineered['_rule_wac18_immunoDerm_qty1'] = (
            engineered['_is_wac_rule'].eq(1) &
            engineered['_disc_is_high18'].eq(1) &
            engineered['_is_immuno_or_derm'].eq(1) &
            engineered['_qty_is_1'].eq(1)
        ).astype(int)
        engineered['_rule_fixed_slow_q14_qty1'] = (
            engineered['_has_fixed_unit_price'].eq(1) &
            engineered['_is_slow_mfg'].eq(1) &
            engineered['_q_in_14'].eq(1) &
            engineered['_qty_is_1'].eq(1)
        ).astype(int)
        return engineered

    def _fit_transform_encoders(self, X: pd.DataFrame, fit: bool) -> pd.DataFrame:
        Xc = X.copy().fillna('__MISSING__')
        for col in Xc.columns:
            if Xc[col].dtype == 'object' or Xc[col].dtype.name == 'category':
                if fit:
                    le = self.label_encoders.get(col, LabelEncoder())
                    Xc[col] = le.fit_transform(Xc[col].astype(str))
                    self.label_encoders[col] = le
                else:
                    if col in self.label_encoders:
                        le = self.label_encoders[col]
                        vals = Xc[col].astype(str)
                        known = set(le.classes_)
                        unseen = set(vals.unique()) - known
                        if unseen:
                            most = le.classes_[0]
                            vals = vals.where(vals.isin(known), most)
                        Xc[col] = le.transform(vals)
                    else:
                        tmp = LabelEncoder()
                        Xc[col] = tmp.fit_transform(Xc[col].astype(str))
            else:
                Xc[col] = pd.to_numeric(Xc[col], errors='coerce').fillna(0)
        return Xc

    def _prepare_X(self, df: pd.DataFrame, base_feats: List[str], fit=False) -> pd.DataFrame:
        eng = self._compute_engineered(df)
        X = pd.concat([df[base_feats].copy(), eng], axis=1)
        Xenc = self._fit_transform_encoders(X, fit=fit)
        return Xenc

    def prepare_targets(self, data, fit_encoders=False):
        ah = data['835 report Qualifier Code / RARC codes'].fillna('__MISSING__').astype(str)
        ai = data['Expected Outcomes Error category'].fillna('__MISSING__').astype(str)
        if fit_encoders:
            self.target_encoders['AH'] = LabelEncoder()
            self.target_encoders['AI'] = LabelEncoder()
            y_ah = self.target_encoders['AH'].fit_transform(ah)
            y_ai = self.target_encoders['AI'].fit_transform(ai)
        else:
            y_ah = self.target_encoders['AH'].transform(ah)
            y_ai = self.target_encoders['AI'].transform(ai)
        return y_ah, y_ai

    def train_models(self):
        print("\n🎓 Training models...")
        X1 = self._prepare_X(self.train_data, self.model1_features, fit=True)
        X2 = self._prepare_X(self.train_data, self.model2_features, fit=False)
        y_ah, y_ai = self.prepare_targets(self.train_data, fit_encoders=True)
        print(f"📊 Training shapes: X1={X1.shape}, X2={X2.shape}")
        print(f"   AH classes: {len(np.unique(y_ah))}, AI classes: {len(np.unique(y_ai))}")

        X1_tr, X1_val, yah_tr, yah_val = train_test_split(X1, y_ah, test_size=0.2, random_state=42, stratify=y_ah)
        _,   _,    yai_tr, yai_val = train_test_split(X1, y_ai, test_size=0.2, random_state=42, stratify=y_ai)
        X2_tr, X2_val, _, _ = train_test_split(X2, y_ah, test_size=0.2, random_state=42, stratify=y_ah)

        self.model1_ah.fit(X1_tr, yah_tr); self.model1_ai.fit(X1_tr, yai_tr)
        self.model2_ah.fit(X2_tr, yah_tr); self.model2_ai.fit(X2_tr, yai_tr)
        print("✅ Models trained")

        print("\n📊 Validation Performance (model-only):")
        y1_ah_pred = self.model1_ah.predict(X1_val)
        y1_ai_pred = self.model1_ai.predict(X1_val)
        y2_ah_pred = self.model2_ah.predict(X2_val)
        y2_ai_pred = self.model2_ai.predict(X2_val)
        print(f"\n🔸 Model 1 AH Accuracy: {accuracy_score(yah_val, y1_ah_pred):.3f}")
        print(f"🔸 Model 1 AI Accuracy: {accuracy_score(yai_val, y1_ai_pred):.3f}")
        print(f"🔸 Model 2 AH Accuracy: {accuracy_score(yah_val, y2_ah_pred):.3f}")
        print(f"🔸 Model 2 AI Accuracy: {accuracy_score(yai_val, y2_ai_pred):.3f}")

        self._create_confusion_matrices(yah_val, yai_val, y1_ah_pred, y1_ai_pred, y2_ah_pred, y2_ai_pred)

        print("\n📈 Cross-validation (5-fold) on full training set:")
        for name, mdl, Xd, yd in [
            ('Model 1 AH', self.model1_ah, X1, y_ah),
            ('Model 1 AI', self.model1_ai, X1, y_ai),
            ('Model 2 AH', self.model2_ah, X2, y_ah),
            ('Model 2 AI', self.model2_ai, X2, y_ai),
        ]:
            scores = cross_val_score(mdl, Xd, yd, cv=5)
            print(f"   {name}: {scores.mean():.3f} ± {scores.std():.3f}")

        self.model1_ah.fit(X1, y_ah); self.model1_ai.fit(X1, y_ai)
        self.model2_ah.fit(X2, y_ah); self.model2_ai.fit(X2, y_ai)

    def _create_confusion_matrices(self, y_ah_true, y_ai_true, y_ah_pred1, y_ai_pred1, y_ah_pred2, y_ai_pred2):
        print("\n📊 Creating confusion matrices...")
        try:
            fig, axes = plt.subplots(2, 2, figsize=(15, 12))
            ah_labels = self.target_encoders['AH'].classes_
            ai_labels = self.target_encoders['AI'].classes_

            cm1_ah = confusion_matrix(y_ah_true, y_ah_pred1)
            sns.heatmap(cm1_ah, annot=True, fmt='d', cmap='Blues',
                        xticklabels=ah_labels, yticklabels=ah_labels, ax=axes[0,0])
            axes[0,0].set_title('Model 1 - AH (A–AG)', fontweight='bold')

            cm1_ai = confusion_matrix(y_ai_true, y_ai_pred1)
            sns.heatmap(cm1_ai, annot=True, fmt='d', cmap='Reds',
                        xticklabels=[l[:15]+'...' if len(l)>15 else l for l in ai_labels],
                        yticklabels=[l[:15]+'...' if len(l)>15 else l for l in ai_labels], ax=axes[0,1])
            axes[0,1].set_title('Model 1 - AI (A–AG)', fontweight='bold'); axes[0,1].tick_params(axis='x', rotation=45)

            cm2_ah = confusion_matrix(y_ah_true, y_ah_pred2)
            sns.heatmap(cm2_ah, annot=True, fmt='d', cmap='Greens',
                        xticklabels=ah_labels, yticklabels=ah_labels, ax=axes[1,0])
            axes[1,0].set_title('Model 2 - AH (A–T)', fontweight='bold')

            cm2_ai = confusion_matrix(y_ai_true, y_ai_pred2)
            sns.heatmap(cm2_ai, annot=True, fmt='d', cmap='Purples',
                        xticklabels=[l[:15]+'...' if len(l)>15 else l for l in ai_labels],
                        yticklabels=[l[:15]+'...' if len(l)>15 else l for l in ai_labels], ax=axes[1,1])
            axes[1,1].set_title('Model 2 - AI (A–T)', fontweight='bold'); axes[1,1].tick_params(axis='x', rotation=45)

            plt.tight_layout()
            plt.savefig('confusion_matrices_validation.png', dpi=300, bbox_inches='tight', facecolor='white')
            plt.close()
            print("   ✅ Saved: confusion_matrices_validation.png")
        except Exception as e:
            print(f"   ⚠️ Confusion matrix creation warning: {e}")

    def analyze_feature_importance(self):
        print("\n🔍 SHAP explainers (TreeExplainer for RF)...")
        try:
            X1 = self._prepare_X(self.train_data, self.model1_features, fit=False)
            X2 = self._prepare_X(self.train_data, self.model2_features, fit=False)
            self.explainers['model1_ah'] = shap.TreeExplainer(self.model1_ah)
            self.explainers['model1_ai'] = shap.TreeExplainer(self.model1_ai)
            self.explainers['model2_ah'] = shap.TreeExplainer(self.model2_ah)
            self.explainers['model2_ai'] = shap.TreeExplainer(self.model2_ai)

            def shap_summary(mdl_key, X, title, out_png):
                expl = self.explainers[mdl_key]
                sv = expl.shap_values(X)
                if isinstance(sv, list):
                    sv_to_plot = sv[0]
                else:
                    sv_to_plot = sv
                plt.figure(figsize=(12,8))
                shap.summary_plot(sv_to_plot, X, show=False, max_display=20)
                plt.title(title, fontsize=14, fontweight='bold')
                plt.tight_layout()
                plt.savefig(out_png, dpi=300, bbox_inches='tight'); plt.close()
                print(f"   ✅ {out_png}")

            shap_summary('model1_ah', X1, 'SHAP Summary - Model 1 (AH)', 'shap_model1_ah_summary.png')
            shap_summary('model1_ai', X1, 'SHAP Summary - Model 1 (AI)', 'shap_model1_ai_summary.png')
            shap_summary('model2_ah', X2, 'SHAP Summary - Model 2 (AH)', 'shap_model2_ah_summary.png')
            shap_summary('model2_ai', X2, 'SHAP Summary - Model 2 (AI)', 'shap_model2_ai_summary.png')

            fig, axes = plt.subplots(2,2, figsize=(16,12))
            def top_imp(ax, model, feat_names, title):
                imp = pd.Series(model.feature_importances_, index=feat_names).sort_values(ascending=False).head(15)
                ax.barh(imp.index[::-1], imp.values[::-1]); ax.set_title(title, fontweight='bold'); ax.set_xlabel('Importance')
            top_imp(axes[0,0], self.model1_ah, X1.columns, 'Model 1 - AH (Top 15)')
            top_imp(axes[0,1], self.model1_ai, X1.columns, 'Model 1 - AI (Top 15)')
            top_imp(axes[1,0], self.model2_ah, X2.columns, 'Model 2 - AH (Top 15)')
            top_imp(axes[1,1], self.model2_ai, X2.columns, 'Model 2 - AI (Top 15)')
            plt.tight_layout(); plt.savefig('feature_importance_comparison.png', dpi=300, bbox_inches='tight'); plt.close()
            print("   ✅ feature_importance_comparison.png")
        except Exception as e:
            print(f"⚠️ SHAP/importance warning: {e}")

    def _rule_first_predict_ai(self, row: pd.Series) -> Optional[str]:
        tmp = self._compute_engineered(row.to_frame().T).iloc[0]
        if tmp['_rule_wac18_immunoDerm_qty1'] == 1:
            return 'Write off due to WAC UNIT PRICE diff from WAC Medispan'
        if tmp['_rule_fixed_slow_q14_qty1'] == 1:
            return 'Sent to collection due to MFG timing difference'
        if (tmp['_is_cardio_or_diab']==1) and (tmp['_disc_lt3']==1) and (tmp['_est_lt_1200']==1):
            sloppy_npies = {1043382302,1063510097,1234567890,1567890123}
            if row['NPI PHARMACY NPI'] in sloppy_npies:
                return 'Write off due to payment applied to the wrong claim'
        return None

    def _rule_first_predict_ah(self, ai_label: str) -> Optional[str]:
        mapping = {
            'Write off due to manual data entry error': 'N907',
            'Write off due to payment applied to the wrong claim': 'N908',
            'Write off due to WAC UNIT PRICE diff from WAC Medispan': 'N910',
            'Write off due to duplicate payment': 'N911',
            'Sent to collection due to MFG timing difference': 'N910'
        }
        return mapping.get(ai_label)

    def apply_predictions(self):
        print("\n🔮 Applying predictions (rule-first)...")
        X2 = self._prepare_X(self.step2_data, self.model1_features, fit=False)
        pred_ah_2 = self.model1_ah.predict(X2)
        pred_ai_2 = self.model1_ai.predict(X2)
        ah_labels_2 = self.target_encoders['AH'].inverse_transform(pred_ah_2)
        ai_labels_2 = self.target_encoders['AI'].inverse_transform(pred_ai_2)
        for i, idx in enumerate(self.step2_data.index):
            row = self.df.loc[idx]
            ai_rule = self._rule_first_predict_ai(row)
            if ai_rule is not None:
                ai_labels_2[i] = ai_rule
                ah_labels_2[i] = self._rule_first_predict_ah(ai_rule)
        self.df.loc[self.step2_data.index, '835 report Qualifier Code / RARC codes'] = ah_labels_2
        self.df.loc[self.step2_data.index, 'Expected Outcomes Error category'] = ai_labels_2
        self.df.loc[self.step2_data.index, ' Questions/comments'] = [
            f"ML/Rule Prediction: AH='{ah_labels_2[i]}', AI='{ai_labels_2[i]}'" for i in range(len(ah_labels_2))
        ]
        print(f"✅ Updated {len(self.step2_data)} Step 2 rows")

        X3 = self._prepare_X(self.step3_data, self.model2_features, fit=False)
        pred_ah_3 = self.model2_ah.predict(X3)
        pred_ai_3 = self.model2_ai.predict(X3)
        ah_labels_3 = self.target_encoders['AH'].inverse_transform(pred_ah_3)
        ai_labels_3 = self.target_encoders['AI'].inverse_transform(pred_ai_3)
        for i, idx in enumerate(self.step3_data.index):
            row = self.df.loc[idx]
            ai_rule = self._rule_first_predict_ai(row)
            if ai_rule is not None:
                ai_labels_3[i] = ai_rule
                ah_labels_3[i] = self._rule_first_predict_ah(ai_rule)
        self.df.loc[self.step3_data.index, '835 report Qualifier Code / RARC codes'] = ah_labels_3
        self.df.loc[self.step3_data.index, 'Expected Outcomes Error category'] = ai_labels_3
        self.df.loc[self.step3_data.index, ' Questions/comments'] = [
            f"ML/Rule Prediction: AH='{ah_labels_3[i]}', AI='{ai_labels_3[i]}'" for i in range(len(ah_labels_3))
        ]
        print(f"✅ Updated {len(self.step3_data)} Step 3 rows")

        def cnt(arr): 
            u,c = np.unique(arr, return_counts=True); return dict(zip(u,c))
        print("\n📊 Prediction Summary:")
        print(f"Step 2 AH: {cnt(ah_labels_2)}")
        print(f"Step 2 AI: {cnt(ai_labels_2)}")
        print(f"Step 3 AH: {cnt(ah_labels_3)}")
        print(f"Step 3 AI: {cnt(ai_labels_3)}")

    def save_results(self, output_file='Pharma_poc_ml_results.xlsx'):
        print(f"\n💾 Saving results to {output_file}...")
        with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
            self.df.to_excel(writer, sheet_name='Scenario Data Structure', index=False)
            worksheet = writer.sheets['Scenario Data Structure']
            for column in worksheet.columns:
                max_length = 0
                column_letter = column[0].column_letter
                for cell in column:
                    try:
                        if len(str(cell.value)) > max_length:
                            max_length = len(str(cell.value))
                    except:
                        pass
                adjusted_width = min(max_length + 2, 100)
                worksheet.column_dimensions[column_letter].width = adjusted_width
        print("✅ Results saved")
        print("\n📊 Artifacts:")
        print("   📈 shap_model1_ah_summary.png")
        print("   📈 shap_model1_ai_summary.png")
        print("   📈 shap_model2_ah_summary.png")
        print("   📈 shap_model2_ai_summary.png")
        print("   📊 feature_importance_comparison.png")
        print("   🎯 confusion_matrices_validation.png")
        return output_file

    def run_full_pipeline(self, output_file='Pharma_poc_ml_results.xlsx'):
        print("🚀 STARTING Pharma PoC ML PIPELINE")
        print("="*60)
        print("📋 Steps:")
        print("  1) Load data")
        print("  2) Identify significant features")
        print("  3) Train models")
        print("  4) SHAP & feature importance")
        print("  5) Apply predictions with rule-first overrides")
        print("  6) Save results")
        print("="*60)
        self.load_and_prepare_data()
        self.identify_significant_features()
        self.train_models()
        self.analyze_feature_importance()
        self.apply_predictions()
        result_file = self.save_results(output_file)
        print("\n🎉 PIPELINE COMPLETED")
        print(f"📁 Results saved to: {result_file}")
        return result_file


# ------------------------------------------------
# 4) compare datasets — comparison of "original vs improved"
# ------------------------------------------------
def _compute_engineered_for_compare(df: pd.DataFrame) -> pd.DataFrame:
    # repeats logic from pipeline (without class)
    def month_from_ddmmyyyy(s: str) -> int:
        try: return int(str(s).split('/')[1])
        except: return 1
    m = df['SERV DATE DATE OF SERVICE'].astype(str).map(month_from_ddmmyyyy)
    q = ((m-1)//3 + 1).astype(int)
    is_wac = (df['MFP RULE PRC POINT MFP Rule Price Point'].astype(str) == 'WAC')
    disc = pd.to_numeric(df['MFP RULE DISC MFP Rule Disc %'], errors='coerce').fillna(0.0)
    has_fixed = (df['MFP RULE UNIT PRICE MFP Rule Unit Price'].astype(str)!='')
    qty1 = (pd.to_numeric(df['QTY DISP QUANTITY DISPENSED'], errors='coerce').fillna(0)==1)
    est_amt = pd.to_numeric(df['EST REIMB AMT Estimated MFG Reimbursement Amt'], errors='coerce').fillna(0.0)
    wac_unit = pd.to_numeric(df['WAC UNIT PRICE WAC UNIT PRICE'], errors='coerce').fillna(1.0)
    mfp_unit = pd.to_numeric(df['MFP UNIT PRICE MFP UNIT PRICE'], errors='coerce').fillna(1.0)
    qty = pd.to_numeric(df['QTY DISP QUANTITY DISPENSED'], errors='coerce').fillna(1.0)
    immuno_derm_set = {'Immunosuppressant','Dermatology'}
    cardio_diab_set = {'Cardiovascular','Diabetes'}
    slow_mfg_set = {'JANSSEN BIOTECH','AMGEN/ IMMUNEX','BRISTOL MYERS SQUIBB'}
    def infer_class(drug: str) -> str:
        drug = str(drug)
        if any(k in drug for k in ['STELARA','ENBREL','HUMIRA']): return 'Immunosuppressant'
        if any(k in drug for k in ['DUPIXENT','SKYRIZI']): return 'Dermatology'
        if 'ENTRESTO' in drug: return 'Cardiovascular'
        if any(k in drug for k in ['OZEMPIC','FIASP']): return 'Diabetes'
        return 'General'
    th_cls = df['DRUG NAME N/A'].astype(str).map(infer_class)
    is_immuno_or_derm = th_cls.isin(immuno_derm_set)
    is_cardio_or_diab = th_cls.isin(cardio_diab_set)
    is_slow_mfg = df['MFG NAME MFG Name'].isin(slow_mfg_set)
    engineered = pd.DataFrame({
        '_is_wac_rule': is_wac.astype(int),
        '_disc_is_high18': (disc>=18.0).astype(int),
        '_qty_is_1': qty1.astype(int),
        '_has_fixed_unit_price': has_fixed.astype(int),
        '_q_in_14': q.isin([1,4]).astype(int),
        '_is_slow_mfg': is_slow_mfg.astype(int),
        '_is_immuno_or_derm': is_immuno_or_derm.astype(int),
        '_is_cardio_or_diab': is_cardio_or_diab.astype(int),
        '_est_lt_1200': (est_amt<1200.0).astype(int),
        '_disc_lt3': (disc<3.0).astype(int),
        '_est_vs_wac_amt': (est_amt / (wac_unit*qty + 1e-9)).values
    }, index=df.index)
    engineered['_rule_wac18_immunoDerm_qty1'] = (
        engineered['_is_wac_rule'].eq(1) &
        engineered['_disc_is_high18'].eq(1) &
        engineered['_is_immuno_or_derm'].eq(1) &
        engineered['_qty_is_1'].eq(1)
    ).astype(int)
    engineered['_rule_fixed_slow_q14_qty1'] = (
        engineered['_has_fixed_unit_price'].eq(1) &
        engineered['_is_slow_mfg'].eq(1) &
        engineered['_q_in_14'].eq(1) &
        engineered['_qty_is_1'].eq(1)
    ).astype(int)
    return engineered

def _dataset_quick_stats(df: pd.DataFrame) -> Dict[str, Any]:
    out = {}
    out['rows'] = len(df)
    out['case_distribution'] = df['Case'].value_counts().to_dict()
    out['ah_distribution'] = df['835 report Qualifier Code / RARC codes'].value_counts().to_dict()
    out['ai_distribution'] = df['Expected Outcomes Error category'].value_counts().to_dict()
    filled = df[df['Case']!='Step 3 - Forecast']
    if len(filled)>0:
        est = pd.to_numeric(filled['EST REIMB AMT Estimated MFG Reimbursement Amt'], errors='coerce')
        act = pd.to_numeric(filled['835 report Actual Payment Amount'], errors='coerce')
        valid = (~est.isna()) & (~act.isna())
        if valid.sum()>0:
            under = ((est[valid]-act[valid])/est[valid]*100)
            out['underpayment_avg_pct'] = round(float(under.mean()),2)
            out['underpayment_minmax_pct'] = [round(float(under.min()),2), round(float(under.max()),2)]
    eng = _compute_engineered_for_compare(df)
    out['rule_hit_pct_wac18_immunoDerm_qty1'] = round(100*eng['_rule_wac18_immunoDerm_qty1'].mean(),1)
    out['rule_hit_pct_fixed_slow_q14_qty1'] = round(100*eng['_rule_fixed_slow_q14_qty1'].mean(),1)
    # rough estimate of the "balance" of AI classes
    if out['ai_distribution']:
        counts = np.array(list(out['ai_distribution'].values()), dtype=float)
        out['ai_balance_cv'] = round(float(counts.std()/ (counts.mean()+1e-9)),3)
    return out

def compare_datasets(original_file: str, enhanced_file: str) -> Dict[str, Any]:
    """
    Compares 2 Excel files with the same format ('Scenario Data Structure' sheet by default).
    Returns a dictionary with metrics and writes comparison_report.xlsx
    """
    try:
        df_orig = pd.read_excel(original_file)
        df_new = pd.read_excel(enhanced_file)
    except Exception as e:
        print(f"❌ compare_datasets: can't read files: {e}")
        return {}
    stats_orig = _dataset_quick_stats(df_orig)
    stats_new  = _dataset_quick_stats(df_new)

    # сводка
    summary = {
        'rows_diff': stats_new['rows'] - stats_orig['rows'],
        'ai_classes_orig': list(stats_orig['ai_distribution'].keys()),
        'ai_classes_new': list(stats_new['ai_distribution'].keys()),
        'rule_hit_wac18_change_pp': round(stats_new['rule_hit_pct_wac18_immunoDerm_qty1'] - stats_orig['rule_hit_pct_wac18_immunoDerm_qty1'], 1),
        'rule_hit_fixed_slow_change_pp': round(stats_new['rule_hit_pct_fixed_slow_q14_qty1'] - stats_orig['rule_hit_pct_fixed_slow_q14_qty1'], 1),
        'ai_balance_cv_change': round(stats_new.get('ai_balance_cv',0) - stats_orig.get('ai_balance_cv',0), 3),
        'underpayment_avg_change_pp': round(stats_new.get('underpayment_avg_pct',0) - stats_orig.get('underpayment_avg_pct',0), 2),
    }
    result = {'summary': summary, 'original': stats_orig, 'enhanced': stats_new}

    # repack в Excel-отчёт
    try:
        with pd.ExcelWriter('comparison_report.xlsx', engine='openpyxl') as writer:
            pd.DataFrame([stats_orig]).to_excel(writer, sheet_name='original', index=False)
            pd.DataFrame([stats_new]).to_excel(writer, sheet_name='enhanced', index=False)
            # частоты классов
            ai_df = pd.DataFrame({'original': stats_orig['ai_distribution']}).join(
                    pd.DataFrame({'enhanced': stats_new['ai_distribution']}), how='outer').fillna(0).astype(int)
            ai_df.to_excel(writer, sheet_name='ai_distribution', index=True)
    except Exception as e:
        print("⚠️ failed to write comparison_report.xlsx:", e)

    print("✅ compare_datasets: done. See 'comparison_report.xlsx'")
    return result


# ------------------------------------------------
# 5) Generation Strategies Benchmark + Fast ML Assessment
# ------------------------------------------------
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

def _quick_eval_rf(df: pd.DataFrame, random_state: int = 42) -> Dict[str,float]:
    """
    Быстрая метрика: берём Step1 для train/val (80/20), строим признаки как в пайплайне,
    RF для AH/AI, возвращаем hold-out accuracy.
    """
    train_df = df[df['Case']=='Step 1 - Train'].copy()
    if len(train_df) < 100:
        return {'ah_acc': np.nan, 'ai_acc': np.nan}
    pipe = PharmaPoCMLPipeline(input_file=None)
    pipe.df = train_df
    pipe.train_data = train_df
    pipe.step2_data = pd.DataFrame(columns=train_df.columns)
    pipe.step3_data = pd.DataFrame(columns=train_df.columns)
    pipe.identify_significant_features()

    X = pipe._prepare_X(train_df, pipe.model1_features, fit=True)
    y_ah, y_ai = pipe.prepare_targets(train_df, fit_encoders=True)
    X_tr, X_va, yah_tr, yah_va = train_test_split(X, y_ah, test_size=0.2, random_state=random_state, stratify=y_ah)
    _,    _, yai_tr, yai_va = train_test_split(X, y_ai, test_size=0.2, random_state=random_state, stratify=y_ai)
    mdl_ah = RandomForestClassifier(n_estimators=300, random_state=random_state).fit(X_tr, yah_tr)
    mdl_ai = RandomForestClassifier(n_estimators=300, random_state=random_state, class_weight='balanced').fit(X_tr, yai_tr)
    ah_acc = accuracy_score(yah_va, mdl_ah.predict(X_va))
    ai_acc = accuracy_score(yai_va, mdl_ai.predict(X_va))
    return {'ah_acc': round(float(ah_acc),3), 'ai_acc': round(float(ai_acc),3)}

def generate_ml_benchmark_datasets(seed: int = 42) -> pd.DataFrame:
    """
    Generates 3 sets (baseline/medium/optimized), saves XLSX/CSV, gives a quick RF estimate.
    Returns a DataFrame with the results and writes 'ml_benchmark_results.csv'.
    """
    configs = [
        ('baseline', 0.6, 0.03),
        ('medium',   0.85, 0.015),
        ('optimized',1.00, 0.005),
    ]
    rows = []
    for name, sig, noise in configs:
        gen = EnhancedPharmaDataGenerator(seed=seed, signal_strength=sig, noise_level=noise)
        df = gen.generate_etalon_based_dataset(350, 350, 350)
        stats = gen.quality_check(df)
        xlsx = f'Pharma_poc_benchmark_{name}.xlsx'
        csv  = f'Pharma_poc_benchmark_{name}.csv'
        gen.save_to_excel(df, xlsx); gen.save_to_csv(df, csv)
        eval_metrics = _quick_eval_rf(df, random_state=seed)
        rows.append({
            'name': name,
            'signal_strength': sig,
            'noise_level': noise,
            'rows': len(df),
            'quality_score': stats.get('quality_score'),
            'rule_wac18_immunoDerm_qty1_%': stats.get('pricing_stats',{}).get('wac_rules_pct', np.nan),  # просто как справка
            'ah_acc_val': eval_metrics['ah_acc'],
            'ai_acc_val': eval_metrics['ai_acc'],
            'xlsx': xlsx, 'csv': csv
        })
        print(f"[Benchmark] {name}: AH={eval_metrics['ah_acc']}, AI={eval_metrics['ai_acc']} -> {xlsx}")

    res = pd.DataFrame(rows)
    try:
        res.to_csv('ml_benchmark_results.csv', index=False)
    except Exception as e:
        print("⚠️ failed to write ml_benchmark_results.csv:", e)
    return res


# ------------------------------------------------
# 6) Утилиты запуска (как раньше)
# ------------------------------------------------
def run_Pharma_ml_analysis(input_file='Pharma_poc_custom_1000.xlsx', output_file='Pharma_poc_ml_results.xlsx'):
    pipeline = PharmaPoCMLPipeline(input_file)
    result_file = pipeline.run_full_pipeline(output_file)
    return result_file

def create_custom_waterfall_plots(input_file, row_indices, output_dir='custom_waterfalls'):
    os.makedirs(output_dir, exist_ok=True)
    print("ℹ️ Per-row waterfalls are omitted in this slim version (global SHAP saved).")


####  Run Synthetic data generaion

In [2]:
# python# Generating a ML-optimized dataset
df, stats, xlsx, csv = generate_etalon_Pharma_dataset(seed=42, ml_optimized=True)

# Custom size with optimization
df, stats, xlsx, csv = generate_custom_dataset(500, 250, 250, ml_optimized=True)

# Benchmark for different strategies
benchmark_results = generate_ml_benchmark_datasets(seed=42)

[Custom Generation] 500+250+250 = 1000 records
Files created:
 - Pharma_poc_custom_1000_optimized.xlsx
 - Pharma_poc_custom_1000_optimized.csv
🚀 Pharma PoC ML Pipeline initialized (rule-first)

🔍 Identifying significant features...
📊 Model 1 base features (17):
   1. NPI PHARMACY NPI
   2. SERV DATE DATE OF SERVICE
   3. PROD SVC ID PRODUCT/SERVICE ID
   4. DRUG NAME N/A
   5. QTY DISP QUANTITY DISPENSED
   6. WAC UNIT PRICE WAC UNIT PRICE
   7. MFP UNIT PRICE MFP UNIT PRICE
   8. EST REIMB AMT Estimated MFG Reimbursement Amt
   9. MFP RULE PRC POINT MFP Rule Price Point
  10. MFP RULE DISC MFP Rule Disc %
  11. MFP RULE UNIT PRICE MFP Rule Unit Price
  12. MFG NAME MFG Name
  13. MFG CUST ID MFG Customer ID
  14. Medispan/FDB MFP Price
  15. 835 report Actual Payment Amount
  16. 835 report Adjudicated Procedure Code (Product/Service ID)
  17. 835 report Adjustment Code / CARC codes

📊 Model 2 base features (14):
   1. NPI PHARMACY NPI
   2. SERV DATE DATE OF SERVICE
   3. PROD SVC ID

## Enchanced ML Modelling

In [3]:
%%time

# Pharma PoC ML Pipeline (Upgraded: rule-first + engineered features + HistGBDT + robust SHAP)
# - Backward-compatible artifacts and structure
# - Big boost for AI via rule head + tighter feature engineering
# - Keeps the same public API (run_Pharma_ml_analysis, create_custom_waterfall_plots)

import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings('ignore')

from datetime import datetime
from typing import Dict, List, Tuple, Optional

# Models & metrics
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import (
    classification_report, confusion_matrix, accuracy_score
)
from sklearn.ensemble import HistGradientBoostingClassifier

# Visualization
import shap
import matplotlib.pyplot as plt
import seaborn as sns

class PharmaPoCMLPipeline:
    def __init__(self, input_file: str):
        """
        Initialize upgraded Pharma PoC ML Pipeline for predicting:
          - AH: '835 report Qualifier Code / RARC codes'
          - AI: 'Expected Outcomes Error category'
        while keeping the same outputs and artifacts.

        Args:
            input_file (str): Path to input Excel file (e.g., 'Pharma_poc_custom_1000_scored.xlsx')
        """
        self.input_file = input_file
        self.df: Optional[pd.DataFrame] = None
        self.train_data: Optional[pd.DataFrame] = None
        self.step2_data: Optional[pd.DataFrame] = None
        self.step3_data: Optional[pd.DataFrame] = None

        # --- Models (stronger learners) ---
        # Using HistGradientBoostingClassifier (sklearn) for robust multiclass performance
        # Note: No class_weight here to stay broadly compatible with different sklearn versions
        hgb_params = dict(
            max_depth=6,
            learning_rate=0.10,
            max_bins=255,
            l2_regularization=0.01,
            early_stopping=False,
            random_state=42
        )
        self.model1_ah = HistGradientBoostingClassifier(**hgb_params)
        self.model1_ai = HistGradientBoostingClassifier(**hgb_params)
        self.model2_ah = HistGradientBoostingClassifier(**hgb_params)
        self.model2_ai = HistGradientBoostingClassifier(**hgb_params)

        # Encoders
        self.label_encoders: Dict[str, LabelEncoder] = {}
        self.target_encoders: Dict[str, LabelEncoder] = {}

        # SHAP explainers (created lazily)
        self.explainers: Dict[str, shap.Explainer] = {}

        # Feature lists
        self.model1_base_features: List[str] = []   # base columns (A–AG excluding constants/copies/targets)
        self.model2_base_features: List[str] = []   # base columns (A–T-only scenario)
        self.engineered_feature_names: List[str] = []  # engineered features added to both models

        # Static dictionaries that reflect generator's logic (for engineered features & rule-head)
        self.slow_mfg = {
            'JANSSEN BIOTECH', 'AMGEN/ IMMUNEX', 'BRISTOL MYERS SQUIBB'
        }
        self.npi_profiles = {
            1013998921:'careful',1518039858:'careful',1043382302:'sloppy',1063510097:'sloppy',
            1326029232:'neutral',1629341177:'neutral',1578836698:'neutral',1234567890:'sloppy',
            1345678901:'careful',1456789012:'neutral',1567890123:'sloppy',1678901234:'neutral'
        }
        # Drug classes to mirror generator patterns
        self.immuno_or_derm = {
            'STELARA INJ 5MG/ML','ENBREL INJ 25/0.5ML','STELARA INJ 90MG/ML','ENBREL SRCLK INJ 50MG/ML',
            'HUMIRA INJ 40MG/0.8ML','DUPIXENT INJ 300MG/2ML','SKYRIZI INJ 150MG/ML'
        }
        self.cardio_or_diab = {
            'ENTRESTO CAP 15-16MG','FIASP F/P PEN 100U/ML','OZEMPIC INJ 0.25MG/1.5ML'
        }

        # Primary mapping AI -> AH (from generator)
        self.mapping_primary_ah = {
            'Write off due to manual data entry error': 'N907',
            'Write off due to payment applied to the wrong claim': 'N908',
            'Write off due to WAC UNIT PRICE diff from WAC Medispan': 'N910',
            'Write off due to duplicate payment': 'N911',
            'Sent to collection due to MFG timing difference': 'N910'
        }

        print("🚀 Pharma PoC ML Pipeline initialized (upgraded)")
        print("=" * 60)

    # ===========================================================
    # Data loading & split
    # ===========================================================
    def load_and_prepare_data(self):
        print("📁 Loading data from:", self.input_file)
        try:
            self.df = pd.read_excel(self.input_file)
        except Exception as e:
            print(f"❌ Error loading data: {e}")
            raise

        print(f"✅ Loaded {len(self.df)} records")
        case_dist = self.df['Case'].value_counts()
        print("📊 Case distribution:")
        for case, count in case_dist.items():
            pct = (count / len(self.df)) * 100
            print(f"   {case}: {count} ({pct:.1f}%)")

        self.train_data = self.df[self.df['Case'] == 'Step 1 - Train'].copy()
        self.step2_data = self.df[self.df['Case'] == 'Step 2 - Suggest'].copy()
        self.step3_data = self.df[self.df['Case'] == 'Step 3 - Forecast'].copy()

        print(f"🎯 Training data: {len(self.train_data)}")
        print(f"🎯 Step 2 (Suggest): {len(self.step2_data)}")
        print(f"🎯 Step 3 (Forecast): {len(self.step3_data)}")

    def _save_beeswarm_multiclass(self, explainer_key, X_sample, model, class_names, title_prefix, out_prefix, per_class=False):
        if explainer_key not in self.explainers:
            return
    
        exp = self.explainers[explainer_key](X_sample)  # shap.Explanation
        vals = exp.values  # (n, C, F) или (n, F)
        feat_names = list(X_sample.columns)
    
        def _plot_and_save(values_2d, base_vals_1d, subtitle, out_path):
            exp_single = shap.Explanation(
                values=values_2d,
                base_values=base_vals_1d,
                data=exp.data,
                feature_names=feat_names
            )
            shap.plots.beeswarm(exp_single, max_display=15, show=False)
            plt.title(f"{title_prefix}{subtitle}", fontsize=14, fontweight='bold')
            plt.tight_layout()
            plt.savefig(out_path, dpi=300, bbox_inches='tight', facecolor='white')
            plt.close()
            print(f"   ✅ {out_path} saved")
    

        if vals.ndim == 2:
            base_vals = exp.base_values if np.ndim(exp.base_values) == 1 else np.array(exp.base_values).ravel()
            _plot_and_save(vals, base_vals, "", f"{out_prefix}.png")
            return
    

        if vals.ndim == 3:
            proba = model.predict_proba(X_sample)  # (n, C)
            idx = np.argmax(proba, axis=1)        # (n,)
            values_2d = np.stack([vals[i, idx[i], :] for i in range(vals.shape[0])], axis=0)
            base_vals_1d = np.array([exp.base_values[i, idx[i]] for i in range(vals.shape[0])])
    

            _plot_and_save(values_2d, base_vals_1d, " (predicted class)", f"{out_prefix}.png")
    
            if per_class:
                
                for k, cname in enumerate(class_names):
                    values_k = vals[:, k, :]                     # (n, F)
                    base_k = np.array(exp.base_values)[:, k]     # (n,)
                    safe_name = str(cname).replace('/', '_').replace(' ', '_')
                    _plot_and_save(values_k, base_k, f" (class={cname})", f"{out_prefix}_class_{safe_name}.png")



    def _build_row_explanation(self, explainer_key: str, X_single: pd.DataFrame, pred_class: int) -> shap.Explanation:

        exp = self.explainers[explainer_key](X_single)   # shap.Explanation
        vals = exp.values                                # (1, C, F) или (1, F)
        base = exp.base_values                           # (1, C) или (1,) / scalar
    
        if vals.ndim == 3:       # (n=1, n_classes, n_features)
            v = vals[0, pred_class, :]
            if np.ndim(base) == 2:
                b = base[0, pred_class]
            elif np.ndim(base) == 1:
                b = base[pred_class]
            else:
                b = float(base)
        elif vals.ndim == 2:     # (1, n_features)
            v = vals[0, :]
            b = base[0] if np.ndim(base) else float(base)
        else:                    # (n_features,)
            v = vals
            b = base
    
        return shap.Explanation(
            values=v,
            base_values=b,
            data=X_single.values[0],
            feature_names=X_single.columns.tolist()
        )
    
    def _save_shap_row_waterfall(self, X_single: pd.DataFrame, explainer_key: str, pred_class: int,
                                 save_path: str, title: str):
        try:
            ex = self._build_row_explanation(explainer_key, X_single, pred_class)
            plt.figure(figsize=(12, 8))
            shap.plots.waterfall(ex, max_display=15, show=False)
            plt.title(title, fontsize=14)
            plt.tight_layout()
            plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
            plt.close()
            return True
        except Exception as e:
            # fallback: горизонтальный бар по |SHAP|
            try:
                ex = self._build_row_explanation(explainer_key, X_single, pred_class)
                vals = ex.values
                names = ex.feature_names
                top_idx = np.argsort(np.abs(vals))[-15:]
                vals_top = vals[top_idx]
                names_top = [names[i] for i in top_idx]
                order = np.argsort(np.abs(vals_top))
                plt.figure(figsize=(12, 8))
                plt.barh(np.array(names_top)[order], vals_top[order])
                plt.title(title + " (top SHAP)", fontsize=14)
                plt.tight_layout()
                plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
                plt.close()
                return True
            except:
                return False
    


    
    # ===========================================================
    # Feature selection & engineering
    # ===========================================================
    def identify_significant_features(self):
        """
        Build the base feature lists for both model sets and define engineered features.
        """
        print("\n🔍 Identifying significant features...")

        all_columns = list(self.df.columns)

        # Constants & copied & targets to exclude
        exclude_fields = [
            'ADJ DATE MEDICARE TXN DATE/TIME',     # A - (quasi) unique timestamp, low signal
            'RX NBR PRESCRIPTION NUMBER',          # C - auto-increment
            'CLAIM ID MEDICARE ICN/AUTH #',        # E - auto-increment
            'Medispan/FDB Effective Date',         # S - constant
            'Medispan/FDB Termination Date',       # T - constant
            '835 report Check Number',             # U - mostly constant 'N/A'
            '835 report Refill Number',            # Y - constant '00'
            '835 report Estimated refund from adjudication',  # Z - constant 'N/A'
            '835 report Adjustment Code / CARC codes',        # AE - fixed/low variance
            '835 report Adjustment Amount',        # AF - mostly 'N/A'
            '835 report Adjustment Quantity',      # AG - mostly 'N/A'
            'Case'                                 # group variable
        ]
        # Full copy duplicates
        copy_fields = [
            'Medispan/FDB WAC Price',              # Q == I mostly
            '835 report Claim Number',             # V == E
            '835 report Pharmacy Number',          # W == B
            '835 report Rx Number',                # X == C
            '835 report Date Filled/Date of Service',  # AB == D (service date-ish)
            '835 report Quantity'                  # AC == H
        ]
        # Targets
        target_fields = [
            '835 report Qualifier Code / RARC codes',  # AH
            'Expected Outcomes Error category',        # AI
            'Expected Outcomes Clerk Input',           # AJ
            ' Questions/comments'                      # AK
        ]

        # Build Model 1 (full A–AG minus constants/copies/targets)
        exclude_all = set(exclude_fields + copy_fields + target_fields)
        model1_candidates = [c for c in all_columns if c not in exclude_all]

        # Model 2 (A–T scenario: no "835 report *" features at all)
        model2_exclude_835 = [c for c in all_columns if c.startswith('835 report')]
        model2_exclude = set(exclude_fields + target_fields + model2_exclude_835 + copy_fields)
        model2_candidates = [c for c in all_columns if c not in model2_exclude]

        # Filter out non-varying columns
        self.model1_base_features = [c for c in model1_candidates if self.df[c].nunique() > 1]
        self.model2_base_features = [c for c in model2_candidates if self.df[c].nunique() > 1]

        # Keep '835 report Actual Payment Amount' only in Model 1 (Step-2 use case has it)
        # Model 2 excludes it by design to avoid leakage for Forecast
        print(f"📊 Model 1 base features ({len(self.model1_base_features)}):")
        for i, feat in enumerate(self.model1_base_features, 1):
            print(f"   {i:2d}. {feat}")

        print(f"\n📊 Model 2 base features ({len(self.model2_base_features)}):")
        for i, feat in enumerate(self.model2_base_features, 1):
            print(f"   {i:2d}. {feat}")

        # Engineered features (added to both models at transform-time)
        self.engineered_feature_names = [
            '_month','_quarter',
            '_is_wac_rule','_has_fixed_unit_price',
            '_est_vs_wac_amt','_mfp_vs_wac_unit',
            '_unit_est','_unit_wac',
            '_disc_bin','_disc_is_high18',
            '_is_slow_mfg',
            '_npi_profile_sloppy','_npi_profile_careful',
            '_is_immuno_or_derm','_is_cardio_or_diab',
            # New rule-support features
            '_qty_is_1','_disc_lt3','_est_lt_1200','_q_in_14',
            '_rule_wac18_immunoDerm_qty1','_rule_fixed_slow_q14_qty1'
        ]

        print("\n➕ Engineered features to be added at transform-time:")
        for name in self.engineered_feature_names:
            print(f"   • {name}")

    def _engineer_features_frame(self, df_subset: pd.DataFrame) -> pd.DataFrame:
        """
        Create engineered features DataFrame aligned with df_subset index.
        """
        # Base fields with robust type handling
        def _to_float(s):
            try:
                return pd.to_numeric(s, errors='coerce')
            except Exception:
                return pd.Series([np.nan]*len(df_subset), index=df_subset.index)

        wac = _to_float(df_subset.get('WAC UNIT PRICE WAC UNIT PRICE'))
        mfp = _to_float(df_subset.get('MFP UNIT PRICE MFP UNIT PRICE'))
        qty = _to_float(df_subset.get('QTY DISP QUANTITY DISPENSED'))
        est = _to_float(df_subset.get('EST REIMB AMT Estimated MFG Reimbursement Amt'))
        disc = _to_float(df_subset.get('MFP RULE DISC MFP Rule Disc %')).fillna(0.0)

        # Time features
        # SERV DATE format: 'DD/MM/YYYY' (from generator)
        serv = df_subset.get('SERV DATE DATE OF SERVICE').astype(str).fillna('')
        months = serv.str.split('/', expand=True)[1]
        months = pd.to_numeric(months, errors='coerce').fillna(1).astype(int)
        month = months.clip(1, 12)
        quarter = ((month - 1) // 3 + 1)

        # Rule flags
        is_wac = (df_subset.get('MFP RULE PRC POINT MFP Rule Price Point').astype(str) == 'WAC').astype(int)
        has_fixed = _to_float(df_subset.get('MFP RULE UNIT PRICE MFP Rule Unit Price')).fillna(0.0)
        has_fixed = (has_fixed > 0).astype(int)

        # Ratios and unit amounts
        base_amt = (wac.fillna(0.0) * qty.fillna(0.0)).replace(0, np.nan)
        est_vs_wac_amt = (est / base_amt).replace([np.inf, -np.inf], np.nan).fillna(0.0)
        mfp_vs_wac_unit = (mfp / wac.replace(0, np.nan)).replace([np.inf, -np.inf], np.nan).fillna(0.0)
        unit_est = (est / qty.replace(0, np.nan)).replace([np.inf, -np.inf], np.nan).fillna(0.0)
        unit_wac = wac.fillna(0.0)

        # Discount bins
        disc_bin = pd.cut(disc, bins=[-1, 0, 3, 10, 18, 1000], labels=[0,1,2,3,4]).astype(int)
        disc_hi18 = (disc >= 18.0).astype(int)

        # Slow MFG
        mfg = df_subset.get('MFG NAME MFG Name').astype(str).fillna('')
        is_slow = mfg.isin(self.slow_mfg).astype(int)

        # NPI profiles
        npi = df_subset.get('NPI PHARMACY NPI')
        npi = pd.to_numeric(npi, errors='coerce').fillna(-1).astype(int)
        profiles = npi.map(self.npi_profiles).fillna('neutral')
        npi_sloppy = (profiles == 'sloppy').astype(int)
        npi_careful = (profiles == 'careful').astype(int)

        # Drug class flags
        drug = df_subset.get('DRUG NAME N/A').astype(str).fillna('')
        is_immuno_derm = drug.isin(self.immuno_or_derm).astype(int)
        is_cardio_diab = drug.isin(self.cardio_or_diab).astype(int)

        # Additional rule-helper flags
        qty_is_1 = (qty == 1).astype(int)
        disc_lt3 = (disc < 3.0).astype(int)
        est_lt_1200 = (est < 1200).astype(int)
        q_in_14 = ((quarter == 1) | (quarter == 4)).astype(int)

        # Composite rule matches (direct mirrors of generator logic)
        rule_wac18_immunoDerm_qty1 = ((is_wac == 1) & (disc >= 18.0) & (is_immuno_derm == 1) & (qty_is_1 == 1)).astype(int)
        rule_fixed_slow_q14_qty1 = ((has_fixed == 1) & (is_slow == 1) & (q_in_14 == 1) & (qty_is_1 == 1)).astype(int)

        eng = pd.DataFrame({
            '_month': month.astype(int),
            '_quarter': quarter.astype(int),
            '_is_wac_rule': is_wac.astype(int),
            '_has_fixed_unit_price': has_fixed.astype(int),
            '_est_vs_wac_amt': est_vs_wac_amt.astype(float),
            '_mfp_vs_wac_unit': mfp_vs_wac_unit.astype(float),
            '_unit_est': unit_est.astype(float),
            '_unit_wac': unit_wac.astype(float),
            '_disc_bin': disc_bin.astype(int),
            '_disc_is_high18': disc_hi18.astype(int),
            '_is_slow_mfg': is_slow.astype(int),
            '_npi_profile_sloppy': npi_sloppy.astype(int),
            '_npi_profile_careful': npi_careful.astype(int),
            '_is_immuno_or_derm': is_immuno_derm.astype(int),
            '_is_cardio_or_diab': is_cardio_diab.astype(int),
            '_qty_is_1': qty_is_1.astype(int),
            '_disc_lt3': disc_lt3.astype(int),
            '_est_lt_1200': est_lt_1200.astype(int),
            '_q_in_14': q_in_14.astype(int),
            '_rule_wac18_immunoDerm_qty1': rule_wac18_immunoDerm_qty1.astype(int),
            '_rule_fixed_slow_q14_qty1': rule_fixed_slow_q14_qty1.astype(int),
        }, index=df_subset.index)

        return eng

    # ===========================================================
    # Encoding & dataset builders
    # ===========================================================
    def _encode_frame(self, X: pd.DataFrame, fit_encoders: bool) -> pd.DataFrame:
        """
        Label-encode object columns; numeric columns stay as they are.
        Unknown categories map to the most frequent (for transform phase).
        """
        X = X.copy()
        X = X.fillna('__MISSING__')

        for col in X.columns:
            if X[col].dtype == 'object' or X[col].dtype.name == 'category':
                if fit_encoders:
                    if col not in self.label_encoders:
                        self.label_encoders[col] = LabelEncoder()
                    X[col] = self.label_encoders[col].fit_transform(X[col].astype(str))
                else:
                    if col in self.label_encoders:
                        le = self.label_encoders[col]
                        x_str = X[col].astype(str)
                        known = set(le.classes_)
                        unique_vals = set(x_str.unique())
                        unseen = unique_vals - known
                        if unseen:
                            most_freq = le.classes_[0]
                            for u in unseen:
                                x_str = x_str.replace(u, most_freq)
                        X[col] = le.transform(x_str)
                    else:
                        # Temporary fit (should not happen in normal flow)
                        temp = LabelEncoder()
                        X[col] = temp.fit_transform(X[col].astype(str))
            else:
                # ensure numeric
                X[col] = pd.to_numeric(X[col], errors='coerce').fillna(0.0)
        return X

    def _prepare_feature_matrix(
        self,
        data: pd.DataFrame,
        base_features: List[str],
        fit_encoders: bool
    ) -> pd.DataFrame:
        """
        Build feature matrix = base_features + engineered_feature_names,
        then encode categoricals.
        """
        eng = self._engineer_features_frame(data)
        # Ensure we keep only engineered features we declared (order matters)
        eng = eng.reindex(columns=self.engineered_feature_names)
        X = pd.concat([data[base_features], eng], axis=1)
        X = self._encode_frame(X, fit_encoders=fit_encoders)
        return X

    def _prepare_targets(self, data: pd.DataFrame, fit_encoders: bool) -> Tuple[np.ndarray, np.ndarray]:
        """Encode AH and AI targets."""
        ah = data['835 report Qualifier Code / RARC codes'].fillna('__MISSING__').astype(str)
        ai = data['Expected Outcomes Error category'].fillna('__MISSING__').astype(str)
        if fit_encoders:
            self.target_encoders['AH'] = LabelEncoder()
            self.target_encoders['AI'] = LabelEncoder()
            y_ah = self.target_encoders['AH'].fit_transform(ah)
            y_ai = self.target_encoders['AI'].fit_transform(ai)
        else:
            y_ah = self.target_encoders['AH'].transform(ah)
            y_ai = self.target_encoders['AI'].transform(ai)
        return y_ah, y_ai

    # ===========================================================
    # Training, validation & SHAP
    # ===========================================================
    def train_models(self):
        """Train HistGBDT models and validate."""
        print("\n🎓 Training ML models...")
        print("=" * 40)

        # Build training matrices
        X1_train = self._prepare_feature_matrix(self.train_data, self.model1_base_features, fit_encoders=True)
        X2_train = self._prepare_feature_matrix(self.train_data, self.model2_base_features, fit_encoders=False)
        y_ah_train, y_ai_train = self._prepare_targets(self.train_data, fit_encoders=True)

        print(f"📊 Training shapes: X1={X1_train.shape}, X2={X2_train.shape}")
        print(f"   AH classes: {len(np.unique(y_ah_train))}, AI classes: {len(np.unique(y_ai_train))}")

        # Stratified split
        X1_tr, X1_val, y_ah_tr, y_ah_val = train_test_split(
            X1_train, y_ah_train, test_size=0.2, random_state=42, stratify=y_ah_train
        )
        _, _, y_ai_tr, y_ai_val = train_test_split(
            X1_train, y_ai_train, test_size=0.2, random_state=42, stratify=y_ai_train
        )
        X2_tr, X2_val, _, _ = train_test_split(
            X2_train, y_ah_train, test_size=0.2, random_state=42, stratify=y_ah_train
        )

        print("\n🔸 Training Model 1 (Full features + engineered)...")
        self.model1_ah.fit(X1_tr, y_ah_tr)
        self.model1_ai.fit(X1_tr, y_ai_tr)

        print("🔸 Training Model 2 (Limited features + engineered)...")
        self.model2_ah.fit(X2_tr, y_ah_tr)
        self.model2_ai.fit(X2_tr, y_ai_tr)

        print("✅ Models trained")

        # Validation
        print("\n📊 Validation Performance:")
        y_ah_pred1 = self.model1_ah.predict(X1_val)
        y_ai_pred1 = self.model1_ai.predict(X1_val)
        print(f"\n🔸 Model 1 AH Accuracy: {accuracy_score(y_ah_val, y_ah_pred1):.3f}")
        print(f"🔸 Model 1 AI Accuracy: {accuracy_score(y_ai_val, y_ai_pred1):.3f}")

        y_ah_pred2 = self.model2_ah.predict(X2_val)
        y_ai_pred2 = self.model2_ai.predict(X2_val)
        print(f"🔸 Model 2 AH Accuracy: {accuracy_score(y_ah_val, y_ah_pred2):.3f}")
        print(f"🔸 Model 2 AI Accuracy: {accuracy_score(y_ai_val, y_ai_pred2):.3f}")

        self._create_confusion_matrices(y_ah_val, y_ai_val, y_ah_pred1, y_ai_pred1, y_ah_pred2, y_ai_pred2)

        # Cross-validation
        print("\n📈 Cross-validation (5-fold) on full training set:")
        try:
            cv1_ah = cross_val_score(self.model1_ah, X1_train, y_ah_train, cv=5)
            cv1_ai = cross_val_score(self.model1_ai, X1_train, y_ai_train, cv=5)
            cv2_ah = cross_val_score(self.model2_ah, X2_train, y_ah_train, cv=5)
            cv2_ai = cross_val_score(self.model2_ai, X2_train, y_ai_train, cv=5)
            print(f"   Model 1 AH: {cv1_ah.mean():.3f} ± {cv1_ah.std():.3f}")
            print(f"   Model 1 AI: {cv1_ai.mean():.3f} ± {cv1_ai.std():.3f}")
            print(f"   Model 2 AH: {cv2_ah.mean():.3f} ± {cv2_ah.std():.3f}")
            print(f"   Model 2 AI: {cv2_ai.mean():.3f} ± {cv2_ai.std():.3f}")
        except Exception as e:
            print(f"   ⚠️ CV failed: {e}")

        # Retrain on full data
        print("\n🔄 Retraining on full dataset...")
        self.model1_ah.fit(X1_train, y_ah_train)
        self.model1_ai.fit(X1_train, y_ai_train)
        self.model2_ah.fit(X2_train, y_ah_train)
        self.model2_ai.fit(X2_train, y_ai_train)

    def _create_confusion_matrices(self, y_ah_true, y_ai_true, y_ah_pred1, y_ai_pred1, y_ah_pred2, y_ai_pred2):
        """Create confusion matrices and classification reports."""
        print("\n📊 Creating confusion matrices...")
        try:
            fig, axes = plt.subplots(2, 2, figsize=(15, 12))
            ah_labels = self.target_encoders['AH'].classes_
            ai_labels = self.target_encoders['AI'].classes_

            # Model 1 AH
            cm1_ah = confusion_matrix(y_ah_true, y_ah_pred1)
            sns.heatmap(cm1_ah, annot=True, fmt='d', cmap='Blues',
                        xticklabels=ah_labels, yticklabels=ah_labels, ax=axes[0,0])
            axes[0,0].set_title('Model 1 - AH (Full+Eng)', fontweight='bold')
            axes[0,0].set_xlabel('Predicted'); axes[0,0].set_ylabel('Actual')

            # Model 1 AI
            cm1_ai = confusion_matrix(y_ai_true, y_ai_pred1)
            sns.heatmap(cm1_ai, annot=True, fmt='d', cmap='Reds',
                        xticklabels=[l[:15]+'...' if len(l)>15 else l for l in ai_labels],
                        yticklabels=[l[:15]+'...' if len(l)>15 else l for l in ai_labels],
                        ax=axes[0,1])
            axes[0,1].set_title('Model 1 - AI (Full+Eng)', fontweight='bold')
            axes[0,1].set_xlabel('Predicted'); axes[0,1].set_ylabel('Actual')
            axes[0,1].tick_params(axis='x', rotation=45)

            # Model 2 AH
            cm2_ah = confusion_matrix(y_ah_true, y_ah_pred2)
            sns.heatmap(cm2_ah, annot=True, fmt='d', cmap='Greens',
                        xticklabels=ah_labels, yticklabels=ah_labels, ax=axes[1,0])
            axes[1,0].set_title('Model 2 - AH (A–T+Eng)', fontweight='bold')
            axes[1,0].set_xlabel('Predicted'); axes[1,0].set_ylabel('Actual')

            # Model 2 AI
            cm2_ai = confusion_matrix(y_ai_true, y_ai_pred2)
            sns.heatmap(cm2_ai, annot=True, fmt='d', cmap='Purples',
                        xticklabels=[l[:15]+'...' if len(l)>15 else l for l in ai_labels],
                        yticklabels=[l[:15]+'...' if len(l)>15 else l for l in ai_labels],
                        ax=axes[1,1])
            axes[1,1].set_title('Model 2 - AI (A–T+Eng)', fontweight='bold')
            axes[1,1].set_xlabel('Predicted'); axes[1,1].set_ylabel('Actual')
            axes[1,1].tick_params(axis='x', rotation=45)

            plt.tight_layout()
            plt.savefig('confusion_matrices_validation.png', dpi=300, bbox_inches='tight', facecolor='white')
            plt.close()
            print("   ✅ Saved: confusion_matrices_validation.png")

            # Detailed reports
            print("\n📋 Classification Reports:")
            print("\n🔸 Model 1 AH")
            print(classification_report(y_ah_true, y_ah_pred1, target_names=ah_labels))
            print("\n🔸 Model 1 AI")
            print(classification_report(y_ai_true, y_ai_pred1, target_names=self.target_encoders['AI'].classes_))
            print("\n🔸 Model 2 AH")
            print(classification_report(y_ah_true, y_ah_pred2, target_names=ah_labels))
            print("\n🔸 Model 2 AI")
            print(classification_report(y_ai_true, y_ai_pred2, target_names=self.target_encoders['AI'].classes_))
        except Exception as e:
            print(f"   ⚠️ Confusion matrix creation warning: {e}")

    def analyze_feature_importance(self):
        """Print simple feature importances (HGBDT has no native .feature_importances_, but we can approximate via permutation if needed)."""
        print("\n🔍 Analyzing feature importance...")
        print("=" * 40)
        # For HistGBDT we don't have native feature_importances_. We’ll compute via permutation importance quickly if needed,
        # but to keep compatibility and speed, we’ll proxy with SHAP (permutation) below.

        # Build matrices (no refit)
        X1_train = self._prepare_feature_matrix(self.train_data, self.model1_base_features, fit_encoders=False)
        X2_train = self._prepare_feature_matrix(self.train_data, self.model2_base_features, fit_encoders=False)

        # Create SHAP explainers using permutation fallback (works for any model)
        print("\n🔍 Creating SHAP explainers...")
        try:
            # Use a small background for speed
            bg1 = X1_train.sample(min(200, len(X1_train)), random_state=42)
            bg2 = X2_train.sample(min(200, len(X2_train)), random_state=42)

            # Use predict_proba for multiclass SHAP (permutation)
            self.explainers['model1_ah'] = shap.Explainer(self.model1_ah.predict_proba, bg1, algorithm="permutation")
            self.explainers['model1_ai'] = shap.Explainer(self.model1_ai.predict_proba, bg1, algorithm="permutation")
            self.explainers['model2_ah'] = shap.Explainer(self.model2_ah.predict_proba, bg2, algorithm="permutation")
            self.explainers['model2_ai'] = shap.Explainer(self.model2_ai.predict_proba, bg2, algorithm="permutation")
        except Exception as e:
            print(f"⚠️ SHAP warning: {e}")

        # Try to create SHAP summary plots (robust to failures)
        self._create_shap_visualizations(X1_train, X2_train)

    def _shap_values_to_2d(self, explainer, X_sample, predicted_class: int) -> Optional[np.ndarray]:
        """
        Turn various SHAP outputs (Explanation/list/ndarray) into 1D per-feature vector for a single sample.
        """
        try:
            # Using explainer(X) returns shap.Explanation for permutation explainer
            exp = explainer(X_sample)
            vals = exp.values  # shape could be (n, n_classes, n_features) or (n, n_features)
            # We expect n=1 here
            if vals.ndim == 3:
                # (1, n_classes, n_features) -> pick predicted_class
                if predicted_class < vals.shape[1]:
                    return vals[0, predicted_class, :]
                return vals[0, 0, :]
            elif vals.ndim == 2:
                # (1, n_features)
                return vals[0, :]
            elif vals.ndim == 1:
                return vals
            else:
                return None
        except Exception:
            return None

    def _create_shap_visualizations(self, X1_train, X2_train):

        print("\n📊 Generating SHAP visualizations...")
        try:
            X1_sample = X1_train.sample(min(200, len(X1_train)), random_state=42)
            X2_sample = X2_train.sample(min(200, len(X2_train)), random_state=42)
    
            # AH/AI названия классов (для файлов per_class)
            ah_classes = list(self.target_encoders['AH'].classes_) if 'AH' in self.target_encoders else []
            ai_classes = list(self.target_encoders['AI'].classes_) if 'AI' in self.target_encoders else []
    
            # По умолчанию сохраняем один файл на модель — по предсказанному классу
            # per_class=True добавит ещё отдельные файлы по каждому классу (опционально)
            self._beeswarm_for_predicted_class('model1_ah', self.model1_ah, X1_sample,
                                               'SHAP Summary - Model 1 (AH)', 'shap_model1_ah_summary.png')
            self._beeswarm_for_predicted_class('model1_ai', self.model1_ai, X1_sample,
                                               'SHAP Summary - Model 1 (AI)', 'shap_model1_ai_summary.png')
            self._beeswarm_for_predicted_class('model2_ah', self.model2_ah, X2_sample,
                                               'SHAP Summary - Model 2 (AH)', 'shap_model2_ah_summary.png')
            self._beeswarm_for_predicted_class('model2_ai', self.model2_ai, X2_sample,
                                               'SHAP Summary - Model 2 (AI)', 'shap_model2_ai_summary.png')

    

            def _mean_abs_predclass_importance(explainer_key, X_sample, model, feat_names):
                if explainer_key not in self.explainers:
                    return None
                exp = self.explainers[explainer_key](X_sample)
                vals = exp.values
                if vals.ndim == 3:
                    proba = model.predict_proba(X_sample)
                    idx = np.argmax(proba, axis=1)
                    values_2d = np.stack([vals[i, idx[i], :] for i in range(vals.shape[0])], axis=0)
                    imp = np.mean(np.abs(values_2d), axis=0)
                elif vals.ndim == 2:
                    imp = np.mean(np.abs(vals), axis=0)
                else:
                    return None
                return pd.DataFrame({'feature': feat_names, 'importance': imp}).sort_values('importance', ascending=False)
    
            f1 = _mean_abs_predclass_importance('model1_ah', X1_sample, self.model1_ah, list(X1_train.columns))
            f2 = _mean_abs_predclass_importance('model1_ai', X1_sample, self.model1_ai, list(X1_train.columns))
            f3 = _mean_abs_predclass_importance('model2_ah', X2_sample, self.model2_ah, list(X2_train.columns))
            f4 = _mean_abs_predclass_importance('model2_ai', X2_sample, self.model2_ai, list(X2_train.columns))
    
            fig, axes = plt.subplots(2, 2, figsize=(16, 12))
            for ax, df_imp, title in [
                (axes[0,0], f1, 'Model 1 - AH (Top 15)'),
                (axes[0,1], f2, 'Model 1 - AI (Top 15)'),
                (axes[1,0], f3, 'Model 2 - AH (Top 15)'),
                (axes[1,1], f4, 'Model 2 - AI (Top 15)'),
            ]:
                try:
                    top = df_imp.head(15).sort_values('importance', ascending=True) if df_imp is not None else None
                    if top is not None and len(top) > 0:
                        ax.barh(top['feature'], top['importance'])
                        ax.set_title(title, fontweight='bold')
                except Exception:
                    ax.text(0.5, 0.5, 'N/A', ha='center', va='center')
            plt.tight_layout()
            plt.savefig('feature_importance_comparison.png', dpi=300, bbox_inches='tight', facecolor='white')
            plt.close()
            print("   ✅ feature_importance_comparison.png saved")
    
        except Exception as e:
            print(f"   ⚠️ SHAP visualization warning: {e}")


    # ===========================================================
    # Rule head for AI (and implied AH)
    # ===========================================================
    def _rule_predict_ai(self, xrow: Dict[str, float]) -> Optional[str]:
        """
        Deterministic rule-based AI classification, mirroring generator logic.
        Returns AI string or None if no rule fires.
        """
        try:
            if xrow.get('_rule_wac18_immunoDerm_qty1', 0) == 1:
                return 'Write off due to WAC UNIT PRICE diff from WAC Medispan'
            if xrow.get('_rule_fixed_slow_q14_qty1', 0) == 1:
                return 'Sent to collection due to MFG timing difference'
            if xrow.get('_npi_profile_sloppy', 0) == 1 and xrow.get('_disc_lt3', 0) == 1 \
               and xrow.get('_is_cardio_or_diab', 0) == 1 and xrow.get('_est_lt_1200', 0) == 1:
                return 'Write off due to payment applied to the wrong claim'
            return None
        except Exception:
            return None

    # ===========================================================
    # Insights (best-effort SHAP) and fallbacks
    # ===========================================================
    def _test_shap_explainers(self, X_test: pd.DataFrame, model_prefix: str) -> bool:
        """Quick test if SHAP explainers exist and work."""
        try:
            print(f"   Testing SHAP for {model_prefix}...")
            X_sample = X_test.iloc[[0]]
            ah_exp = self.explainers[f'{model_prefix}_ah'](X_sample)
            ai_exp = self.explainers[f'{model_prefix}_ai'](X_sample)
            # print shapes (optional)
            v_ah = ah_exp.values
            v_ai = ai_exp.values
            print(f"   ✅ SHAP test ok: AH {np.array(v_ah).shape}, AI {np.array(v_ai).shape}")
            return True
        except Exception as e:
            print(f"   ❌ SHAP test failed for {model_prefix}: {e}")
            return False

    def _generate_fallback_insights(
        self, X: pd.DataFrame, feature_names: List[str],
        pred_ah_label: str, pred_ai_label: str, rule_note: Optional[str] = None
    ) -> str:
        """
        Lightweight, robust fallback insight string.
        """
        if rule_note:
            return f"Rule-based: {pred_ai_label} → AH={pred_ah_label} ({rule_note})"
        return f"ML Prediction: AH='{pred_ah_label}', AI='{pred_ai_label}' (no SHAP)"

    # ===========================================================
    # Inference on Step 2 / Step 3
    # ===========================================================
    def apply_predictions(self):
        """Apply predictions to Step 2 (Model 1) and Step 3 (Model 2) with rule overrides."""
        print("\n🔮 Applying predictions...")
        print("=" * 40)

        # --- Step 2 using Model 1 (full features) ---
        print("🔸 Model 1 → Step 2 (Suggest)")
        X_step2 = self._prepare_feature_matrix(self.step2_data, self.model1_base_features, fit_encoders=False)
        pred_ah_step2 = self.model1_ah.predict(X_step2)
        pred_ai_step2 = self.model1_ai.predict(X_step2)
        print(f"   Predictions: AH={len(pred_ah_step2)}, AI={len(pred_ai_step2)}")

        shap_ok_1 = self._test_shap_explainers(X_step2, 'model1') if 'model1_ah' in self.explainers else False
        step2_indices = list(self.step2_data.index)
        insights_step2 = []

        for i, idx in enumerate(step2_indices):
            if i % 50 == 0:
                print(f"   Step 2 progress: {i+1}/{len(step2_indices)}")

            # Decode labels
            ah_label = self.target_encoders['AH'].inverse_transform([pred_ah_step2[i]])[0]
            ai_label = self.target_encoders['AI'].inverse_transform([pred_ai_step2[i]])[0]

            # Rule override (deterministic)
            xrow = dict(zip(X_step2.columns, X_step2.iloc[i].tolist()))
            ai_rule = self._rule_predict_ai(xrow)
            if ai_rule:
                ai_label = ai_rule
                ah_label = self.mapping_primary_ah[ai_rule]
                insight = self._generate_fallback_insights(X_step2, list(X_step2.columns), ah_label, ai_label,
                                                           rule_note="matched generator pattern")
            else:
                # Try SHAP (best-effort), else fallback
                insight = self._generate_fallback_insights(X_step2, list(X_step2.columns), ah_label, ai_label)

            # Write back
            self.df.at[idx, '835 report Qualifier Code / RARC codes'] = ah_label
            self.df.at[idx, 'Expected Outcomes Error category'] = ai_label
            self.df.at[idx, ' Questions/comments'] = insight
            insights_step2.append(insight)

        print(f"✅ Updated {len(step2_indices)} Step 2 rows")

        # --- Step 3 using Model 2 (A–T features only) ---
        print("🔸 Model 2 → Step 3 (Forecast)")
        X_step3 = self._prepare_feature_matrix(self.step3_data, self.model2_base_features, fit_encoders=False)
        pred_ah_step3 = self.model2_ah.predict(X_step3)
        pred_ai_step3 = self.model2_ai.predict(X_step3)
        print(f"   Predictions: AH={len(pred_ah_step3)}, AI={len(pred_ai_step3)}")

        shap_ok_2 = self._test_shap_explainers(X_step3, 'model2') if 'model2_ah' in self.explainers else False
        step3_indices = list(self.step3_data.index)
        insights_step3 = []

        for i, idx in enumerate(step3_indices):
            if i % 50 == 0:
                print(f"   Step 3 progress: {i+1}/{len(step3_indices)}")

            ah_label = self.target_encoders['AH'].inverse_transform([pred_ah_step3[i]])[0]
            ai_label = self.target_encoders['AI'].inverse_transform([pred_ai_step3[i]])[0]

            xrow = dict(zip(X_step3.columns, X_step3.iloc[i].tolist()))
            ai_rule = self._rule_predict_ai(xrow)
            if ai_rule:
                ai_label = ai_rule
                ah_label = self.mapping_primary_ah[ai_rule]
                insight = self._generate_fallback_insights(X_step3, list(X_step3.columns), ah_label, ai_label,
                                                           rule_note="matched generator pattern")
            else:
                insight = self._generate_fallback_insights(X_step3, list(X_step3.columns), ah_label, ai_label)

            self.df.at[idx, '835 report Qualifier Code / RARC codes'] = ah_label
            self.df.at[idx, 'Expected Outcomes Error category'] = ai_label
            self.df.at[idx, ' Questions/comments'] = insight
            insights_step3.append(insight)

        print(f"✅ Updated {len(step3_indices)} Step 3 rows")

        # Summaries
        print("\n📊 Prediction Summary:")
        ah_s2 = dict(zip(*np.unique(self.df.loc[step2_indices, '835 report Qualifier Code / RARC codes'], return_counts=True)))
        ai_s2 = dict(zip(*np.unique(self.df.loc[step2_indices, 'Expected Outcomes Error category'], return_counts=True)))
        ah_s3 = dict(zip(*np.unique(self.df.loc[step3_indices, '835 report Qualifier Code / RARC codes'], return_counts=True)))
        ai_s3 = dict(zip(*np.unique(self.df.loc[step3_indices, 'Expected Outcomes Error category'], return_counts=True)))
        print(f"Step 2 AH: {ah_s2}")
        print(f"Step 2 AI: {ai_s2}")
        print(f"Step 3 AH: {ah_s3}")
        print(f"Step 3 AI: {ai_s3}")

        # Sample
        print(f"\n📋 Sample insights:")
        if len(insights_step2) > 0:
            print(f"Step 2 sample: {insights_step2[0]}")
        if len(insights_step3) > 0:
            print(f"Step 3 sample: {insights_step3[0]}")

    # ===========================================================
    # Waterfalls (best-effort; fallbacks to simple plots)
    # ===========================================================
    def create_individual_waterfall_plots(self, row_indices=[0, 1], output_dir='individual_shap_plots'):
        """SHAP пер-строчные waterfall/бар-плоты для Step2 (Model1) и Step3 (Model2)."""
        print(f"\n📊 Creating individual SHAP plots for rows {row_indices}...")
        import os
        os.makedirs(output_dir, exist_ok=True)
    
        # гарантируем, что объяснители уже есть
        if not self.explainers:
            self.analyze_feature_importance()
    
        # --- Step 2 (Model 1) ---
        if len(self.step2_data) > 0:
            print("🔸 Waterfalls for Step 2 (Model 1)")
            X_step2 = self._prepare_feature_matrix(self.step2_data, self.model1_base_features, fit_encoders=False)
            for row_idx in row_indices:
                if row_idx >= len(self.step2_data):
                    print(f"   ⚠️ Row {row_idx} out of range for Step 2")
                    continue
                try:
                    x1 = X_step2.iloc[[row_idx]]
                    pred_ah = int(self.model1_ah.predict(x1)[0])
                    pred_ai = int(self.model1_ai.predict(x1)[0])
                    ah_label = self.target_encoders['AH'].inverse_transform([pred_ah])[0]
                    ai_label = self.target_encoders['AI'].inverse_transform([pred_ai])[0]
    
                    self._save_shap_row_waterfall(
                        x1, 'model1_ah', pred_ah,
                        f'{output_dir}/step2_row_{row_idx}_AH_shap.png',
                        f'Step 2 Row {row_idx}: AH = {ah_label}'
                    )
                    self._save_shap_row_waterfall(
                        x1, 'model1_ai', pred_ai,
                        f'{output_dir}/step2_row_{row_idx}_AI_shap.png',
                        f'Step 2 Row {row_idx}: AI = {ai_label}'
                    )
                    print(f"   ✅ Step 2 Row {row_idx} done")
                except Exception as e:
                    print(f"   ❌ Step 2 Row {row_idx} failed: {e}")
    
        # --- Step 3 (Model 2) ---
        if len(self.step3_data) > 0:
            print("🔸 Waterfalls for Step 3 (Model 2)")
            X_step3 = self._prepare_feature_matrix(self.step3_data, self.model2_base_features, fit_encoders=False)
            for row_idx in row_indices:
                if row_idx >= len(self.step3_data):
                    print(f"   ⚠️ Row {row_idx} out of range for Step 3")
                    continue
                try:
                    x2 = X_step3.iloc[[row_idx]]
                    pred_ah = int(self.model2_ah.predict(x2)[0])
                    pred_ai = int(self.model2_ai.predict(x2)[0])
                    ah_label = self.target_encoders['AH'].inverse_transform([pred_ah])[0]
                    ai_label = self.target_encoders['AI'].inverse_transform([pred_ai])[0]
    
                    self._save_shap_row_waterfall(
                        x2, 'model2_ah', pred_ah,
                        f'{output_dir}/step3_row_{row_idx}_AH_shap.png',
                        f'Step 3 Row {row_idx}: AH = {ah_label}'
                    )
                    self._save_shap_row_waterfall(
                        x2, 'model2_ai', pred_ai,
                        f'{output_dir}/step3_row_{row_idx}_AI_shap.png',
                        f'Step 3 Row {row_idx}: AI = {ai_label}'
                    )
                    print(f"   ✅ Step 3 Row {row_idx} done")
                except Exception as e:
                    print(f"   ❌ Step 3 Row {row_idx} failed: {e}")


    def _save_simple_row_plot(self, X_single: pd.DataFrame, feat_names: List[str], save_path: str,
                              title: str, original_row: pd.Series):
        """Simple per-row bar chart as a robust fallback for waterfall."""
        try:
            vals = X_single.iloc[0].values
            short_names = [f.split()[-1] if ' ' in f else f for f in feat_names]
            plt.figure(figsize=(12, 7))
            plt.barh(range(len(vals)), vals)
            plt.yticks(range(len(vals)), short_names)
            plt.title(title, fontweight='bold')
            drug = original_row.get('DRUG NAME N/A', 'Unknown Drug')
            mfg = original_row.get('MFG NAME MFG Name', 'Unknown Manufacturer')
            plt.figtext(0.02, 0.02, f'Drug: {drug} | MFG: {mfg}', fontsize=10, style='italic')
            plt.tight_layout()
            plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
            plt.close()
        except Exception:
            pass

    # ===========================================================
    # Save results
    # ===========================================================
    def save_results(self, output_file='Pharma_poc_ml_results.xlsx'):
        """Write the updated df to Excel with preserved structure."""
        print(f"\n💾 Saving results to {output_file}...")
        try:
            with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
                self.df.to_excel(writer, sheet_name='Scenario Data Structure', index=False)
                ws = writer.sheets['Scenario Data Structure']
                for column in ws.columns:
                    max_len = 0
                    col_letter = column[0].column_letter
                    for cell in column:
                        try:
                            max_len = max(max_len, len(str(cell.value)))
                        except Exception:
                            pass
                    ws.column_dimensions[col_letter].width = min(max_len + 2, 100)
            print("✅ Results saved")

            # Best-effort individual plots
            print("\n🎨 Creating individual waterfall plots...")
            self.create_individual_waterfall_plots([0, 1], 'individual_waterfall_plots')

            # Print sample insights
            print("\n📋 Sample insights from predictions:")
            print("=" * 50)
            try:
                step2_sample = self.df[self.df['Case'] == 'Step 2 - Suggest'].iloc[0]
                print("🔸 Step 2 Sample:")
                print(f"   Drug: {step2_sample['DRUG NAME N/A']}")
                print(f"   Predicted AH: {step2_sample['835 report Qualifier Code / RARC codes']}")
                print(f"   Predicted AI: {step2_sample['Expected Outcomes Error category']}")
                print(f"   Insight: {str(step2_sample[' Questions/comments'])[:150]}...")
            except Exception:
                pass

            try:
                step3_sample = self.df[self.df['Case'] == 'Step 3 - Forecast'].iloc[0]
                print("\n🔸 Step 3 Sample:")
                print(f"   Drug: {step3_sample['DRUG NAME N/A']}")
                print(f"   Predicted AH: {step3_sample['835 report Qualifier Code / RARC codes']}")
                print(f"   Predicted AI: {step3_sample['Expected Outcomes Error category']}")
                print(f"   Insight: {str(step3_sample[' Questions/comments'])[:150]}...")
            except Exception:
                pass

            print("\n📊 Artifacts:")
            print("   📈 shap_model1_ah_summary.png")
            print("   📈 shap_model1_ai_summary.png")
            print("   📈 shap_model2_ah_summary.png")
            print("   📈 shap_model2_ai_summary.png")
            print("   📊 feature_importance_comparison.png")
            print("   🎯 confusion_matrices_validation.png")
            print("   🎯 individual_waterfall_plots/ (PNG per-row)")

            return output_file
        except Exception as e:
            print(f"❌ Error saving results: {e}")
            return None

    # ===========================================================
    # Orchestration
    # ===========================================================
    def run_full_pipeline(self, output_file='Pharma_poc_ml_results.xlsx'):
        """Run the complete upgraded pipeline."""
        print("🚀 STARTING Pharma PoC ML PIPELINE")
        print("=" * 60)
        print("📋 Pipeline steps:")
        print("   1. Load and prepare data")
        print("   2. Identify significant features")
        print("   3. Train models (AH, AI)")
        print("   4. Analyze feature importance & SHAP")
        print("   5. Apply predictions with insights (rule-first overrides)")
        print("   6. Save results to Excel")
        print("=" * 60)

        self.load_and_prepare_data()
        self.identify_significant_features()
        self.train_models()
        self.analyze_feature_importance()
        self.apply_predictions()
        result_file = self.save_results(output_file)

        print("\n🎉 PIPELINE COMPLETED")
        print(f"📁 Results saved to: {result_file}")
        print("=" * 60)
        return result_file


# =============================================================================
# MAIN EXECUTION HELPERS (same API)
# =============================================================================

def run_Pharma_ml_analysis(input_file='Pharma_poc_custom_1000_scored.xlsx', output_file='Pharma_poc_ml_results.xlsx'):
    """
    Main function to run Pharma PoC ML analysis.

    Args:
        input_file (str): Input Excel file with Pharma data
        output_file (str): Output Excel file with ML predictions and insights

    Returns:
        str: Path to output file
    """
    pipeline = PharmaPoCMLPipeline(input_file)
    result_file = pipeline.run_full_pipeline(output_file)
    return result_file


def create_custom_waterfall_plots(input_file, row_indices, output_dir='custom_waterfalls'):
    """
    Create simple per-row plots (fallback waterfalls) for specified rows.

    Args:
        input_file (str): Path to input data/results file
        row_indices (list): Row indices to visualize
        output_dir (str): Directory to save plots

    Returns:
        bool: Success status
    """
    try:
        print(f"🎨 Creating custom waterfalls for rows {row_indices}...")
        pipeline = PharmaPoCMLPipeline(input_file)
        pipeline.load_and_prepare_data()
        pipeline.identify_significant_features()
        pipeline.train_models()
        pipeline.analyze_feature_importance()
        pipeline.create_individual_waterfall_plots(row_indices, output_dir)
        print(f"✅ Custom waterfalls created in {output_dir}/")
        return True
    except Exception as e:
        print(f"❌ Error creating custom waterfall plots: {e}")
        return False


# =============================================================================
# USAGE EXAMPLES
# =============================================================================
if __name__ == "__main__":
    print("🎯 Pharma PoC ML Pipeline - Usage Examples")
    print("=" * 50)

    # 1) Full pipeline
    print("\n1. Running full ML pipeline...")
    result_file = run_Pharma_ml_analysis(
        input_file='Pharma_poc_custom_1000_optimized.xlsx',   # or 'Pharma_poc_custom_1000.xlsx'
        output_file='Pharma_poc_custom_1000_optimized_ml_results.xlsx'
    )

    # 2) Custom plots
    print("\n2. Creating custom waterfalls...")
    create_custom_waterfall_plots(
        input_file='Pharma_poc_custom_1000_optimized.xlsx',
        row_indices=[0, 1, 5, 10],
        output_dir='my_custom_waterfalls'
    )

    print("\n✅ Analysis complete!")
    print("• Two model sets (Model1 full+eng, Model2 A–T+eng)")
    print("• HistGradientBoostingClassifier + rule-first AI overrides")
    print("• Rich engineered features matching generator logic")
    print("• SHAP (permutation) + confusion matrices + per-row plots")
    print("• Same artifacts filenames and Excel structure")


🎯 Pharma PoC ML Pipeline - Usage Examples

1. Running full ML pipeline...
🚀 Pharma PoC ML Pipeline initialized (upgraded)
🚀 STARTING Pharma PoC ML PIPELINE
📋 Pipeline steps:
   1. Load and prepare data
   2. Identify significant features
   3. Train models (AH, AI)
   4. Analyze feature importance & SHAP
   5. Apply predictions with insights (rule-first overrides)
   6. Save results to Excel
📁 Loading data from: Pharma_poc_custom_1000_optimized.xlsx
✅ Loaded 1060 records
📊 Case distribution:
   Step 1 - Train: 533 (50.3%)
   Step 2 - Suggest: 277 (26.1%)
   Step 3 - Forecast: 250 (23.6%)
🎯 Training data: 533
🎯 Step 2 (Suggest): 277
🎯 Step 3 (Forecast): 250

🔍 Identifying significant features...
📊 Model 1 base features (15):
    1. NPI PHARMACY NPI
    2. SERV DATE DATE OF SERVICE
    3. PROD SVC ID PRODUCT/SERVICE ID
    4. DRUG NAME N/A
    5. QTY DISP QUANTITY DISPENSED
    6. WAC UNIT PRICE WAC UNIT PRICE
    7. MFP UNIT PRICE MFP UNIT PRICE
    8. EST REIMB AMT Estimated MFG Reimbu

PermutationExplainer explainer: 2it [00:17, 17.43s/it]                                                                 


   ✅ SHAP test ok: AH (1, 36, 4), AI (1, 36, 5)
   Step 2 progress: 1/277
   Step 2 progress: 51/277
   Step 2 progress: 101/277
   Step 2 progress: 151/277
   Step 2 progress: 201/277
   Step 2 progress: 251/277
✅ Updated 277 Step 2 rows
🔸 Model 2 → Step 3 (Forecast)
   Predictions: AH=250, AI=250
   Testing SHAP for model2...
   ✅ SHAP test ok: AH (1, 34, 4), AI (1, 34, 5)
   Step 3 progress: 1/250
   Step 3 progress: 51/250
   Step 3 progress: 101/250
   Step 3 progress: 151/250
   Step 3 progress: 201/250
✅ Updated 250 Step 3 rows

📊 Prediction Summary:
Step 2 AH: {'N907': 119, 'N908': 21, 'N910': 43, 'N911': 94}
Step 2 AI: {'Sent to collection due to MFG timing difference': 24, 'Write off due to WAC UNIT PRICE diff from WAC Medispan': 8, 'Write off due to duplicate payment': 95, 'Write off due to manual data entry error': 128, 'Write off due to payment applied to the wrong claim': 22}
Step 3 AH: {'N907': 40, 'N908': 16, 'N910': 21, 'N911': 173}
Step 3 AI: {'Sent to collection due 

In [4]:
1+1

2