In [None]:
import pandas as pd
import numpy as np
import requests
import shelve
import os
from datetime import datetime
import matplotlib.pyplot as plt
import re
import scipy.stats as stats
from scipy.stats import pearsonr, spearmanr

In [None]:
user="lholguin"
#user in personal pc1 <- "asus"

In [34]:
class NDCATCProcessor:

    def __init__(self, year, base_path=None):
        self.year = year
        self.base_path = base_path or rf"c:\Users\{user}\OneDrive - purdue.edu\VS code\Data"
        self.df_cleaned = None
        self.df_merged = None
        self.atc_mapping = None
        
    def clean_sdud_data(self):

        csv_file = os.path.join(self.base_path, f"SDUD\\SDUD{self.year}.csv")
        print(f"Reading CSV: {csv_file}")
        
        df = pd.read_csv(csv_file, dtype={'NDC': 'object'})
        print(f"Initial rows: {len(df):,}")
        
        # Filter data
        df_filtered = df.dropna(subset=['Units Reimbursed', 'Number of Prescriptions']) #dropping na values
        df_filtered = df_filtered[df_filtered['State'] != 'XX']
        
        print(f"After cleaning: {len(df_filtered):,} rows, {df_filtered['NDC'].nunique():,} unique NDCs")
        
        self.df_cleaned = df_filtered
        return self.df_cleaned
    
    def adding_key(self):
        if self.df_cleaned is None:
            raise ValueError("Run clean_sdud_data() first")
        
        self.df_cleaned['record_id'] = (
            self.df_cleaned['State'].astype(str) + "_" +
            self.df_cleaned['Year'].astype(str) + "_" +
            self.df_cleaned['Quarter'].astype(str) + "_" +
            self.df_cleaned['Utilization Type'].astype(str) + "_" +
            self.df_cleaned['NDC'].astype(str)
        )
        
        print(f"Created {len(self.df_cleaned):,} record IDs")
        return self.df_cleaned
    
    def generate_ndc_txt(self, output_filename=None):
        if 'record_id' not in self.df_cleaned.columns:
            raise ValueError("Run adding_key() first")
            
        output_filename = output_filename or f"NDCNEW_{self.year}.txt"
        output_path = os.path.join(self.base_path, f"ATC\\text_files\\{output_filename}")
        
        unique_pairs = self.df_cleaned[['NDC', 'record_id']].drop_duplicates()
        
        with open(output_path, 'w') as f:
            f.write("NDC\trecord_id\n")
            for _, row in unique_pairs.iterrows():
                f.write(f"{row['NDC']}\t{row['record_id']}\n")
        
        print(f"Exported {unique_pairs['record_id'].nunique():,} unique records to {output_path}")
        return output_path
    
    def analyze_atc4_mapping(self):

        if 'record_id' not in self.df_cleaned.columns:
            raise ValueError("Run adding_key() first")
            
        atc4_path = os.path.join(self.base_path, f"ATC\\ATC4_classes\\NDCNEW_{self.year}_ATC4_classes.csv")
        
        # Load ATC4 mapping
        df_atc4 = pd.read_csv(atc4_path, dtype={'NDC': 'object', 'record_id': 'string'})
        df_atc4['NDC'] = df_atc4['NDC'].str.zfill(11)
        
        print(f"ATC4 file: {len(df_atc4):,} rows, {df_atc4['NDC'].nunique():,} unique NDCs")
        
        # Ensure consistent types
        self.df_cleaned['record_id'] = self.df_cleaned['record_id'].astype('string')
        self.df_cleaned['NDC'] = self.df_cleaned['NDC'].astype('object')
        
        # Merge on both record_id and NDC
        self.atc_mapping = pd.merge(
            self.df_cleaned,
            df_atc4[['record_id', 'NDC', 'ATC4 Class']],
            on=['record_id', 'NDC'],
            how='left'
        )
        #deduplication after merge by record_id
        before_count=len(self.atc_mapping) 
        self.atc_mapping=self.atc_mapping.drop_duplicates(subset='record_id', keep='first')
        
        total = len(self.atc_mapping)
     
        mapped = self.atc_mapping['ATC4 Class'].notna().sum()
        print(f"Merged: {total:,} records, {mapped:,} with ATC4 ({mapped/total*100:.1f}%)")
        
        missing = total - mapped
        if missing > 0:
            print(f"Missing: {missing:,} records, {self.atc_mapping[self.atc_mapping['ATC4 Class'].isna()]['NDC'].nunique():,} unique NDCs")
        
        return self.atc_mapping
    
    def analyze_atc_distribution(self, level='ATC3'):

        if self.atc_mapping is None:
            raise ValueError("Run analyze_atc4_mapping() first")
        
        records = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()].copy()
        
        if len(records) == 0:
            print("No records with valid ATC4 mappings.")
            return None
        
        # Create ATC level column if needed
        if level == 'ATC3':
            records['ATC3 Class'] = records['ATC4 Class'].str[:4]
            class_col = 'ATC3 Class'
        elif level == 'ATC2':
            records['ATC2 Class'] = records['ATC4 Class'].str[:3]
            class_col = 'ATC2 Class'
        else:
            class_col = 'ATC4 Class'
        
        # Count classes per record_id
        per_record = records.groupby('record_id')[class_col].nunique().reset_index()
        per_record.columns = ['record_id', 'num_classes']
        
        distribution = per_record['num_classes'].value_counts().sort_index()
        
        print(f"\n{level} CLASSES PER RECORD_ID:")
        for n_classes, count in distribution.items():
            pct = (count / len(per_record)) * 100
            print(f"  {n_classes} class(es): {count:,} records ({pct:.1f}%)")
        
        print(f"\nSummary:")
        print(f"  Avg {level} per record: {per_record['num_classes'].mean():.2f}")
        print(f"  Max {level} per record: {per_record['num_classes'].max()}")
        
        return per_record

    def fetch_atc_names(self, cache_path=None):
        """Fetch ATC class names (ATC4, ATC3, ATC2) from RxNav API."""
        if self.atc_mapping is None:
            raise ValueError("Must run analyze_atc4_mapping() first")
        
        if cache_path is None:
            cache_path = os.path.join(self.base_path, "ATC\\cache_files\\atc_names_cache")
        
        print(f"\n{'='*60}")
        print("FETCHING ATC CLASS NAMES")
        print(f"{'='*60}")
        print(f"Using cache: {cache_path}")
        
        # Get only records with valid ATC4 mappings
        df_with_atc = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()].copy()
        
        # Create ATC3 and ATC2 columns from ATC4
        print("\nCreating ATC3 and ATC2 columns from ATC4...")
        df_with_atc['ATC3 Class'] = df_with_atc['ATC4 Class'].str[:4]
        df_with_atc['ATC2 Class'] = df_with_atc['ATC4 Class'].str[:3]
        
        # Get unique codes for each level
        unique_atc4 = df_with_atc['ATC4 Class'].dropna().unique()
        unique_atc3 = df_with_atc['ATC3 Class'].dropna().unique()
        unique_atc2 = df_with_atc['ATC2 Class'].dropna().unique()
        
        # Filter out invalid codes
        unique_atc4 = [c for c in unique_atc4 if c not in ['No ATC Mapping Found', 'No RxCUI Found', '']]
        unique_atc3 = [c for c in unique_atc3 if c not in ['No ATC Mapping Found', 'No RxCUI Found', '', 'No ', 'No']]
        unique_atc2 = [c for c in unique_atc2 if c not in ['No ATC Mapping Found', 'No RxCUI Found', '', 'No ', 'No']]
        
        print(f"\nUnique codes to fetch:")
        print(f"  ATC4: {len(unique_atc4)}")
        print(f"  ATC3: {len(unique_atc3)}")
        print(f"  ATC2: {len(unique_atc2)}")
        
        # Build mappings
        atc4_names = {}
        atc3_names = {}
        atc2_names = {}
        
        with shelve.open(cache_path) as cache:
            start_time = datetime.now()
            
            print("\nFetching ATC4 names...")
            for code in unique_atc4:
                atc4_names[code] = self._get_atc_name(code, cache)
            
            print("Fetching ATC3 names...")
            for code in unique_atc3:
                atc3_names[code] = self._get_atc_name(code, cache)
            
            print("Fetching ATC2 names...")
            for code in unique_atc2:
                atc2_names[code] = self._get_atc_name(code, cache)
            
            print(f"\nTotal processing time: {(datetime.now() - start_time).total_seconds()/60:.1f} minutes")
        
        # Apply names to all records in atc_mapping
        print("\nApplying names to dataframe...")
        self.atc_mapping['ATC3 Class'] = self.atc_mapping['ATC4 Class'].str[:4]
        self.atc_mapping['ATC2 Class'] = self.atc_mapping['ATC4 Class'].str[:3]
        
        self.atc_mapping['ATC4_Name'] = self.atc_mapping['ATC4 Class'].map(atc4_names).fillna('')
        self.atc_mapping['ATC3_Name'] = self.atc_mapping['ATC3 Class'].map(atc3_names).fillna('')
        self.atc_mapping['ATC2_Name'] = self.atc_mapping['ATC2 Class'].map(atc2_names).fillna('')
        
        print(f"\nATC names added successfully!")
        print("\nSample output:")
        sample = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()][['NDC', 'record_id', 'ATC4 Class', 'ATC4_Name', 'ATC3 Class', 'ATC3_Name', 'ATC2 Class', 'ATC2_Name']].head(5)
        print(sample.to_string())
        
        return self.atc_mapping
    
    def prepare_final_dataframe(self):

        if self.atc_mapping is None:
            raise ValueError("Run fetch_atc_names() first")
        
        self.df_merged = self.atc_mapping.copy()
        
        # Scale units
        self.df_merged['Units Reimbursed'] = self.df_merged['Units Reimbursed'] / 1e9
        self.df_merged['Number of Prescriptions'] = self.df_merged['Number of Prescriptions'] / 1e6
        
        total = len(self.df_merged)
        mapped = self.df_merged['ATC4 Class'].notna().sum()
        
        print(f"\nFinal Statistics:")
        print(f"  Records: {total:,} ({mapped:,} with ATC4, {mapped/total*100:.1f}%)")
        print(f"  Units Reimbursed: {self.df_merged['Units Reimbursed'].sum():.2f} Billion")
        print(f"  Prescriptions: {self.df_merged['Number of Prescriptions'].sum():.2f} Million")
        
        return self.df_merged
    
    def _get_atc_name(self, atc_code, cache):
        """Helper: Fetch ATC name from RxNav API with caching."""
        cache_key = f"atc_name:{atc_code}"
        if cache_key in cache:
            return cache[cache_key]
        
        try:
            url = f"https://rxnav.nlm.nih.gov/REST/rxclass/class/byId.json?classId={atc_code}"
            response = requests.get(url)
            response.raise_for_status()
            data = response.json()
            
            if 'rxclassMinConceptList' in data and 'rxclassMinConcept' in data['rxclassMinConceptList']:
                concepts = data['rxclassMinConceptList']['rxclassMinConcept']
                if concepts:
                    name = concepts[0].get('className', '')
                    cache[cache_key] = name
                    return name
            
            cache[cache_key] = ''
            return ''
            
        except Exception as e:
            print(f"Error retrieving {atc_code}: {e}")
            cache[cache_key] = ''
            return ''

    def export_merged_data(self, output_filename=None, show_details=True):

        if self.df_merged is None:
            raise ValueError("Run prepare_final_dataframe() first")
            
        output_filename = output_filename or f"merged_NEWdata_{self.year}.csv"
        output_path = os.path.join(self.base_path, f"ATC\\merged_data\\{output_filename}")
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        # Check duplicates
        initial_count = len(self.df_merged)
        duplicate_count = self.df_merged['record_id'].duplicated().sum()
        
        print(f"\nDeduplication Check:")
        print(f"  Before: {initial_count:,} rows")
        print(f"  Duplicates: {duplicate_count:,}")
        
        # Show sample duplicates if requested
        if show_details and duplicate_count > 0:
            dup_records = self.df_merged[self.df_merged['record_id'].duplicated(keep=False)].sort_values('record_id')
            sample_ids = dup_records['record_id'].unique()[:2]
            
            print(f"\nSample duplicate record_ids:")
            for rid in sample_ids:
                sample = self.df_merged[self.df_merged['record_id'] == rid][
                    ['record_id', 'NDC', 'State', 'ATC4 Class', 'ATC2 Class']
                ]
                print(f"\n{rid}:")
                print(sample.to_string(index=False))
        
        # Deduplicate and export
        df_final = self.df_merged.drop_duplicates(subset='record_id', keep='first')
        df_final.to_csv(output_path, index=False)
        
        print(f"\n  After: {len(df_final):,} rows")
        print(f"  Removed: {initial_count - len(df_final):,}")
        print(f"\nExported to: {output_path}")

        #Showing final atc class mapping in %
        final_count = len(df_final)
        final_mapped_records = df_final['ATC4 Class'].notna().sum()
        final_unmapped_records = final_count - final_mapped_records
        final_mapped_ndcs=df_final[df_final['ATC4 Class'].notna()]['NDC'].nunique()
        final_unmapped_ndcs=df_final['NDC'].nunique()
        
        # Aggregate metrics
        agg = df_final.groupby('record_id').agg({
            'Units Reimbursed': 'sum',
            'Number of Prescriptions': 'sum'
        })
        
        print(f"\nAggregated Totals:")
        print(f"  Units Reimbursed: {agg['Units Reimbursed'].sum():.3f} Billion")
        print(f"  Number of Prescriptions: {agg['Number of Prescriptions'].sum():3f} Million")
        
        return output_path
    
    def export_unscaled_data(self, output_filename=None, show_details=True):
        """Export merged data without scaling units."""
        if self.atc_mapping is None:
            raise ValueError("Run fetch_atc_names() first")
        
        output_filename = output_filename or f"MUD_{self.year}.csv"
        output_path = os.path.join(self.base_path, f"ATC\\merged_data\\unscaled_data\\{output_filename}")
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        # Use atc_mapping directly (no scaling)
        df_unscaled = self.atc_mapping.copy()
        
        # Check duplicates
        initial_count = len(df_unscaled)
        duplicate_count = df_unscaled['record_id'].duplicated().sum()
        
        print(f"\nDeduplication Check:")
        print(f"  Before: {initial_count:,} rows")
        print(f"  Duplicates: {duplicate_count:,}")
        
        # Show sample duplicates if requested
        if show_details and duplicate_count > 0:
            dup_records = df_unscaled[df_unscaled['record_id'].duplicated(keep=False)].sort_values('record_id')
            sample_ids = dup_records['record_id'].unique()[:2]
            
            print(f"\nSample duplicate record_ids:")
            for rid in sample_ids:
                sample = df_unscaled[df_unscaled['record_id'] == rid][
                    ['record_id', 'NDC', 'State', 'ATC4 Class', 'ATC2 Class']
                ]
                print(f"\n{rid}:")
                print(sample.to_string(index=False))
        
        # Deduplicate and export
        df_final = df_unscaled.drop_duplicates(subset='record_id', keep='first')
        df_final.to_csv(output_path, index=False)
        
        print(f"\n  After: {len(df_final):,} rows")
        print(f"  Removed: {initial_count - len(df_final):,}")
        print(f"\nExported to: {output_path}")
        
        # Show final ATC class mapping
        final_count = len(df_final)
        final_mapped_records = df_final['ATC4 Class'].notna().sum()
        
        # Aggregate metrics (unscaled)
        agg = df_final.groupby('record_id').agg({
            'Units Reimbursed': 'sum',
            'Number of Prescriptions': 'sum'
        })
        
        print(f"\nAggregated Totals (Unscaled):")
        print(f"  Units Reimbursed: {agg['Units Reimbursed'].sum():,.0f}")
        print(f"  Number of Prescriptions: {agg['Number of Prescriptions'].sum():,.0f}")
        print(f"  Mapped Records: {final_mapped_records:,} ({final_mapped_records/final_count*100:.1f}%)")
        
        return output_path

In [None]:
class NDCATC_overview:

    @staticmethod
    def create_multi_year_distribution_analysis(years_list, base_path=None):

        if base_path is None:
            base_path = rf"c:\Users\{user}\OneDrive - purdue.edu\VS code\Data"
            
        print("Creating Multi-Year ATC Distribution Analysis...")
        print("="*70)
        
        results = {
            'ATC4_1_class': {}, 'ATC4_2_classes': {}, 'ATC4_3+_classes': {},
            'ATC3_1_class': {}, 'ATC3_2_classes': {}, 'ATC3_3+_classes': {},
            'ATC2_1_class': {}, 'ATC2_2_classes': {}, 'ATC2_3+_classes': {}
        }
        
        for year in years_list:
            print(f"Processing {year}...", end=" ")
            try:
                # Load the pre-processed CSV file
                csv_path = os.path.join(base_path, f"ATC\\merged_data\\merged_NEWdata_{year}.csv")
                df_merged = pd.read_csv(csv_path)
                
                records = df_merged[df_merged['ATC4 Class'].notna()].copy()
                if records.empty:
                    print("No ATC records")
                    for key in results.keys():
                        results[key][year] = "N/A"
                    continue
                    
                records['ATC2 Class'] = records['ATC4 Class'].str[:3]
                records['ATC3 Class'] = records['ATC4 Class'].str[:4]
                
                # Calculate distributions for each level
                for level, col in [('ATC4', 'ATC4 Class'), ('ATC3', 'ATC3 Class'), ('ATC2', 'ATC2 Class')]:
                    per_record = records.groupby('record_id')[col].nunique()
                    dist = per_record.value_counts().sort_index()
                    total = len(per_record)
                    
                    results[f'{level}_1_class'][year] = f"{(dist.get(1, 0) / total * 100):.1f}%"
                    results[f'{level}_2_classes'][year] = f"{(dist.get(2, 0) / total * 100):.1f}%"
                    results[f'{level}_3+_classes'][year] = f"{(dist[dist.index >= 3].sum() / total * 100):.1f}%"
                
                print("✓")
            except FileNotFoundError:
                print(f"✗ File not found: {csv_path}")
                for key in results.keys():
                    results[key][year] = "N/A"
            except Exception as e:
                print(f"✗ Error: {e}")
                for key in results.keys():
                    results[key][year] = "N/A"
        
        df_percentages = pd.DataFrame(results).T
        print(f"\nATC DISTRIBUTION PERCENTAGES ACROSS YEARS")
        print("="*60)
        print(df_percentages)
        return df_percentages
    
    @staticmethod
    def analyze_general_atc_overview_by_state(years_list, base_path=None, state_filter=None):

        if base_path is None:
            base_path = rf"c:\Users\{user}\OneDrive - purdue.edu\VS code\Data"

        if state_filter:
            if isinstance(state_filter, str):
                state_filter = [state_filter]
            print(f"Creating ATC2 & ATC3 Overview by State for: {', '.join(state_filter)}")
        else:
            print("Creating ATC2 & ATC3 Overview by State (All States)")
        print("="*78)
        
        # Results will be organized by state
        state_results = {}
        all_states = set()
        
        # Collect data for all years first
        for year in years_list:
            print(f"Processing {year}...", end=" ")
            try:
                csv_path = os.path.join(base_path, f"ATC\\merged_data\\merged_NEWdata_{year}.csv")
                df_merged = pd.read_csv(csv_path)
                
                records = df_merged[df_merged['ATC4 Class'].notna()].copy()
                if records.empty:
                    print("No ATC records")
                    continue
                
                # Filter by state if specified
                if state_filter:
                    records = records[records['State'].isin(state_filter)]
                    if records.empty:
                        print(f"No records for states {state_filter}")
                        continue
                    states_to_process = state_filter
                else:
                    states_to_process = records['State'].unique()
                    all_states.update(states_to_process)
                
                for state in states_to_process:
                    state_records = records[records['State'] == state]
                    if state_records.empty:
                        continue
                    
                    # Initialize state data structure
                    if state not in state_results:
                        state_results[state] = {
                            'atc2_year_results': {},
                            'atc3_year_results': {}
                        }
                    
                    # ATC2 summary for this state
                    pairs2 = state_records[['record_id', 'NDC', 'ATC2 Class']].drop_duplicates()
                    atc2_summary = pairs2.groupby('ATC2 Class').agg(
                        Unique_NDCs=('NDC', 'nunique'),
                        Total_Records=('record_id', 'nunique')
                    ).sort_values('Unique_NDCs', ascending=False)
                    atc2_summary['Percentage_of_NDCs'] = (
                        atc2_summary['Unique_NDCs'] / pairs2['NDC'].nunique() * 100
                    ).round(1)
                    
                    # ATC3 summary for this state
                    pairs3 = state_records[['record_id', 'NDC', 'ATC3 Class']].drop_duplicates()
                    atc3_summary = pairs3.groupby('ATC3 Class').agg(
                        Unique_NDCs=('NDC', 'nunique'),
                        Total_Records=('record_id', 'nunique')
                    ).sort_values('Unique_NDCs', ascending=False)
                    atc3_summary['Percentage_of_NDCs'] = (
                        atc3_summary['Unique_NDCs'] / pairs3['NDC'].nunique() * 100
                    ).round(1)
                    
                    state_results[state]['atc2_year_results'][year] = atc2_summary
                    state_results[state]['atc3_year_results'][year] = atc3_summary
                
                processed_states = len([s for s in states_to_process if s in state_results])
                print(f"✓ (States: {processed_states})")
                
            except FileNotFoundError:
                print(f"✗ File not found")
            except Exception as e:
                print(f"✗ Error: {e}")
        
        # Process each state with the same format as original method
        final_state_results = {}
        
        states_to_analyze = state_filter if state_filter else sorted(all_states)
        
        for state in states_to_analyze:
            if state not in state_results:
                continue
                
            print(f"\n" + "="*60)
            print(f"PROCESSING STATE: {state}")
            print("="*60)
            
            atc2_year_results = state_results[state]['atc2_year_results']
            atc3_year_results = state_results[state]['atc3_year_results']
            
            # Print summaries for this state (same format as original)
            print(f"\nUNIQUE NDCs PER ATC2 CLASS BY YEAR - {state}")
            print("="*50)
            for year in years_list:
                if year in atc2_year_results and not atc2_year_results[year].empty:
                    print(f"\n{year}: {len(atc2_year_results[year])} classes, "
                        f"{atc2_year_results[year]['Unique_NDCs'].sum():,} total NDCs")
                    print("Top 10:")
                    print(atc2_year_results[year].head(10))
            
            print(f"\nUNIQUE NDCs PER ATC3 CLASS BY YEAR - {state}")
            print("="*50)
            for year in years_list:
                if year in atc3_year_results and not atc3_year_results[year].empty:
                    print(f"\n{year}: {len(atc3_year_results[year])} classes, "
                        f"{atc3_year_results[year]['Unique_NDCs'].sum():,} total NDCs")
                    print("Top 10:")
                    print(atc3_year_results[year].head(10))
            
            # Build comparison tables (same logic as original)
            def build_comparison(year_tables):
                all_classes = set()
                for tbl in year_tables.values():
                    if not tbl.empty:
                        all_classes.update(tbl.index.tolist())
                comp = {cls: {y: int(year_tables[y].loc[cls, 'Unique_NDCs']) 
                            if y in year_tables and not year_tables[y].empty and cls in year_tables[y].index else 0
                            for y in years_list}
                        for cls in sorted(all_classes)}
                df = pd.DataFrame(comp).T
                return df.loc[df.sum(axis=1).sort_values(ascending=False).index]
            
            atc2_comparison = build_comparison(atc2_year_results)
            atc3_comparison = build_comparison(atc3_year_results)
            
            # Create cumulative frequency tables (same logic as original)
            def create_cumulative_frequency_table(comparison_df, level_name):
                total_ndcs = comparison_df.sum(axis=1).sort_values(ascending=False)
                
                freq_table = pd.DataFrame({
                    'ATC_Class': total_ndcs.index,
                    'Total_Unique_NDCs': total_ndcs.values,
                    'Percentage': (total_ndcs.values / total_ndcs.sum() * 100).round(2)
                })
                
                freq_table['Cumulative_NDCs'] = freq_table['Total_Unique_NDCs'].cumsum()
                freq_table['Cumulative_Percentage'] = freq_table['Percentage'].cumsum().round(2)
                
                freq_table.reset_index(drop=True, inplace=True)
                freq_table.index = freq_table.index + 1
                
                return freq_table
            
            atc2_freq_table = create_cumulative_frequency_table(atc2_comparison, 'ATC2')
            atc3_freq_table = create_cumulative_frequency_table(atc3_comparison, 'ATC3')
            
            # Store results for this state
            final_state_results[state] = {
                'atc2_year_results': atc2_year_results,
                'atc3_year_results': atc3_year_results,
                'atc2_comparison': atc2_comparison,
                'atc3_comparison': atc3_comparison,
                'atc2_freq_table': atc2_freq_table,
                'atc3_freq_table': atc3_freq_table
            }
        
        return final_state_results
    
    @staticmethod
    def get_atc_ndc_details(year, top_n=10, base_path=None):
    
        if base_path is None:
            base_path = rf"c:\Users\{user}\OneDrive - purdue.edu\VS code\Data"
            
        print(f"Analyzing ATC-NDC details for {year}...")
        print("="*60)
        
        try:
            # Load the pre-processed CSV file
            csv_path = os.path.join(base_path, f"ATC\\merged_data\\merged_NEWdata_{year}.csv")
            df_merged = pd.read_csv(csv_path)
            
            records = df_merged[df_merged['ATC4 Class'].notna()].copy()
            if records.empty:
                print("No records with ATC mapping")
                return pd.DataFrame(), pd.DataFrame()
                
            records['ATC2 Class'] = records['ATC4 Class'].str[:3]
            records['ATC3 Class'] = records['ATC4 Class'].str[:4]
            
            # ATC2 details
            atc2_details = records.groupby('ATC2 Class').agg(
                Unique_NDCs=('NDC', 'nunique'),
                Total_Records=('record_id', 'nunique')
            ).sort_values('Unique_NDCs', ascending=False).head(top_n)
            
            # ATC3 details
            atc3_details = records.groupby('ATC3 Class').agg(
                Unique_NDCs=('NDC', 'nunique'),
                Total_Records=('record_id', 'nunique')
            ).sort_values('Unique_NDCs', ascending=False).head(top_n)
            
            print(f"\nTop {top_n} ATC2 Classes:")
            print(atc2_details)
            print(f"\nTop {top_n} ATC3 Classes:")
            print(atc3_details)
            
            return atc2_details, atc3_details
            
        except FileNotFoundError:
            print(f"✗ File not found: {csv_path}")
            return pd.DataFrame(), pd.DataFrame()
        except Exception as e:
            print(f"✗ Error: {e}")
            return pd.DataFrame(), pd.DataFrame()
    
    @staticmethod
    def export_cumulative_frequency_excel(years_list, level='ATC2', base_path=None, output_filename=None, 
                                          include_ndc_counts=True, by_state=False, state_filter=None):

        if base_path is None:
            base_path = rf"c:\Users\{user}\OneDrive - purdue.edu\VS code\Data"
        
        # Print header
        if by_state:
            states_msg = ', '.join(state_filter) if state_filter else 'All States'
            print(f"Creating {level} Cumulative Frequency Analysis Excel for States: {states_msg}")
        else:
            print(f"Creating {level} Cumulative Frequency Analysis Excel...")
        print("="*70)
        
        # Helper function to create ATC level column
        def create_atc_level_column(df, level):
            if level == 'ATC2':
                return df['ATC4 Class'].str[:3]
            elif level == 'ATC3':
                return df['ATC4 Class'].str[:4]
            else:  # ATC4
                return df['ATC4 Class']
        
        # Helper function to get name mapping
        def get_name_mapping(records, level):
            name_col_map = {'ATC2': 'ATC2_Name', 'ATC3': 'ATC3_Name', 'ATC4': 'ATC4_Name'}
            name_col = name_col_map.get(level)
            
            if name_col and name_col in records.columns:
                return records[['ATC_Level', name_col]].drop_duplicates().set_index('ATC_Level')[name_col].to_dict()
            return {}
        
        # Helper function to process year data
        def process_year_data(csv_path, level, state_filter=None):
            try:
                df_merged = pd.read_csv(csv_path)
                records = df_merged[df_merged['ATC4 Class'].notna()].copy()
                
                if records.empty:
                    return None, None, None, None
                
                # Filter by state if needed
                if state_filter:
                    records = records[records['State'].isin(state_filter)]
                    if records.empty:
                        return None, None, None, None
                
                records['ATC_Level'] = create_atc_level_column(records, level)
                name_mapping = get_name_mapping(records, level)
                
                # Aggregate financial data
                financial = records.groupby('ATC_Level').agg(
                    Units_Reimbursed=('Units Reimbursed', 'sum'),
                    Number_of_Prescriptions=('Number of Prescriptions', 'sum')
                )
                
                # Count unique NDCs
                ndc_counts = records.groupby('ATC_Level').agg(
                    Unique_NDCs=('NDC', 'nunique')
                )
                
                states_processed = records['State'].nunique() if 'State' in records.columns else 1
                
                return financial, ndc_counts, name_mapping, states_processed
                
            except Exception as e:
                print(f"Error processing file: {e}")
                return None, None, None, None
        
        # Collect data for all years
        if by_state:
            state_data = {}  # {state: {year: (financial, ndc_counts, name_mapping)}}
            all_states = set()
        else:
            year_results = {}
            ndc_counts = {}
            name_mapping = {}
        
        # Process each year
        for year in years_list:
            print(f"Processing {year}...", end=" ")
            csv_path = os.path.join(base_path, f"ATC\\merged_data\\merged_NEWdata_{year}.csv")
            
            if by_state:
                # Load full data first
                try:
                    df = pd.read_csv(csv_path)
                    records = df[df['ATC4 Class'].notna()].copy()
                    
                    if state_filter:
                        states_to_process = state_filter if isinstance(state_filter, list) else [state_filter]
                    else:
                        states_to_process = records['State'].unique()
                        all_states.update(states_to_process)
                    
                    for state in states_to_process:
                        financial, ndcs, names, _ = process_year_data(csv_path, level, [state])
                        
                        if financial is not None:
                            if state not in state_data:
                                state_data[state] = {}
                            state_data[state][year] = (financial, ndcs, names)
                    
                    print(f"✓ (States: {len(states_to_process)})")
                except:
                    print("✗")
            else:
                financial, ndcs, names, _ = process_year_data(csv_path, level)
                
                if financial is not None:
                    year_results[year] = financial
                    ndc_counts[year] = ndcs
                    name_mapping.update(names)
                    print(f"✓ ({len(financial)} classes)")
                else:
                    print("✗")
        
        # Helper function to build comparison dataframes
        def build_comparison_dfs(year_financial_dict, year_ndc_dict, years_list):
            # Collect all unique ATC classes
            all_classes = set()
            for df in year_financial_dict.values():
                if df is not None and not df.empty:
                    all_classes.update(df.index)
            for df in year_ndc_dict.values():
                if df is not None and not df.empty:
                    all_classes.update(df.index)
            
            if not all_classes:
                return None, None, None
            
            # Build comparison dictionaries
            units_comp = {}
            presc_comp = {}
            ndc_comp = {}
            
            for cls in sorted(all_classes):
                units_comp[cls] = {
                    y: float(year_financial_dict[y].loc[cls, 'Units_Reimbursed']) 
                    if y in year_financial_dict and year_financial_dict[y] is not None 
                    and not year_financial_dict[y].empty and cls in year_financial_dict[y].index 
                    else 0.0 
                    for y in years_list
                }
                
                presc_comp[cls] = {
                    y: float(year_financial_dict[y].loc[cls, 'Number_of_Prescriptions']) 
                    if y in year_financial_dict and year_financial_dict[y] is not None 
                    and not year_financial_dict[y].empty and cls in year_financial_dict[y].index 
                    else 0.0 
                    for y in years_list
                }
                
                ndc_comp[cls] = {
                    y: int(year_ndc_dict[y].loc[cls, 'Unique_NDCs']) 
                    if y in year_ndc_dict and year_ndc_dict[y] is not None 
                    and not year_ndc_dict[y].empty and cls in year_ndc_dict[y].index 
                    else 0 
                    for y in years_list
                }
            
            # Convert to DataFrames and sort by total units
            units_df = pd.DataFrame(units_comp).T
            presc_df = pd.DataFrame(presc_comp).T
            ndc_df = pd.DataFrame(ndc_comp).T
            
            units_total = units_df.sum(axis=1).sort_values(ascending=False)
            units_df = units_df.loc[units_total.index]
            presc_df = presc_df.loc[units_total.index]
            ndc_df = ndc_df.loc[units_total.index]
            
            return units_df, presc_df, ndc_df
        
        # Helper function to create cumulative frequency DataFrame
        def create_cumulative_df(comparison_df, metric_name, name_mapping):
            totals = comparison_df.sum(axis=1)
            total_sum = totals.sum()
            
            cumulative_total = 0
            df_data = []
            
            for atc_class in comparison_df.index:
                class_total = totals[atc_class]
                cumulative_total += class_total
                percentage = (class_total / total_sum * 100) if total_sum > 0 else 0
                cumulative_pct = (cumulative_total / total_sum * 100) if total_sum > 0 else 0
                
                row = {'ATC_Class': atc_class}
                
                # Add ATC name if available
                row['ATC_Name'] = name_mapping.get(atc_class, '')
                
                # Add year-by-year data
                for year in years_list:
                    if metric_name == 'NDCs':
                        row[f'{metric_name}_{year}'] = int(comparison_df.loc[atc_class, year])
                    else:
                        row[f'{metric_name}_{year}'] = round(comparison_df.loc[atc_class, year], 3)
                
                # Add summary columns
                if metric_name == 'NDCs':
                    row[f'Total_{metric_name}'] = int(class_total)
                    row['Percentage'] = round(percentage, 2)
                    row[f'Cumulative_{metric_name}'] = int(cumulative_total)
                    row['Cumulative_Percentage_NDCs'] = round(cumulative_pct, 2)
                else:
                    row[f'Total_{metric_name}'] = round(class_total, 3)
                    row['Percentage'] = round(percentage, 2)
                    row[f'Cumulative_{metric_name}'] = round(cumulative_total, 3)
                    row['Cumulative_Percentage'] = round(cumulative_pct, 2)
                
                df_data.append(row)
            
            return pd.DataFrame(df_data)
        
        # Export function
        def export_to_excel(units_df, prescriptions_df, ndc_df, output_path, include_ndc_counts):
            with pd.ExcelWriter(output_path, engine='openpyxl') as writer:
                units_df.to_excel(writer, sheet_name='Units_Reimbursed', index=False)
                prescriptions_df.to_excel(writer, sheet_name='Prescriptions', index=False)
                
                if include_ndc_counts:
                    ndc_df.to_excel(writer, sheet_name='NDC_Counts', index=False)
            
            print(f"Exported to Excel: {output_path}")
        
        # Process and export
        output_dir = os.path.join(base_path, "ATC\\exported_analysis")
        os.makedirs(output_dir, exist_ok=True)
        
        if by_state:
            states_to_export = state_filter if state_filter else sorted(all_states)
            all_output_paths = {}
            
            for state in states_to_export:
                if state not in state_data:
                    continue
                
                # Extract data for this state
                state_year_financial = {y: data[0] for y, data in state_data[state].items()}
                state_year_ndc = {y: data[1] for y, data in state_data[state].items()}
                state_name_mapping = {}
                for data in state_data[state].values():
                    state_name_mapping.update(data[2])
                
                # Build comparison DataFrames
                units_comp_df, presc_comp_df, ndc_comp_df = build_comparison_dfs(
                    state_year_financial, state_year_ndc, years_list
                )
                
                if units_comp_df is None:
                    continue
                
                # Create cumulative DataFrames
                units_cum = create_cumulative_df(units_comp_df, 'Units', state_name_mapping)
                presc_cum = create_cumulative_df(presc_comp_df, 'Prescriptions', state_name_mapping)
                ndc_cum = create_cumulative_df(ndc_comp_df, 'NDCs', state_name_mapping)
                
                # Generate output filename
                if output_filename:
                    name_parts = output_filename.rsplit('.', 1)
                    state_output_filename = f"{name_parts[0]}_{state}.{name_parts[1]}"
                else:
                    state_output_filename = f"{level}_Cumulative_Analysis_{state}_with_NDC_Counts.xlsx"
                
                output_path = os.path.join(output_dir, state_output_filename)
                export_to_excel(units_cum, presc_cum, ndc_cum, output_path, include_ndc_counts)
                all_output_paths[state] = output_path
            
            return all_output_paths
        
        else:
            # Build comparison DataFrames
            units_comp_df, presc_comp_df, ndc_comp_df = build_comparison_dfs(
                year_results, ndc_counts, years_list
            )
            
            if units_comp_df is None:
                print("No data found!")
                return None
            
            # Create cumulative DataFrames
            units_cum = create_cumulative_df(units_comp_df, 'Units', name_mapping)
            presc_cum = create_cumulative_df(presc_comp_df, 'Prescriptions', name_mapping)
            ndc_cum = create_cumulative_df(ndc_comp_df, 'NDCs', name_mapping)
            
            # Generate output filename
            if not output_filename:
                output_filename = f"{level}_Cumulative_Analysis_with_NDC_Counts.xlsx"
            
            output_path = os.path.join(output_dir, output_filename)
            export_to_excel(units_cum, presc_cum, ndc_cum, output_path, include_ndc_counts)
            
            return output_path
    
    @staticmethod
    def compare_cumulative_80_analysis(base_path=None, cumulative_threshold=80.0):

        if base_path is None:
            base_path = rf"c:\Users\{user}\OneDrive - purdue.edu\VS code\Data"
        
        print(f"COMPARATIVE ANALYSIS: Indiana vs National ATC2 Classes at {cumulative_threshold}% Threshold")
        print("="*85)
        
        # File paths
        indiana_file = os.path.join(base_path, "ATC\\exported_analysis\\ATC2_Cumulative_Analysis_IN_with_NDC_Counts.xlsx")
        national_file = os.path.join(base_path, "ATC\\exported_analysis\\ATC2_Cumulative_Analysis_with_NDC_Counts.xlsx")
        
        # Load data
        try:
            indiana_units = pd.read_excel(indiana_file, sheet_name='Units_Reimbursed')
            indiana_prescriptions = pd.read_excel(indiana_file, sheet_name='Prescriptions')
            national_units = pd.read_excel(national_file, sheet_name='Units_Reimbursed')
            national_prescriptions = pd.read_excel(national_file, sheet_name='Prescriptions')
            
            print(f"✓ Loaded Indiana data: {len(indiana_units)} ATC2 classes")
            print(f"✓ Loaded National data: {len(national_units)} ATC2 classes")
        except FileNotFoundError:
            print("✗ Error: Required Excel files not found. Please run export functions first.")
            return None
        except Exception as e:
            print(f"✗ Error loading files: {e}")
            return None
        
        # Helper function to get classes at threshold
        def get_classes_at_threshold(df, threshold):
            df_sorted = df.sort_values('Cumulative_Percentage').reset_index(drop=True)
            threshold_idx = df_sorted[df_sorted['Cumulative_Percentage'] >= threshold].index
            
            if len(threshold_idx) > 0:
                return df_sorted.iloc[:threshold_idx[0] + 1]['ATC_Class'].tolist()
            return df_sorted['ATC_Class'].tolist()
        
        # Get classes at threshold for each dataset
        in_units_80 = get_classes_at_threshold(indiana_units, cumulative_threshold)
        in_presc_80 = get_classes_at_threshold(indiana_prescriptions, cumulative_threshold)
        nat_units_80 = get_classes_at_threshold(national_units, cumulative_threshold)
        nat_presc_80 = get_classes_at_threshold(national_prescriptions, cumulative_threshold)
        
        # Calculate overlaps
        in_overlap = set(in_units_80) & set(in_presc_80)
        in_only_units = set(in_units_80) - set(in_presc_80)
        in_only_presc = set(in_presc_80) - set(in_units_80)
        
        nat_overlap = set(nat_units_80) & set(nat_presc_80)
        nat_only_units = set(nat_units_80) - set(nat_presc_80)
        nat_only_presc = set(nat_presc_80) - set(nat_units_80)
        
        # Print Indiana Analysis
        print(f"\n{'='*60}")
        print("1. INDIANA ANALYSIS")
        print(f"{'='*60}")
        print(f"\nClasses reaching {cumulative_threshold}% cumulative:")
        print(f"  Units Reimbursed: {len(in_units_80)} classes")
        print(f"  Prescriptions: {len(in_presc_80)} classes")
        print(f"\nComparison:")
        print(f"  Overlap (both metrics): {len(in_overlap)} classes")
        print(f"  Only in Units: {len(in_only_units)} classes")
        print(f"  Only in Prescriptions: {len(in_only_presc)} classes")
        
        # Print National Analysis
        print(f"\n{'='*60}")
        print("2. NATIONAL ANALYSIS")
        print(f"{'='*60}")
        print(f"\nClasses reaching {cumulative_threshold}% cumulative:")
        print(f"  Units Reimbursed: {len(nat_units_80)} classes")
        print(f"  Prescriptions: {len(nat_presc_80)} classes")
        print(f"\nComparison:")
        print(f"  Overlap (both metrics): {len(nat_overlap)} classes")
        print(f"  Only in Units: {len(nat_only_units)} classes")
        print(f"  Only in Prescriptions: {len(nat_only_presc)} classes")
        
        # Helper function for detailed comparison
        def print_detailed_comparison(set1, set2, metric_name):
            overlap = set1 & set2
            only_in = set1 - set2
            only_nat = set2 - set1
            
            print(f"\n{metric_name.upper()} - Detailed Comparison:")
            print("-" * 50)
            print(f"Classes in BOTH Indiana and National ({len(overlap)}):")
            print(f"  {sorted(overlap) if overlap else 'None'}")
            print(f"\nClasses ONLY in Indiana ({len(only_in)}):")
            print(f"  {sorted(only_in) if only_in else 'None'}")
            print(f"\nClasses ONLY in National ({len(only_nat)}):")
            print(f"  {sorted(only_nat) if only_nat else 'None'}")
        
        # Print detailed comparisons
        print(f"\n{'='*85}")
        print("3. DETAILED CLASS ANALYSIS")
        print(f"{'='*85}")
        
        print_detailed_comparison(set(in_units_80), set(nat_units_80), "Units Reimbursed")
        print_detailed_comparison(set(in_presc_80), set(nat_presc_80), "Prescriptions")
        
        # Helper function to calculate category totals
        def get_category_totals(df_units, df_presc, class_list):
            if not class_list:
                return 0.0, 0.0
            units_total = df_units[df_units['ATC_Class'].isin(class_list)]['Total_Units'].sum()
            presc_total = df_presc[df_presc['ATC_Class'].isin(class_list)]['Total_Prescriptions'].sum()
            return units_total, presc_total
        
        # Calculate totals for all categories
        in_units_only_u, in_units_only_p = get_category_totals(indiana_units, indiana_prescriptions, list(in_only_units))
        in_presc_only_u, in_presc_only_p = get_category_totals(indiana_units, indiana_prescriptions, list(in_only_presc))
        in_overlap_u, in_overlap_p = get_category_totals(indiana_units, indiana_prescriptions, list(in_overlap))
        
        nat_units_only_u, nat_units_only_p = get_category_totals(national_units, national_prescriptions, list(nat_only_units))
        nat_presc_only_u, nat_presc_only_p = get_category_totals(national_units, national_prescriptions, list(nat_only_presc))
        nat_overlap_u, nat_overlap_p = get_category_totals(national_units, national_prescriptions, list(nat_overlap))
        
        # Create summary DataFrame
        totals_summary = pd.DataFrame({
            'Geography': ['Indiana', 'Indiana', 'Indiana', 'National', 'National', 'National'],
            'Category': ['Only in Units', 'Only in Prescriptions', 'In Both (Overlap)'] * 2,
            'Num_Classes': [
                len(in_only_units), len(in_only_presc), len(in_overlap),
                len(nat_only_units), len(nat_only_presc), len(nat_overlap)
            ],
            'Total_Units': [
                in_units_only_u, in_presc_only_u, in_overlap_u,
                nat_units_only_u, nat_presc_only_u, nat_overlap_u
            ],
            'Total_Prescriptions': [
                in_units_only_p, in_presc_only_p, in_overlap_p,
                nat_units_only_p, nat_presc_only_p, nat_overlap_p
            ]
        })
        
        print(f"\n{'='*85}")
        print("4. SUMMARY TOTALS")
        print(f"{'='*85}")
        print(totals_summary.to_string(index=False))
        
        # Return comprehensive results
        return {
            'totals_summary': totals_summary,
            'indiana': {
                'units_80': in_units_80,
                'prescriptions_80': in_presc_80,
                'overlap': list(in_overlap),
                'only_units': list(in_only_units),
                'only_prescriptions': list(in_only_presc)
            },
            'national': {
                'units_80': nat_units_80,
                'prescriptions_80': nat_presc_80,
                'overlap': list(nat_overlap),
                'only_units': list(nat_only_units),
                'only_prescriptions': list(nat_only_presc)
            }
        }
    @staticmethod
    def create_pareto_charts(base_path=None, top_n=30):

        if base_path is None:
            base_path = rf"c:\Users\{user}\OneDrive - purdue.edu\VS code\Data"
        
        print("CREATING PARETO CHARTS: Indiana vs National")
        print("="*60)
        
        # File paths
        indiana_file = os.path.join(base_path, "ATC\\exported_analysis\\ATC2_Cumulative_Analysis_IN_with_NDC_Counts.xlsx")
        national_file = os.path.join(base_path, "ATC\\exported_analysis\\ATC2_Cumulative_Analysis_with_NDC_Counts.xlsx")
        
        # Load data
        try:
            indiana_units = pd.read_excel(indiana_file, sheet_name='Units_Reimbursed')
            indiana_prescriptions = pd.read_excel(indiana_file, sheet_name='Prescriptions')
            national_units = pd.read_excel(national_file, sheet_name='Units_Reimbursed')
            national_prescriptions = pd.read_excel(national_file, sheet_name='Prescriptions')
            print("✓ Data loaded successfully")
        except FileNotFoundError:
            print("✗ Error: Required Excel files not found. Please run export functions first.")
            return None
        except Exception as e:
            print(f"✗ Error loading files: {e}")
            return None
        
        # Create figure with 2x2 subplots
        fig, axes = plt.subplots(2, 2, figsize=(24, 16))
        
        def create_single_pareto(ax, df, metric_name, geography, top_n):
            """Create a single Pareto chart"""
            df_top = df.head(top_n).copy()
            
            # Prepare data
            classes = df_top['ATC_Class']
            if metric_name == 'Units':
                values = df_top['Total_Units']
                y_label = 'Total Units Reimbursed (Billions)'
            else:
                values = df_top['Total_Prescriptions']
                y_label = 'Total Prescriptions (Millions)'
            
            cumulative_pct = df_top['Cumulative_Percentage']
            
            # Create bar chart
            bars = ax.bar(range(len(classes)), values, alpha=0.7, color='steelblue')
            ax.set_xlabel('ATC2 Classes', fontsize=12)
            ax.set_ylabel(y_label, fontsize=12, color='steelblue')
            ax.tick_params(axis='y', labelcolor='steelblue')
            ax.set_xticks(range(len(classes)))
            ax.set_xticklabels(classes, rotation=90, ha='center', fontsize=8)
            
            # Add value labels on top 3 bars
            for i in range(min(3, len(bars))):
                height = bars[i].get_height()
                ax.text(bars[i].get_x() + bars[i].get_width()/2., height,
                       f'{height:.2f}', ha='center', va='bottom', fontsize=8)
            
            # Create cumulative percentage line
            ax2 = ax.twinx()
            ax2.plot(range(len(classes)), cumulative_pct, color='red', marker='o', linewidth=2, markersize=3)
            ax2.set_ylabel('Cumulative Percentage (%)', fontsize=12, color='red')
            ax2.tick_params(axis='y', labelcolor='red')
            ax2.set_ylim(0, 100)
            
            # Add 80% threshold line
            ax2.axhline(y=80, color='red', linestyle='--', alpha=0.7, linewidth=1)
            ax2.text(len(classes)*0.7, 82, '80%', color='red', fontweight='bold')
            
            # Title
            ax.set_title(f'{geography} - {metric_name}\n(Top {top_n} ATC2 Classes)', 
                        fontsize=12, fontweight='bold', pad=15)
        
        # Create all four Pareto charts
        create_single_pareto(axes[0,0], indiana_units, 'Units', 'INDIANA', top_n)
        create_single_pareto(axes[0,1], national_units, 'Units', 'NATIONAL', top_n)
        create_single_pareto(axes[1,0], indiana_prescriptions, 'Prescriptions', 'INDIANA', top_n)
        create_single_pareto(axes[1,1], national_prescriptions, 'Prescriptions', 'NATIONAL', top_n)
        
        plt.tight_layout()
        plt.show()
        
        # Helper function to print Pareto summary
        def print_pareto_summary(df, metric_name, geography):
            df_sorted = df.sort_values('Cumulative_Percentage').reset_index(drop=True)
            threshold_idx = df_sorted[df_sorted['Cumulative_Percentage'] >= 80].index
            
            if len(threshold_idx) > 0:
                classes_80 = df_sorted.iloc[:threshold_idx[0] + 1]
            else:
                classes_80 = df_sorted
            
            total_classes = len(df)
            classes_for_80 = len(classes_80)
            pct_classes_for_80 = (classes_for_80 / total_classes) * 100
            
            if metric_name == 'Units':
                total_value = df['Total_Units'].sum()
                value_80 = classes_80['Total_Units'].sum()
            else:
                total_value = df['Total_Prescriptions'].sum()
                value_80 = classes_80['Total_Prescriptions'].sum()
            
            actual_pct_covered = (value_80 / total_value) * 100
            
            print(f"\n{geography} - {metric_name}:")
            print(f"  Total ATC2 classes: {total_classes}")
            print(f"  Classes needed for ~80%: {classes_for_80} ({pct_classes_for_80:.1f}% of classes)")
            print(f"  Actual coverage: {actual_pct_covered:.1f}%")
            print(f"  Total {metric_name.lower()}: {total_value:.3f}")
            print(f"  Value in top classes: {value_80:.3f}")
            print(f"  Top 5 classes:")
            for i, (_, row) in enumerate(df.head(5).iterrows(), 1):
                value = row['Total_Units'] if metric_name == 'Units' else row['Total_Prescriptions']
                print(f"    {i}. {row['ATC_Class']}: {value:.3f} ({row['Percentage']:.1f}%)")
        
        # Print summaries
        print(f"\n{'='*80}")
        print("PARETO ANALYSIS SUMMARY")
        print(f"{'='*80}")
        
        print_pareto_summary(indiana_units, 'Units', 'INDIANA')
        print_pareto_summary(national_units, 'Units', 'NATIONAL')
        print_pareto_summary(indiana_prescriptions, 'Prescriptions', 'INDIANA')
        print_pareto_summary(national_prescriptions, 'Prescriptions', 'NATIONAL')
        
        # Compare top classes
        print(f"\n{'='*80}")
        print("KEY INSIGHTS - Top 10 Classes Comparison")
        print(f"{'='*80}")
        
        in_top_units = set(indiana_units.head(10)['ATC_Class'])
        nat_top_units = set(national_units.head(10)['ATC_Class'])
        in_top_presc = set(indiana_prescriptions.head(10)['ATC_Class'])
        nat_top_presc = set(national_prescriptions.head(10)['ATC_Class'])
        
        units_overlap = in_top_units & nat_top_units
        presc_overlap = in_top_presc & nat_top_presc
        
        print(f"\nUnits Reimbursed - Common classes: {len(units_overlap)}/10")
        print(f"  Shared: {sorted(units_overlap)}")
        print(f"  Only Indiana: {sorted(in_top_units - nat_top_units)}")
        print(f"  Only National: {sorted(nat_top_units - in_top_units)}")
        
        print(f"\nPrescriptions - Common classes: {len(presc_overlap)}/10")
        print(f"  Shared: {sorted(presc_overlap)}")
        print(f"  Only Indiana: {sorted(in_top_presc - nat_top_presc)}")
        print(f"  Only National: {sorted(nat_top_presc - in_top_presc)}")
        
        return {
            'indiana_units': indiana_units,
            'indiana_prescriptions': indiana_prescriptions,
            'national_units': national_units,
            'national_prescriptions': national_prescriptions
        }

In [47]:
processor = NDCATCProcessor(year=2016)
processor.clean_sdud_data()           # Clean SDUD data
processor.adding_key()                # Add record_id key
#analyzer.generate_ndc_txt()          # Generate NDC text file
processor.analyze_atc4_mapping() 
processor.fetch_atc_names() 
#processor.export_unscaled_data()

processor.prepare_final_dataframe()   
processor.export_merged_data()  
  

Reading CSV: c:\Users\lholguin\OneDrive - purdue.edu\VS code\Data\SDUD\SDUD2016.csv
Initial rows: 4,702,214
After cleaning: 2,291,493 rows, 28,875 unique NDCs
Created 2,291,493 record IDs
ATC4 file: 3,431,883 rows, 19,956 unique NDCs
Merged: 2,291,493 records, 1,818,174 with ATC4 (79.3%)
Missing: 473,319 records, 8,919 unique NDCs

FETCHING ATC CLASS NAMES
Using cache: c:\Users\lholguin\OneDrive - purdue.edu\VS code\Data\ATC\cache_files\atc_names_cache

Creating ATC3 and ATC2 columns from ATC4...

Unique codes to fetch:
  ATC4: 589
  ATC3: 208
  ATC2: 89

Fetching ATC4 names...
Fetching ATC3 names...
Fetching ATC2 names...

Total processing time: 0.0 minutes

Applying names to dataframe...

ATC names added successfully!

Sample output:
           NDC                   record_id ATC4 Class                                  ATC4_Name ATC3 Class                                              ATC3_Name ATC2 Class               ATC2_Name
0  00002143380  AK_2016_4_FFSU_00002143380      A10BJ  G

'c:\\Users\\lholguin\\OneDrive - purdue.edu\\VS code\\Data\\ATC\\merged_data\\merged_NEWdata_2016.csv'

In [None]:
#Just checking overlap between files with and without key
nokey_path=rf'C:\Users\{user}\OneDrive - purdue.edu\VS code\Data\ATC\ATC4_classes\Classes_notgood\NDCf_2020_ATC4_classes.csv'
keyed_path=rf'C:\Users\{user}\OneDrive - purdue.edu\VS code\Data\ATC\ATC4_classes\NDCNEW_2020_ATC4_classes.csv'
# Load them
keyed = pd.read_csv(keyed_path, dtype=str)
nokey = pd.read_csv(nokey_path, dtype=str)

# Normalize NDCs (remove hyphens, pad to 11 digits)
for df in [keyed, nokey]:
    df["NDC"] = df["NDC"].str.replace("-", "", regex=False).str.zfill(11)

# --- Summary stats ---
summary = {
    "File": ["With key (NDCNEW_2024_ATC4_classes)", "Without key (NDCf_2024_ATC4_classes)"],
    "Total rows": [len(keyed), len(nokey)],
    "Unique NDCs": [keyed["NDC"].nunique(), nokey["NDC"].nunique()],
    "Mapped NDCs (non-null ATC)": [
        keyed["ATC4 Class"].notna().sum(),
        nokey["ATC4 Class"].notna().sum(),
    ],
}
summary_df = pd.DataFrame(summary)

# --- Compare overlap of unique NDCs ---
ndc_keyed = set(keyed["NDC"].unique())
ndc_nokey = set(nokey["NDC"].unique())

overlap_ndcs = len(ndc_keyed & ndc_nokey)
only_in_nokey = len(ndc_nokey - ndc_keyed)
only_in_keyed = len(ndc_keyed - ndc_nokey)

comparison = pd.DataFrame({
    "Metric": ["Overlap NDCs", "Only in without-key file", "Only in with-key file", "Percent overlap"],
    "Value": [overlap_ndcs, only_in_nokey, only_in_keyed, overlap_ndcs / len(ndc_nokey) * 100]
})

print("\n=== Summary of Each File ===")
print(summary_df.to_string(index=False))

print("\n=== NDC Overlap Comparison ===")
print(comparison.to_string(index=False))

In [None]:
# Indiana only
years_list = [2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024]

#Generating only one excel file for all the states
#output_path = NDCATC_overview.export_cumulative_frequency_excel(years_list, by_state=False)  # This combines all states into one analysis
#NDCATC_overview.export_cumulative_frequency_excel(years_list,  level='ATC2','ATC4'by_state=True,state_filter=['IN'])

print("\nComparing Indiana vs National...")
ind_vs_national = NDCATC_overview.compare_cumulative_80_analysis()

ndc_pareto_results = NDCATC_overview.create_pareto_charts()


Doing the statistical analysis.
The unscaled data is analyzed

In [64]:
class NDCATC_ind:  #Creating a new class to analyze correlations
    
    @staticmethod
    def covariance_look(years_list, base_path=None, min_records=25):
        
        if base_path is None:
            base_path = rf"c:\Users\{user}\OneDrive - purdue.edu\VS code\Data"
        
        print("COVARIANCE ANALYSIS: Units Reimbursed vs Number of Prescriptions")
        print("NATIONAL ANALYSIS - By ATC Class")
        print("="*70)
        
        all_within_class_cov = []
        all_data_for_between = {}
        
        # Load and process all data
        for year in years_list:
            print(f"Processing {year}...", end=" ")
            try:
                csv_path = os.path.join(base_path, f"ATC\\merged_data\\unscaled_data\\MUD_{year}.csv")
                df_merged = pd.read_csv(csv_path)
                
                records = df_merged[df_merged['ATC4 Class'].notna()].copy()
                if records.empty:
                    print("No valid records")
                    continue
                
                # Scale both variables by 1e6
                records['Units Reimbursed'] = records['Units Reimbursed'] / 1e3
                records['Number of Prescriptions'] = records['Number of Prescriptions'] / 1e3        

                print(f"✓ ({len(records):,} records)")
                all_data_for_between[year] = records
                
                # Calculate within-class covariances for this year
                for atc2 in records['ATC2 Class'].unique():
                    subset = records[records['ATC2 Class'] == atc2]
                    if len(subset) >= min_records:
                        try:
                            # Calculate covariance
                            cov_matrix = np.cov(subset['Units Reimbursed'], subset['Number of Prescriptions'])
                            covariance = cov_matrix[0, 1]
                            
                            atc2_name = subset['ATC2_Name'].iloc[0] if 'ATC2_Name' in subset.columns else ''
                            
                            all_within_class_cov.append({
                                'Year': year,
                                'ATC2_Class': atc2,
                                'ATC2_Name': atc2_name,
                                'N_Records': len(subset),
                                'Within_Class_Covariance': covariance,
                                'Mean_Units': subset['Units Reimbursed'].mean(),
                                'Mean_Prescriptions': subset['Number of Prescriptions'].mean()
                            })
                        except Exception as e:
                            print(f"\nError processing {atc2}: {e}")
            except Exception as e:
                print(f"✗ Error: {e}")
        
        within_class_df = pd.DataFrame(all_within_class_cov)
        
        if within_class_df.empty:
            print("No within-class covariance data collected!")
            return pd.DataFrame(), pd.DataFrame()
        
        # ==================== WITHIN-CLASS ANALYSIS ====================
        print(f"\n{'='*70}\n1. WITHIN-CLASS COVARIANCE ANALYSIS\n{'='*70}")
        print(f"Total class-year combinations: {len(within_class_df):,}")
        print(f"Years analyzed: {sorted(within_class_df['Year'].unique())}")
        print(f"Unique ATC2 classes: {len(within_class_df['ATC2_Class'].unique())}")
        
        print(f"\nOverall Within-Class Covariance Statistics:")
        print(f"  Mean: {within_class_df['Within_Class_Covariance'].mean():,.10f}")
        print(f"  Median: {within_class_df['Within_Class_Covariance'].median():,.10f}")
        print(f"  Std Dev: {within_class_df['Within_Class_Covariance'].std():,.10f}")
        print(f"  Range: {within_class_df['Within_Class_Covariance'].min():,.10f} to {within_class_df['Within_Class_Covariance'].max():,.10f}")
        
        # Aggregate by ATC class across all years
        atc_within_summary = within_class_df.groupby(['ATC2_Class', 'ATC2_Name']).agg({
            'Within_Class_Covariance': ['mean', 'std'],
            'N_Records': 'sum'
        }).round(2)
        atc_within_summary.columns = ['Avg_Within_Cov', 'Std_Within_Cov', 'Total_Records']
        atc_within_summary = atc_within_summary.sort_values('Avg_Within_Cov', ascending=False)
        
        print(f"\n{'='*70}\nWITHIN-CLASS COVARIANCE BY ATC2 CLASS (Averaged across years)\n{'='*70}")
        print(f"{'ATC2':<5} {'Name':<30} {'Avg Cov':>15} {'Std':>12} {'Total N':>10}")
        print("-" * 80)
        for (atc_class, atc_name), row in atc_within_summary.head(15).iterrows():
            name_short = atc_name[:28] if atc_name else atc_class
            print(f"{atc_class:<5} {name_short:<30} {row['Avg_Within_Cov']:>15,.10f} {row['Std_Within_Cov']:>12,.10f} {row['Total_Records']:>10,.0f}")
        
        # ==================== BETWEEN-CLASS ANALYSIS ====================
        print(f"\n{'='*70}\n2. BETWEEN-CLASS COVARIANCE ANALYSIS\n{'='*70}")
        
        between_class_results = []
        
        for year in sorted(all_data_for_between.keys()):
            year_data = all_data_for_between[year]
            
            # Get class means for this year
            class_means = year_data.groupby('ATC2 Class').agg({
                'Units Reimbursed': 'mean',
                'Number of Prescriptions': 'mean',
                'ATC2_Name': 'first'
            }).reset_index()
            
            # Filter classes with sufficient records
            class_counts = year_data['ATC2 Class'].value_counts()
            valid_classes = class_counts[class_counts >= min_records].index
            class_means = class_means[class_means['ATC2 Class'].isin(valid_classes)]
            
            if len(class_means) < 2:
                print(f"{year}: Not enough classes for between-class analysis")
                continue
            
            # Calculate between-class covariance
            between_cov = np.cov(
                class_means['Units Reimbursed'],
                class_means['Number of Prescriptions']
            )[0, 1]
            
            between_class_results.append({
                'Year': year,
                'Between_Class_Covariance': between_cov,
                'N_Classes': len(class_means),
                'Total_Records': len(year_data)
            })
            
            print(f"Year {year}: Between-Class Cov = {between_cov:>15,.10f} (across {len(class_means)} classes, {len(year_data):,} records)")
        
        between_class_df = pd.DataFrame(between_class_results)
        
        if not between_class_df.empty:
            print(f"\n{'='*70}\nBETWEEN-CLASS COVARIANCE SUMMARY\n{'='*70}")
            print(f"Average across years: {between_class_df['Between_Class_Covariance'].mean():,.10f}")
            print(f"Median: {between_class_df['Between_Class_Covariance'].median():,.10f}")
            print(f"Std Dev: {between_class_df['Between_Class_Covariance'].std():,.10f}")
            print(f"Range: {between_class_df['Between_Class_Covariance'].min():,.10f} to {between_class_df['Between_Class_Covariance'].max():,.10f}")
            
            # Overall covariance (all data pooled across all classes and years)
            print(f"\n{'='*70}\n3. OVERALL COVARIANCE (All classes and years combined)\n{'='*70}")
            all_combined = pd.concat(all_data_for_between.values(), ignore_index=True)
            overall_cov = np.cov(
                all_combined['Units Reimbursed'],
                all_combined['Number of Prescriptions']
            )[0, 1]
            print(f"Overall Covariance: {overall_cov:,.10f} ({len(all_combined):,} total records)")
            
            # Comparison
            print(f"\n{'='*70}\nCOMPARISON\n{'='*70}")
            avg_within = within_class_df['Within_Class_Covariance'].mean()
            avg_between = between_class_df['Between_Class_Covariance'].mean()
            print(f"Average Within-Class Covariance:  {avg_within:>15,.10f}")
            print(f"Average Between-Class Covariance: {avg_between:>15,.10f}")
            print(f"Overall Covariance:               {overall_cov:>15,.10f}")
            print(f"\nRatio (Between/Within):           {avg_between/avg_within if avg_within != 0 else 'undefined':>15.4f}")
        
        return within_class_df, between_class_df

    @staticmethod
    def analyze_correlation_by_state_atc(years_list, base_path=None, min_records=25):
        
        if base_path is None:
            base_path = rf"c:\Users\{user}\OneDrive - purdue.edu\VS code\Data"
        
        print("CORRELATION ANALYSIS: Units Reimbursed vs Number of Prescriptions")
        print("INDIANA ONLY - By ATC Class")
        print("="*70)
        
        all_correlations = []
        
        for year in years_list:
            print(f"Processing {year}...", end=" ")
            try:
                csv_path = os.path.join(base_path, f"ATC\\merged_data\\merged_NEWdata_{year}.csv")
                df_merged = pd.read_csv(csv_path)
                
                records = df_merged[(df_merged['ATC4 Class'].notna()) & (df_merged['State'] == 'IN')].copy()
                if records.empty:
                    print("No valid Indiana records")
                    continue
                
                print(f"✓ ({len(records):,} records)")
                
                for atc2 in records['ATC2 Class'].unique():
                    subset = records[records['ATC2 Class'] == atc2]
                    if len(subset) >= min_records:
                        try:
                            pearson_r, pearson_p = pearsonr(subset['Units Reimbursed'], subset['Number of Prescriptions'])
                            spearman_r, spearman_p = spearmanr(subset['Units Reimbursed'], subset['Number of Prescriptions'])
                            atc2_name = subset['ATC2_Name'].iloc[0] if 'ATC2_Name' in subset.columns else ''
                            
                            all_correlations.append({
                                'Year': year, 'State': 'IN', 'ATC2_Class': atc2, 'ATC2_Name': atc2_name,
                                'N_Records': len(subset), 'Pearson_r': pearson_r, 'Pearson_p': pearson_p,
                                'Spearman_r': spearman_r, 'Spearman_p': spearman_p
                            })
                        except Exception as e:
                            print(f"\nError processing IN-{atc2}: {e}")
            except Exception as e:
                print(f"✗ Error: {e}")
        
        correlations_df = pd.DataFrame(all_correlations)
        if correlations_df.empty:
            print("No correlation data collected!")
            return pd.DataFrame()
        
        # SUMMARY STATISTICS
        print(f"\n{'='*70}\nSUMMARY STATISTICS FOR INDIANA\n{'='*70}")
        print(f"Total combinations: {len(correlations_df):,} | Years: {sorted(correlations_df['Year'].unique())} | ATC classes: {len(correlations_df['ATC2_Class'].unique())}")
        
        # PEARSON CORRELATION RESULTS
        print(f"\n{'='*70}\nPEARSON CORRELATION RESULTS\n{'='*70}")
        print(f"Average: {correlations_df['Pearson_r'].mean():.4f} | Range: {correlations_df['Pearson_r'].min():.4f} to {correlations_df['Pearson_r'].max():.4f} | Std Dev: {correlations_df['Pearson_r'].std():.4f}")
        
        atc_pearson = correlations_df.groupby(['ATC2_Class', 'ATC2_Name']).agg({
            'Pearson_r': ['mean', 'std'], 'Pearson_p': 'mean', 'N_Records': 'sum'
        }).round(4)
        atc_pearson.columns = ['Avg_Pearson', 'Std_Pearson', 'Avg_P_Value', 'Total_Records']
        atc_pearson = atc_pearson.sort_values('Avg_Pearson', ascending=False)
        
        print(f"\nPEARSON BY ATC CLASS (Average across years):")
        print(f"{'ATC2':<5} {'Name':<25} {'Avg Pearson':<12} {'Std':<8} {'p-val':<8} {'Total N':<8}\n{'-' * 75}")
        for (atc_class, atc_name), row in atc_pearson.iterrows():
            name_short = atc_name[:23] if atc_name else atc_class
            print(f"{atc_class:<5} {name_short:<25} {row['Avg_Pearson']:<12.4f} {row['Std_Pearson']:<8.4f} {row['Avg_P_Value']:<8.4f} {row['Total_Records']:<8.0f}")
        
        # SPEARMAN CORRELATION RESULTS
        print(f"\n{'='*70}\nSPEARMAN RANK CORRELATION RESULTS\n{'='*70}")
        print(f"Average: {correlations_df['Spearman_r'].mean():.4f} | Range: {correlations_df['Spearman_r'].min():.4f} to {correlations_df['Spearman_r'].max():.4f} | Std Dev: {correlations_df['Spearman_r'].std():.4f}")
        
        atc_spearman = correlations_df.groupby(['ATC2_Class', 'ATC2_Name']).agg({
            'Spearman_r': ['mean', 'std'], 'Spearman_p': 'mean', 'N_Records': 'sum'
        }).round(4)
        atc_spearman.columns = ['Avg_Spearman', 'Std_Spearman', 'Avg_P_Value', 'Total_Records']
        atc_spearman = atc_spearman.sort_values('Avg_Spearman', ascending=False)
        
        print(f"\nSPEARMAN BY ATC CLASS (Average across years):")
        print(f"{'ATC2':<5} {'Name':<25} {'Avg Spearman':<12} {'Std':<8} {'p-val':<8} {'Total N':<8}\n{'-' * 75}")
        for (atc_class, atc_name), row in atc_spearman.iterrows():
            name_short = atc_name[:23] if atc_name else atc_class
            print(f"{atc_class:<5} {name_short:<25} {row['Avg_Spearman']:<12.4f} {row['Std_Spearman']:<8.4f} {row['Avg_P_Value']:<8.4f} {row['Total_Records']:<8.0f}")
        # YEAR-OVER-YEAR TRENDS
        print(f"\n{'='*70}")
        print("YEAR-OVER-YEAR TRENDS (Top 3 by Pearson)")
        print(f"{'='*70}")
        
        top_classes = atc_pearson.head(3).index.get_level_values(0).tolist()
        for atc_class in top_classes:
            atc_data = correlations_df[correlations_df['ATC2_Class'] == atc_class].sort_values('Year')
            atc_name = atc_data['ATC2_Name'].iloc[0] if not atc_data.empty else atc_class
            
            print(f"\n{atc_class} - {atc_name[:30]}:")
            print(f"{'Year':<6} {'Pearson':<8} {'Spearman':<9} {'N':<6}")
            print("-" * 35)
            for _, row in atc_data.iterrows():
                print(f"{row['Year']:<6} {row['Pearson_r']:<8.4f} {row['Spearman_r']:<9.4f} {row['N_Records']:<6}")
        
        return correlations_df
    
    @staticmethod
    def plot_units_vs_prescriptions_by_atc(years_list, base_path=None, min_records=25, include_negative=True):

        if base_path is None:
            base_path = rf"c:\Users\{user}\OneDrive - purdue.edu\VS code\Data"
        
        print("Creating plots for Indiana ATC classes...")
        
        # Combine all years of data
        all_data = []
        for year in years_list:
            try:
                csv_path = os.path.join(base_path, f"ATC\\merged_data\\merged_NEWdata_{year}.csv")
                df_merged = pd.read_csv(csv_path)
                records = df_merged[(df_merged['ATC4 Class'].notna()) & (df_merged['State'] == 'IN')].copy()
                records['Year'] = year
                all_data.append(records)
            except Exception as e:
                print(f"Error loading {year}: {e}")
        
        if not all_data:
            print("No data loaded!")
            return
        
        combined_df = pd.concat(all_data, ignore_index=True)
        
        # Get ATC classes with sufficient data and calculate correlations
        atc_counts = combined_df['ATC2 Class'].value_counts()
        sufficient_data_classes = atc_counts[atc_counts >= min_records].index
        
        # Calculate correlations for all classes with sufficient data
        class_correlations = {}
        for atc_class in sufficient_data_classes:
            subset = combined_df[combined_df['ATC2 Class'] == atc_class]
            if len(subset) > 1:
                corr = subset['Number of Prescriptions'].corr(subset['Units Reimbursed'])
                class_correlations[atc_class] = corr
        
        # Select classes to plot
        if include_negative:
            # Get top positive correlations and all negative correlations
            positive_corrs = {k: v for k, v in class_correlations.items() if v >= 0}
            negative_corrs = {k: v for k, v in class_correlations.items() if v < 0}
            
            # Sort positive by correlation (descending) and negative by correlation (ascending, most negative first)
            positive_sorted = sorted(positive_corrs.items(), key=lambda x: x[1], reverse=True)
            negative_sorted = sorted(negative_corrs.items(), key=lambda x: x[1])
            
            # Take top 8 positive and all negative (up to 4 more)
            selected_positive = [x[0] for x in positive_sorted[:8]]
            selected_negative = [x[0] for x in negative_sorted[:4]]
            
            valid_atc_classes = selected_positive + selected_negative
            
            print(f"\nSelected classes: {len(selected_positive)} positive correlations + {len(selected_negative)} negative correlations")
            if selected_negative:
                print(f"Negative correlation classes: {selected_negative}")
        else:
            # Original behavior - top classes by count
            valid_atc_classes = sufficient_data_classes[:12]
        
        # Determine grid size based on number of classes
        n_classes = len(valid_atc_classes)
        if n_classes <= 6:
            rows, cols = 2, 3
        elif n_classes <= 9:
            rows, cols = 3, 3
        elif n_classes <= 12:
            rows, cols = 3, 4
        else:
            rows, cols = 4, 4
            valid_atc_classes = valid_atc_classes[:16]  # Limit to 16 for display
        
        # Set up the plot grid
        fig, axes = plt.subplots(rows, cols, figsize=(cols*5, rows*4))
        if rows == 1 or cols == 1:
            axes = axes.flatten() if hasattr(axes, 'flatten') else [axes]
        else:
            axes = axes.flatten()
        
        colors = plt.cm.Set3(np.linspace(0, 1, len(valid_atc_classes)))
        
        for i, atc_class in enumerate(valid_atc_classes):
            subset = combined_df[combined_df['ATC2 Class'] == atc_class]
            atc_name = subset['ATC2_Name'].iloc[0] if 'ATC2_Name' in subset.columns and not subset['ATC2_Name'].isna().all() else atc_class
            
            # Create scatter plot
            axes[i].scatter(subset['Number of Prescriptions'], 
                           subset['Units Reimbursed'], 
                           alpha=0.6, color=colors[i], s=20)
            
            # Add trend line
            if len(subset) > 1:
                z = np.polyfit(subset['Number of Prescriptions'], subset['Units Reimbursed'], 1)
                p = np.poly1d(z)
                axes[i].plot(subset['Number of Prescriptions'], p(subset['Number of Prescriptions']), 
                            "r--", alpha=0.8, linewidth=1)
            
            # Format axes
            axes[i].set_xlabel('Number of Prescriptions')
            axes[i].set_ylabel('Units Reimbursed')
            axes[i].set_title(f'{atc_class}\n{atc_name[:30]}', fontsize=10)
            axes[i].grid(True, alpha=0.3)
            
            # Add correlation coefficient with color coding
            if len(subset) > 1:
                corr = subset['Number of Prescriptions'].corr(subset['Units Reimbursed'])
                color = 'red' if corr < 0 else 'blue'
                axes[i].text(0.05, 0.95, f'r = {corr:.3f}', 
                            transform=axes[i].transAxes, 
                            bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8, edgecolor=color),
                            fontsize=9, color=color)
        
        # Hide unused subplots
        for j in range(len(valid_atc_classes), len(axes)):
            axes[j].set_visible(False)
        
        plt.tight_layout()
        title_suffix = " (Including Negative Correlations)" if include_negative else ""
        plt.suptitle(f'Indiana: Units Reimbursed vs Number of Prescriptions by ATC2 Class{title_suffix}\n(All Years Combined)', 
                     fontsize=16, y=1.02)
        plt.show()
        
        # Summary table
        print(f"\n{'='*80}")
        print("PLOT SUMMARY - INDIANA ATC CLASSES")
        print(f"{'='*80}")
        print(f"{'ATC2':<5} {'Name':<30} {'Records':<8} {'Correlation':<12} {'Type':<8}")
        print("-" * 90)
        
        for atc_class in valid_atc_classes:
            subset = combined_df[combined_df['ATC2 Class'] == atc_class]
            atc_name = subset['ATC2_Name'].iloc[0] if 'ATC2_Name' in subset.columns and not subset['ATC2_Name'].isna().all() else atc_class
            corr = subset['Number of Prescriptions'].corr(subset['Units Reimbursed'])
            corr_type = "Negative" if corr < 0 else "Positive"
            
            print(f"{atc_class:<5} {atc_name[:28]:<30} {len(subset):<8} {corr:<12.4f} {corr_type:<8}")
        
        return combined_df[combined_df['ATC2 Class'].isin(valid_atc_classes)] 

In [None]:
@staticmethod
def correlation_picture(years_list, base_path=None, min_records=25):
    
    if base_path is None:
        base_path = rf"c:\Users\{user}\OneDrive - purdue.edu\VS code\Data"
    
    print("CORRELATION ANALYSIS: Units Reimbursed vs Number of Prescriptions")
    print("NATIONAL ANALYSIS - By ATC Class")
    print("="*70)
    
    all_within_class_corr = []
    all_data_for_between = {}
    
    # Load and process all data
    for year in years_list:
        print(f"Processing {year}...", end=" ")
        try:
            csv_path = os.path.join(base_path, f"ATC\\merged_data\\merged_NEWdata_{year}.csv")
            df_merged = pd.read_csv(csv_path)
            
            records = df_merged[df_merged['ATC4 Class'].notna()].copy()
            if records.empty:
                print("No valid records")
                continue
            
            print(f"✓ ({len(records):,} records)")
            all_data_for_between[year] = records
            
            # Calculate within-class correlations for this year
            for atc2 in records['ATC2 Class'].unique():
                subset = records[records['ATC2 Class'] == atc2]
                if len(subset) >= min_records:
                    try:
                        # Calculate correlation
                        correlation = subset['Units Reimbursed'].corr(subset['Number of Prescriptions'])
                        
                        atc2_name = subset['ATC2_Name'].iloc[0] if 'ATC2_Name' in subset.columns else ''
                        
                        all_within_class_corr.append({
                            'Year': year,
                            'ATC2_Class': atc2,
                            'ATC2_Name': atc2_name,
                            'N_Records': len(subset),
                            'Within_Class_Correlation': correlation,
                            'Mean_Units': subset['Units Reimbursed'].mean(),
                            'Mean_Prescriptions': subset['Number of Prescriptions'].mean()
                        })
                    except Exception as e:
                        print(f"\nError processing {atc2}: {e}")
        except Exception as e:
            print(f"✗ Error: {e}")
    
    within_class_df = pd.DataFrame(all_within_class_corr)
    
    if within_class_df.empty:
        print("No within-class correlation data collected!")
        return pd.DataFrame(), pd.DataFrame()
    
    # ==================== WITHIN-CLASS ANALYSIS ====================
    print(f"\n{'='*70}\n1. WITHIN-CLASS CORRELATION ANALYSIS\n{'='*70}")
    print(f"Total class-year combinations: {len(within_class_df):,}")
    print(f"Years analyzed: {sorted(within_class_df['Year'].unique())}")
    print(f"Unique ATC2 classes: {len(within_class_df['ATC2_Class'].unique())}")
    
    print(f"\nOverall Within-Class Correlation Statistics:")
    print(f"  Mean: {within_class_df['Within_Class_Correlation'].mean():.4f}")
    print(f"  Median: {within_class_df['Within_Class_Correlation'].median():.4f}")
    print(f"  Std Dev: {within_class_df['Within_Class_Correlation'].std():.4f}")
    print(f"  Range: {within_class_df['Within_Class_Correlation'].min():.4f} to {within_class_df['Within_Class_Correlation'].max():.4f}")
    
    # Aggregate by ATC class across all years
    atc_within_summary = within_class_df.groupby(['ATC2_Class', 'ATC2_Name']).agg({
        'Within_Class_Correlation': ['mean', 'std'],
        'N_Records': 'sum'
    }).round(4)
    atc_within_summary.columns = ['Avg_Within_Corr', 'Std_Within_Corr', 'Total_Records']
    atc_within_summary = atc_within_summary.sort_values('Avg_Within_Corr', ascending=False)
    
    print(f"\n{'='*70}\nWITHIN-CLASS CORRELATION BY ATC2 CLASS (Averaged across years)\n{'='*70}")
    print(f"{'ATC2':<5} {'Name':<30} {'Avg Corr':>12} {'Std':>10} {'Total N':>10}")
    print("-" * 75)
    for (atc_class, atc_name), row in atc_within_summary.head(15).iterrows():
        name_short = atc_name[:28] if atc_name else atc_class
        print(f"{atc_class:<5} {name_short:<30} {row['Avg_Within_Corr']:>12.4f} {row['Std_Within_Corr']:>10.4f} {row['Total_Records']:>10,.0f}")
    
    # ==================== BETWEEN-CLASS ANALYSIS ====================
    print(f"\n{'='*70}\n2. BETWEEN-CLASS CORRELATION ANALYSIS\n{'='*70}")
    
    between_class_results = []
    
    for year in sorted(all_data_for_between.keys()):
        year_data = all_data_for_between[year]
        
        # Get class means for this year
        class_means = year_data.groupby('ATC2 Class').agg({
            'Units Reimbursed': 'mean',
            'Number of Prescriptions': 'mean',
            'ATC2_Name': 'first'
        }).reset_index()
        
        # Filter classes with sufficient records
        class_counts = year_data['ATC2 Class'].value_counts()
        valid_classes = class_counts[class_counts >= min_records].index
        class_means = class_means[class_means['ATC2 Class'].isin(valid_classes)]
        
        if len(class_means) < 2:
            print(f"{year}: Not enough classes for between-class analysis")
            continue
        
        # Calculate between-class correlation
        between_corr = class_means['Units Reimbursed'].corr(class_means['Number of Prescriptions'])
        
        between_class_results.append({
            'Year': year,
            'Between_Class_Correlation': between_corr,
            'N_Classes': len(class_means),
            'Total_Records': len(year_data)
        })
        
        print(f"Year {year}: Between-Class Corr = {between_corr:.4f} (across {len(class_means)} classes, {len(year_data):,} records)")
    
    between_class_df = pd.DataFrame(between_class_results)
    
    if not between_class_df.empty:
        print(f"\n{'='*70}\nBETWEEN-CLASS CORRELATION SUMMARY\n{'='*70}")
        print(f"Average across years: {between_class_df['Between_Class_Correlation'].mean():.4f}")
        print(f"Median: {between_class_df['Between_Class_Correlation'].median():.4f}")
        print(f"Std Dev: {between_class_df['Between_Class_Correlation'].std():.4f}")
        print(f"Range: {between_class_df['Between_Class_Correlation'].min():.4f} to {between_class_df['Between_Class_Correlation'].max():.4f}")
        
        # Overall correlation (all data pooled across all classes and years)
        print(f"\n{'='*70}\n3. OVERALL CORRELATION (All classes and years combined)\n{'='*70}")
        all_combined = pd.concat(all_data_for_between.values(), ignore_index=True)
        overall_corr = all_combined['Units Reimbursed'].corr(all_combined['Number of Prescriptions'])
        print(f"Overall Correlation: {overall_corr:.4f} ({len(all_combined):,} total records)")
        
        # Comparison
        print(f"\n{'='*70}\nCOMPARISON\n{'='*70}")
        avg_within = within_class_df['Within_Class_Correlation'].mean()
        avg_between = between_class_df['Between_Class_Correlation'].mean()
        print(f"Average Within-Class Correlation:  {avg_within:>12.4f}")
        print(f"Average Between-Class Correlation: {avg_between:>12.4f}")
        print(f"Overall Correlation:               {overall_corr:>12.4f}")
    
    return within_class_df, between_class_df

In [65]:
years_to_analyze = [2016,2017,2018,2019, 2020, 2021, 2022, 2023, 2024]
covarience_results = NDCATC_ind.covariance_look(years_to_analyze)

COVARIANCE ANALYSIS: Units Reimbursed vs Number of Prescriptions
NATIONAL ANALYSIS - By ATC Class
Processing 2016... ✓ (1,818,174 records)
Processing 2017... ✓ (1,931,088 records)
Processing 2018... ✓ (2,022,672 records)
Processing 2019... ✓ (2,109,684 records)
Processing 2020... ✓ (2,173,775 records)
Processing 2021... ✓ (2,287,508 records)
Processing 2022... ✓ (2,353,181 records)
Processing 2023... ✓ (2,385,896 records)
Processing 2024... ✓ (2,338,667 records)

1. WITHIN-CLASS COVARIANCE ANALYSIS
Total class-year combinations: 782
Years analyzed: [np.int64(2016), np.int64(2017), np.int64(2018), np.int64(2019), np.int64(2020), np.int64(2021), np.int64(2022), np.int64(2023), np.int64(2024)]
Unique ATC2 classes: 88

Overall Within-Class Covariance Statistics:
  Mean: 220.1820741807
  Median: 33.8981366064
  Std Dev: 1,308.0082293363
  Range: -4.2830359326 to 32,557.2624805437

WITHIN-CLASS COVARIANCE BY ATC2 CLASS (Averaged across years)
ATC2  Name                                   Avg 

In [None]:
print(type(corr_results))
print(type(covarience_results))


In [None]:

#covarience_results = NDCATC_ind.analyze_covariance_by_state_atc(years_to_analyze)

#plot_data = NDCATC_ind.plot_units_vs_prescriptions_by_atc(years_to_analyze)
correlation_results = NDCATC_ind.analyze_correlation_by_state_atc(years_to_analyze)
#plots= NDCATC_ind.plot_units_vs_prescriptions_by_atc(years_to_analyze, include_negative=True)
