In [18]:
import pandas as pd
import numpy as np
import requests
import shelve
import os
from datetime import datetime
import matplotlib.pyplot as plt
import re

In [19]:
user="lholguin"
#user in personal pc <- "asus"

In [None]:
#Class with new methods for NDC-ATC analysis
class NDCATCAnalyzer:

    def __init__(self, year, base_path=None):

        self.year = year
        if base_path is None:
            #Lookup the user's base path
            self.base_path = rf"c:\Users\{user}\OneDrive - purdue.edu\VS code\Data"
        else:
            self.base_path = base_path
            
        self.df_cleaned = None
        self.df_merged = None
        self.atc_mapping = None
        self.df_faf = None
        
    def clean_sdud_data(self):
        csv_file = os.path.join(self.base_path, f"SDUD\\SDUD{self.year}.csv")
        print(f"Reading CSV file: {csv_file}")
        
        # Read with NDC as string to preserve leading zeros
        df = pd.read_csv(csv_file, dtype={'NDC': 'object'})
        
        print(f"Total rows in {self.year} before filtering: {len(df)}")
        
        # Remove NA values
        df_filtered = df.dropna(subset=['Units Reimbursed', 'Number of Prescriptions'])
        print(f"Rows after removing NA: {len(df_filtered)}")
        
        # Filter out State='XX'
        df_filtered = df_filtered[df_filtered['State'] != 'XX']
        print(f"Rows after filtering State='XX': {len(df_filtered)}")
        print(f"Unique NDCs: {df_filtered['NDC'].nunique()}")
        
        self.df_cleaned = df_filtered
        return self.df_cleaned
    
    #NEW
    def adding_key(self):
        if self.df_cleaned is None:
            raise ValueError("Must run clean_sdud_data() first")
        
        print("Adding record_id column...")
        
        # Create record_id column
        self.df_cleaned['record_id'] = (
            self.df_cleaned['State'].astype(str) + "_" +
            self.df_cleaned['Year'].astype(str) + "_" +
            self.df_cleaned['Quarter'].astype(str) + "_" +
            self.df_cleaned['Utilization Type'].astype(str) + "_" +
            self.df_cleaned['NDC'].astype(str)
        )
        
        print(f"Created {len(self.df_cleaned)} record IDs")
        print(f"Sample record_id: {self.df_cleaned['record_id'].iloc[0]}")
        
        return self.df_cleaned
    
    def generate_ndc_txt(self, output_filename=None):
        """Step 2: Generate text file with unique NDC values and their record_id keys."""
        if self.df_cleaned is None:
            raise ValueError("Must run clean_sdud_data() first")
        
        if 'record_id' not in self.df_cleaned.columns:
            raise ValueError("Must run adding_key() first to create record_id column")
            
        if output_filename is None:
            output_filename = f"NDCNEW_{self.year}.txt"
        
        output_path = os.path.join(self.base_path, f"ATC\\text_files\\{output_filename}")
        
        # Get unique combinations of NDC and record_id
        unique_pairs = self.df_cleaned[['NDC', 'record_id']].drop_duplicates()
        
        with open(output_path, 'w') as f:
            # Write header
            f.write("NDC\trecord_id\n")
            # Write each unique pair
            for _, row in unique_pairs.iterrows():
                f.write(f"{row['NDC']}\t{row['record_id']}\n")
        
        print(f"Exported to: {output_path}")
        print(f"Unique record_id values: {unique_pairs['record_id'].nunique()}")
        return output_path
    
    def analyze_atc4_mapping(self):
        atc4_path = os.path.join(self.base_path, f"ATC\\ATC4_classes\\NDCNEW_{self.year}_ATC4_classes.csv")
        
        # Read ATC4 mapping
        df_atc4 = pd.read_csv(atc4_path, dtype={'NDC': 'object', 'record_id': 'string'})
        df_atc4['NDC'] = df_atc4['NDC'].str.zfill(11)

        #Print how many unique NDC codes were mapped
        print(f"Total rows in ATC4 mapping file: {len(df_atc4)}")
        print(f"Unique NDCs in ATC4 mapping file: {df_atc4['NDC'].nunique()}")

        # Ensure consistent data types before merge
        self.df_cleaned['record_id'] = self.df_cleaned['record_id'].astype('string')
        self.df_cleaned['NDC'] = self.df_cleaned['NDC'].astype('object')  
        df_atc4['record_id'] = df_atc4['record_id'].astype('string')
        df_atc4['NDC'] = df_atc4['NDC'].astype('object')  

        # Merge ATC4 mapping with cleaned data using record_id
        if self.df_cleaned is None:
            raise ValueError("Must run clean_sdud_data() and adding_key() first")
        
        if 'record_id' not in self.df_cleaned.columns:
            raise ValueError("Must run adding_key() first to create record_id column")
        
        print(f"Merging ATC4 mapping with cleaned data using record_id and NDC...")
        
        # Merge on BOTH record_id AND NDC
        self.atc_mapping = pd.merge(
            self.df_cleaned,
            df_atc4[['record_id', 'NDC', 'ATC4 Class']],  # Include NDC in the selection
            on=['record_id', 'NDC'],  # Merge on both columns
            how='left'
        )
        
        # Print rows of the merged dataframe
        print(f"Merged dataframe rows: {len(self.atc_mapping)}")
        #print(self.atc_mapping.head())
        total_records = len(self.atc_mapping)
        mapped_records = self.atc_mapping['ATC4 Class'].notna().sum()
        print(f"Records with ATC4 mapping: {mapped_records} ({mapped_records/total_records*100:.1f}%)")

        # Identify missing mappings
        missing_records = self.atc_mapping[self.atc_mapping['ATC4 Class'].isna()]
        if len(missing_records) > 0:
            print(f"\nRecords without ATC4 mapping: {len(missing_records)}")
            print(f"Unique NDCs without mapping: {missing_records['NDC'].nunique()}")
        
        return self.atc_mapping

    def select_one_atc_per_record(self, strategy="ndc_mode_then_priority"):
        if self.atc_mapping is None:
            raise ValueError("Must run analyze_atc4_mapping() first")

        df = self.atc_mapping.copy()
        df = df[df['ATC4 Class'].notna()].copy()

        # NDC-level ATC4 frequency (mode by NDC)
        ndc_atc_counts = (
            df.groupby(['NDC','ATC4 Class'])
            .size().rename('ndc_atc_count')
            .reset_index()
        )
        df = df.merge(ndc_atc_counts, on=['NDC','ATC4 Class'], how='left')

        def atc_priority(atc4):
            atc2 = atc4[:3] if isinstance(atc4, str) else ""
            if atc2.startswith('V03'):  # 'Other therapeutic products'
                return 100
            if atc2.startswith('V'):    # 'Various'
                return 80
            # illustrative booster: tune as needed
            boosters = {'A':10,'B':15,'C':12,'D':20,'G':18,'H':14,'J':11,'L':16,'M':17,'N':13,'R':19,'S':21}
            return boosters.get(atc2[:1], 50)

        df['priority_score'] = df['ATC4 Class'].map(atc_priority).fillna(60)

        df = df.sort_values(
            by=['record_id','ndc_atc_count','priority_score','Units Reimbursed','Number of Prescriptions','ATC4 Class'],
            ascending=[True, False, True, False, False, True]
        )

        df_one = df.drop_duplicates(subset='record_id', keep='first').copy()

        # small diagnostic
        total_records = self.atc_mapping['record_id'].nunique()
        kept = len(df_one)
        multi = (self.atc_mapping
                .groupby('record_id')['ATC4 Class']
                .nunique()
                .reset_index(name='n'))
        pct_multi = (multi['n']>1).mean()*100

        print("\n" + "="*60)
        print("SELECT-ONE-ATC PER RECORD_ID (DETERMINISTIC)")
        print("="*60)
        print(f"Unique record_ids (raw): {total_records:,}")
        print(f"Kept record_ids (one per id): {kept:,}")
        print(f"Records with >1 ATC4 candidates: {int((multi['n']>1).sum()):,} ({pct_multi:.1f}%)")
        print("Tie-break order: NDC-mode → priority(list) → Units → Prescriptions → alphabetical")

        self.df_selected = df_one
        return self.df_selected

    def analyze_atc4_distribution(self):

        if self.atc_mapping is None:
            raise ValueError("Must run analyze_atc4_mapping() first")
        
        print(f"\n{'='*60}")
        print("ATC4 CLASSES PER RECORD_ID DISTRIBUTION")
        print(f"{'='*60}")
        
        # Count ATC4 classes per record_id (only valid mappings)
        records_with_mapping = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()].copy()
        
        if len(records_with_mapping) == 0:
            print("No records with valid ATC4 mappings found.")
            return None
        
        # Group by record_id and count unique ATC4 classes
        atc4_per_record = records_with_mapping.groupby('record_id').agg({
            'ATC4 Class': 'nunique',
            'NDC': 'first',  # Get the NDC for reference
            'State': 'first',  # Get the state for reference
            'Year': 'first'    # Get the year for reference
        }).reset_index()
        
        atc4_per_record.columns = ['record_id', 'num_atc4_classes', 'NDC', 'State', 'Year']
        
        # Distribution analysis
        distribution = atc4_per_record['num_atc4_classes'].value_counts().sort_index()
        
        print("Distribution of ATC4 classes per record_id:")
        for classes, count in distribution.items():
            pct = (count / len(atc4_per_record)) * 100
            print(f"  {classes} class(es): {count:,} record_ids ({pct:.1f}%)")
        
        # Show examples of multi-class records
        #multi_class = atc4_per_record[atc4_per_record['num_atc4_classes'] > 1].sort_values('num_atc4_classes', ascending=False)
        
        #if len(multi_class) > 0:
            #print(f"\nTop 10 record_ids with most ATC4 classes:")
            #for _, row in multi_class.head(10).iterrows():
                #record_classes = records_with_mapping[records_with_mapping['record_id'] == row['record_id']]['ATC4 Class'].unique()
                #print(f"  {row['record_id']}: {row['num_atc4_classes']} classes")
                #print(f"    NDC: {row['NDC']}, State: {row['State']}, Year: {row['Year']}")
                #print(f"    Classes: {list(record_classes)}")
                #print()
        
        #return atc4_per_record

    def fetch_atc_names(self, cache_path=None):
        """Fetch ATC class names (ATC4, ATC3, ATC2) from RxNav API."""
        if self.atc_mapping is None:
            raise ValueError("Must run analyze_atc4_mapping() first")
        
        if cache_path is None:
            cache_path = os.path.join(self.base_path, "ATC\\cache_files\\atc_names_cache")
        
        print(f"\n{'='*60}")
        print("FETCHING ATC CLASS NAMES")
        print(f"{'='*60}")
        print(f"Using cache: {cache_path}")
        
        # Get only records with valid ATC4 mappings
        df_with_atc = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()].copy()
        
        # Create ATC3 and ATC2 columns from ATC4
        print("\nCreating ATC3 and ATC2 columns from ATC4...")
        df_with_atc['ATC3 Class'] = df_with_atc['ATC4 Class'].str[:4]
        df_with_atc['ATC2 Class'] = df_with_atc['ATC4 Class'].str[:3]
        
        # Get unique codes for each level
        unique_atc4 = df_with_atc['ATC4 Class'].dropna().unique()
        unique_atc3 = df_with_atc['ATC3 Class'].dropna().unique()
        unique_atc2 = df_with_atc['ATC2 Class'].dropna().unique()
        
        # Filter out invalid codes
        unique_atc4 = [c for c in unique_atc4 if c not in ['No ATC Mapping Found', 'No RxCUI Found', '']]
        unique_atc3 = [c for c in unique_atc3 if c not in ['No ATC Mapping Found', 'No RxCUI Found', '', 'No ', 'No']]
        unique_atc2 = [c for c in unique_atc2 if c not in ['No ATC Mapping Found', 'No RxCUI Found', '', 'No ', 'No']]
        
        print(f"\nUnique codes to fetch:")
        print(f"  ATC4: {len(unique_atc4)}")
        print(f"  ATC3: {len(unique_atc3)}")
        print(f"  ATC2: {len(unique_atc2)}")
        
        # Build mappings
        atc4_names = {}
        atc3_names = {}
        atc2_names = {}
        
        with shelve.open(cache_path) as cache:
            start_time = datetime.now()
            
            print("\nFetching ATC4 names...")
            for code in unique_atc4:
                atc4_names[code] = self._get_atc_name(code, cache)
            
            print("Fetching ATC3 names...")
            for code in unique_atc3:
                atc3_names[code] = self._get_atc_name(code, cache)
            
            print("Fetching ATC2 names...")
            for code in unique_atc2:
                atc2_names[code] = self._get_atc_name(code, cache)
            
            print(f"\nTotal processing time: {(datetime.now() - start_time).total_seconds()/60:.1f} minutes")
        
        # Apply names to all records in atc_mapping
        print("\nApplying names to dataframe...")
        self.atc_mapping['ATC3 Class'] = self.atc_mapping['ATC4 Class'].str[:4]
        self.atc_mapping['ATC2 Class'] = self.atc_mapping['ATC4 Class'].str[:3]
        
        self.atc_mapping['ATC4_Name'] = self.atc_mapping['ATC4 Class'].map(atc4_names).fillna('')
        self.atc_mapping['ATC3_Name'] = self.atc_mapping['ATC3 Class'].map(atc3_names).fillna('')
        self.atc_mapping['ATC2_Name'] = self.atc_mapping['ATC2 Class'].map(atc2_names).fillna('')
        
        print(f"\nATC names added successfully!")
        print("\nSample output:")
        sample = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()][['NDC', 'record_id', 'ATC4 Class', 'ATC4_Name', 'ATC3 Class', 'ATC3_Name', 'ATC2 Class', 'ATC2_Name']].head(5)
        print(sample.to_string())
        
        return self.atc_mapping
    
    def prepare_final_dataframe(self):
        if not hasattr(self, 'df_selected'):
        # ensure selection has run
            self.select_one_atc_per_record()

        print("\n" + "="*60)
        print("PREPARING FINAL DATAFRAME")
        print("="*60)

        self.df_merged = self.df_selected.copy()

        return self.df_merged
    
    def _get_atc_name(self, atc_code, cache):
        """Get ATC class name from code, using cache."""
        cache_key = f"atc_name:{atc_code}"
        if cache_key in cache:
            return cache[cache_key]
        
        try:
            url = f"https://rxnav.nlm.nih.gov/REST/rxclass/class/byId.json?classId={atc_code}"
            response = requests.get(url)
            response.raise_for_status()
            data = response.json()
            
            # Get class name
            if 'rxclassMinConceptList' in data and 'rxclassMinConcept' in data['rxclassMinConceptList']:
                concepts = data['rxclassMinConceptList']['rxclassMinConcept']
                if concepts:
                    name = concepts[0].get('className', '')
                    cache[cache_key] = name
                    return name
            
            cache[cache_key] = ''
            return ''
            
        except Exception as e:
            print(f"Error retrieving name for {atc_code}: {e}")
            cache[cache_key] = ''
            return ''
    
    def analyze_atc3_distribution(self):

        if self.atc_mapping is None:
            raise ValueError("Must run analyze_atc4_mapping() first")
        
        print(f"\n{'='*60}")
        print("ATC3 CLASSES PER RECORD_ID DISTRIBUTION")
        print(f"{'='*60}")
        
        # Create ATC3 classes from ATC4
        records_with_mapping = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()].copy()
        
        if len(records_with_mapping) == 0:
            print("No records with valid ATC4 mappings found.")
            return None
        
        # Create ATC3 class from ATC4 class (first 4 characters)
        records_with_mapping['ATC3 Class'] = records_with_mapping['ATC4 Class'].str[:4]
        
        # Group by record_id and count unique ATC3 classes
        atc3_per_record = records_with_mapping.groupby('record_id').agg({
            'ATC3 Class': 'nunique',
            'NDC': 'first',
            'State': 'first',
            'Year': 'first'
        }).reset_index()
        
        atc3_per_record.columns = ['record_id', 'num_atc3_classes', 'NDC', 'State', 'Year']
        
        # Distribution analysis
        distribution = atc3_per_record['num_atc3_classes'].value_counts().sort_index()
        
        print("Distribution of ATC3 classes per record_id:")
        for classes, count in distribution.items():
            pct = (count / len(atc3_per_record)) * 100
            print(f"  {classes} class(es): {count:,} record_ids ({pct:.1f}%)")
        
        # Show examples of multi-class records
        #multi_class = atc3_per_record[atc3_per_record['num_atc3_classes'] > 1].sort_values('num_atc3_classes', ascending=False)
        
        #if len(multi_class) > 0:
            #print(f"\nTop 10 record_ids with most ATC3 classes:")
            #for _, row in multi_class.head(10).iterrows():
                #record_classes = records_with_mapping[records_with_mapping['record_id'] == row['record_id']]['ATC3 Class'].unique()
                #print(f"  {row['record_id']}: {row['num_atc3_classes']} classes")
                #print(f"    NDC: {row['NDC']}, State: {row['State']}, Year: {row['Year']}")
                #print(f"    Classes: {list(record_classes)}")
                #print()
        
        # Summary statistics
        #print(f"\nATC3 Summary:")
        print(f"  Total record_ids with ATC3 mapping: {len(atc3_per_record):,}")
        #print(f"  Average ATC3 classes per record_id: {atc3_per_record['num_atc3_classes'].mean():.2f}")
        #print(f"  Max ATC3 classes for single record_id: {atc3_per_record['num_atc3_classes'].max()}")
        
        return atc3_per_record

    def analyze_atc2_distribution(self):

        if self.atc_mapping is None:
            raise ValueError("Must run analyze_atc4_mapping() first")
        
        print(f"\n{'='*60}")
        print("ATC2 CLASSES PER RECORD_ID DISTRIBUTION")
        print(f"{'='*60}")
        
        # Create ATC2 classes from ATC4
        records_with_mapping = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()].copy()
        
        if len(records_with_mapping) == 0:
            print("No records with valid ATC4 mappings found.")
            return None
        
        # Create ATC2 class from ATC4 class (first 3 characters)
        records_with_mapping['ATC2 Class'] = records_with_mapping['ATC4 Class'].str[:3]
        
        # Group by record_id and count unique ATC2 classes
        atc2_per_record = records_with_mapping.groupby('record_id').agg({
            'ATC2 Class': 'nunique',
            'NDC': 'first',
            'State': 'first',
            'Year': 'first'
        }).reset_index()
        
        atc2_per_record.columns = ['record_id', 'num_atc2_classes', 'NDC', 'State', 'Year']
        
        # Distribution analysis
        distribution = atc2_per_record['num_atc2_classes'].value_counts().sort_index()
        
        print("Distribution of ATC2 classes per record_id:")
        for classes, count in distribution.items():
            pct = (count / len(atc2_per_record)) * 100
            print(f"  {classes} class(es): {count:,} record_ids ({pct:.1f}%)")
        
        # Show examples of multi-class records
        #multi_class = atc2_per_record[atc2_per_record['num_atc2_classes'] > 1].sort_values('num_atc2_classes', ascending=False)
        
        #if len(multi_class) > 0:
            #print(f"\nTop 10 record_ids with most ATC2 classes:")
            #for _, row in multi_class.head(10).iterrows():
                #record_classes = records_with_mapping[records_with_mapping['record_id'] == row['record_id']]['ATC2 Class'].unique()
                #print(f"  {row['record_id']}: {row['num_atc2_classes']} classes")
                #print(f"    NDC: {row['NDC']}, State: {row['State']}, Year: {row['Year']}")
                #print(f"    Classes: {list(record_classes)}")
                #print()
        
        # Summary statistics
        print(f"\nATC2 Summary:")
        print(f"  Total record_ids with ATC2 mapping: {len(atc2_per_record):,}")
        print(f"  Average ATC2 classes per record_id: {atc2_per_record['num_atc2_classes'].mean():.2f}")
        print(f"  Max ATC2 classes for single record_id: {atc2_per_record['num_atc2_classes'].max()}")
        
        return atc2_per_record

    def export_merged_data(self, output_filename=None):
     
        if not hasattr(self, 'df_selected'):
            # Ensure we’ve selected one ATC per record deterministically
            self.select_one_atc_per_record()

        df_final = self.df_selected.copy()

        # (Optional) scale here if you prefer the output to be scaled
        # If you scale, do the sum check using unscaled copies to avoid rounding issues
        # Keep unscaled copies to verify sums:
        src_units_sum = self.df_cleaned['Units Reimbursed'].sum()
        src_rx_sum    = self.df_cleaned['Number of Prescriptions'].sum()
        out_units_sum = df_final['Units Reimbursed'].sum()
        out_rx_sum    = df_final['Number of Prescriptions'].sum()

        # now scale for the CSV
        df_final['Units Reimbursed'] /= 1e9
        df_final['Number of Prescriptions'] /= 1e6

        if output_filename is None:
            output_filename = f"merged_NEWdata_{self.year}.csv"
        output_path = os.path.join(self.base_path, f"ATC\\merged_data\\{output_filename}")
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        df_final.to_csv(output_path, index=False)

        print("\n" + "="*60)
        print("DATA EXPORT COMPLETE (ONE ATC PER RECORD_ID)")
        print("="*60)
        print(f"Exported: {output_path}")
        print(f"Rows exported: {len(df_final):,}")
        print(f"Sums preserved? Units: {abs(src_units_sum - out_units_sum) < 1e-6}, "
            f"Prescriptions: {abs(src_rx_sum - out_rx_sum) < 1e-6}")
        return output_path

    @staticmethod #ATC distribution analysis across multiple years 
    def create_multi_year_distribution_analysis_simple(years_list):
        
        print("Creating Multi-Year ATC Distribution Analysis (Percentages Only)...")
        print("="*70)
        
        # Dictionary to store only percentage results
        results = {
            'ATC4_1_class': {},'ATC4_2_classes': {},'ATC4_3+_classes': {},
            'ATC3_1_class': {},'ATC3_2_classes': {},'ATC3_3+_classes': {},
            'ATC2_1_class': {},'ATC2_2_classes': {},'ATC2_3+_classes': {}
        }
        
        for year in years_list:
            print(f"Processing {year}...", end=" ")
            
            try:
                # Create a separate analyzer instance for each year
                analyzer = NDCATCAnalyzer(year=year)
                analyzer.clean_sdud_data()
                analyzer.adding_key()
                analyzer.analyze_atc4_mapping()
                
                # Get records with valid ATC4 mappings
                records = analyzer.atc_mapping[analyzer.atc_mapping['ATC4 Class'].notna()].copy()
                records['ATC2 Class'] = records['ATC4 Class'].str[:3]
                records['ATC3 Class'] = records['ATC4 Class'].str[:4]
                
                # Calculate distributions
                atc4_per_record = records.groupby('record_id')['ATC4 Class'].nunique()
                atc3_per_record = records.groupby('record_id')['ATC3 Class'].nunique()
                atc2_per_record = records.groupby('record_id')['ATC2 Class'].nunique()
                
                # ATC4 percentages
                atc4_dist = atc4_per_record.value_counts().sort_index()
                total_atc4 = len(atc4_per_record)
                results['ATC4_1_class'][year] = f"{(atc4_dist.get(1, 0) / total_atc4 * 100):.1f}%"
                results['ATC4_2_classes'][year] = f"{(atc4_dist.get(2, 0) / total_atc4 * 100):.1f}%"
                results['ATC4_3+_classes'][year] = f"{(atc4_dist[atc4_dist.index >= 3].sum() / total_atc4 * 100):.1f}%"
                
                # ATC3 percentages
                atc3_dist = atc3_per_record.value_counts().sort_index()
                total_atc3 = len(atc3_per_record)
                results['ATC3_1_class'][year] = f"{(atc3_dist.get(1, 0) / total_atc3 * 100):.1f}%"
                results['ATC3_2_classes'][year] = f"{(atc3_dist.get(2, 0) / total_atc3 * 100):.1f}%"
                results['ATC3_3+_classes'][year] = f"{(atc3_dist[atc3_dist.index >= 3].sum() / total_atc3 * 100):.1f}%"
                
                # ATC2 percentages
                atc2_dist = atc2_per_record.value_counts().sort_index()
                total_atc2 = len(atc2_per_record)
                results['ATC2_1_class'][year] = f"{(atc2_dist.get(1, 0) / total_atc2 * 100):.1f}%"
                results['ATC2_2_classes'][year] = f"{(atc2_dist.get(2, 0) / total_atc2 * 100):.1f}%"
                results['ATC2_3+_classes'][year] = f"{(atc2_dist[atc2_dist.index >= 3].sum() / total_atc2 * 100):.1f}%"
                
                print("✓")
                
            except Exception as e:
                print(f"✗ Error: {e}")
                # Fill with N/A for failed years
                for key in results.keys():
                    results[key][year] = "N/A"
        
        # Create DataFrame (only percentages)
        df_percentages = pd.DataFrame(results).T
        
        print(f"\nATC DISTRIBUTION PERCENTAGES ACROSS YEARS")
        print("="*60)
        print(df_percentages)
        
        return df_percentages
    
    @staticmethod
    def analyze_general_atc_overview(years_list):

        print("Creating ATC2 & ATC3 Overview: Unique NDCs per Class Across Years...")
        print("="*78)

        # Per-year summary tables (ATC2 & ATC3)
        atc2_year_results = {}
        atc3_year_results = {}

        # Keep the per-year “clean” dataframes if you want to inspect / reuse
        overview_dataframes_atc2 = {}
        overview_dataframes_atc3 = {}

        for year in years_list:
            print(f"Processing {year}...", end=" ")
            try:
                analyzer = NDCATCAnalyzer(year=year)
                analyzer.clean_sdud_data()
                analyzer.adding_key()
                analyzer.analyze_atc4_mapping()

                # Keep only mapped rows
                records = analyzer.atc_mapping[analyzer.atc_mapping['ATC4 Class'].notna()].copy()
                if records.empty:
                    print("No records with ATC mapping found")
                    atc2_year_results[year] = pd.DataFrame()
                    atc3_year_results[year] = pd.DataFrame()
                    overview_dataframes_atc2[year] = pd.DataFrame()
                    overview_dataframes_atc3[year] = pd.DataFrame()
                    continue

                # Derive ATC2/ATC3 from ATC4
                records['ATC2 Class'] = records['ATC4 Class'].str[:3]
                records['ATC3 Class'] = records['ATC4 Class'].str[:4]

                # --- ATC2 ---
                pairs2 = records[['record_id', 'NDC', 'ATC2 Class']].drop_duplicates()
                atc2_summary = (
                    pairs2.groupby('ATC2 Class')
                        .agg(Unique_NDCs=('NDC', 'nunique'),
                            Total_Records=('record_id', 'nunique'))
                        .sort_values('Unique_NDCs', ascending=False)
                )
                total_unique_ndcs2 = pairs2['NDC'].nunique()
                atc2_summary['Percentage_of_NDCs'] = (
                    atc2_summary['Unique_NDCs'] / total_unique_ndcs2 * 100
                ).round(1)

                # --- ATC3 ---
                pairs3 = records[['record_id', 'NDC', 'ATC3 Class']].drop_duplicates()
                atc3_summary = (
                    pairs3.groupby('ATC3 Class')
                        .agg(Unique_NDCs=('NDC', 'nunique'),
                            Total_Records=('record_id', 'nunique'))
                        .sort_values('Unique_NDCs', ascending=False)
                )
                total_unique_ndcs3 = pairs3['NDC'].nunique()
                atc3_summary['Percentage_of_NDCs'] = (
                    atc3_summary['Unique_NDCs'] / total_unique_ndcs3 * 100
                ).round(1)

                # Store per-year tables
                atc2_year_results[year] = atc2_summary
                atc3_year_results[year] = atc3_summary
                overview_dataframes_atc2[year] = pairs2
                overview_dataframes_atc3[year] = pairs3

                print(f"✓ (ATC2: {len(atc2_summary)} classes, {total_unique_ndcs2:,} unique NDCs; "
                    f"ATC3: {len(atc3_summary)} classes, {total_unique_ndcs3:,} unique NDCs)")

            except Exception as e:
                print(f"✗ Error: {e}")
                atc2_year_results[year] = pd.DataFrame()
                atc3_year_results[year] = pd.DataFrame()
                overview_dataframes_atc2[year] = pd.DataFrame()
                overview_dataframes_atc3[year] = pd.DataFrame()

        # ---- Pretty print per-year (optional) ----
        print("\nUNIQUE NDCs PER ATC2 CLASS BY YEAR")
        print("="*60)
        for year in years_list:
            if not atc2_year_results[year].empty:
                print(f"\n{year}:")
                print("-"*40)
                print(f"Total ATC2 Classes: {len(atc2_year_results[year])}")
                print(f"Total Unique NDCs: {overview_dataframes_atc2[year]['NDC'].nunique():,}")
                print("\nTop 10 ATC2 Classes by Unique NDCs:")
                print(atc2_year_results[year].head(10))
            else:
                print(f"\n{year}: No ATC2 data available")

        print("\nUNIQUE NDCs PER ATC3 CLASS BY YEAR")
        print("="*60)
        for year in years_list:
            if not atc3_year_results[year].empty:
                print(f"\n{year}:")
                print("-"*40)
                print(f"Total ATC3 Classes: {len(atc3_year_results[year])}")
                print(f"Total Unique NDCs: {overview_dataframes_atc3[year]['NDC'].nunique():,}")
                print("\nTop 10 ATC3 Classes by Unique NDCs:")
                print(atc3_year_results[year].head(10))
            else:
                print(f"\n{year}: No ATC3 data available")

        # ---- Build comparison tables (rows = classes, cols = years, values = Unique_NDCs) ----
        def build_comparison(year_tables):
            all_classes = set()
            for tbl in year_tables.values():
                if not tbl.empty:
                    all_classes.update(tbl.index.tolist())
            comp = {}
            for cls in sorted(all_classes):
                comp[cls] = {}
                for y in years_list:
                    if not year_tables[y].empty and cls in year_tables[y].index:
                        comp[cls][y] = int(year_tables[y].loc[cls, 'Unique_NDCs'])
                    else:
                        comp[cls][y] = 0
            dfc = pd.DataFrame(comp).T
            return dfc.loc[dfc.sum(axis=1).sort_values(ascending=False).index]

        atc2_comparison = build_comparison(atc2_year_results)
        atc3_comparison = build_comparison(atc3_year_results)

        print("\nTop 15 ATC2 Classes by Total Unique NDCs Across All Years:")
        print(atc2_comparison.head(15))
        print("\nTop 15 ATC3 Classes by Total Unique NDCs Across All Years:")
        print(atc3_comparison.head(15))

        return atc2_year_results, atc3_year_results, atc2_comparison, atc3_comparison
    
    @staticmethod
    def get_atc_ndc_details(year, top_n=10):
        
    
        print(f"\nDETAILED ATC2 & ATC3 ANALYSIS FOR {year}")
        print("=" * 70)

        try:
            analyzer = NDCATCAnalyzer(year=year)
            analyzer.clean_sdud_data()
            analyzer.adding_key()
            analyzer.analyze_atc4_mapping()

            df = analyzer.atc_mapping[analyzer.atc_mapping['ATC4 Class'].notna()].copy()
            if df.empty:
                print("No valid ATC mappings found.")
                return pd.DataFrame(), pd.DataFrame()

            # Derive ATC3 & ATC2 from ATC4
            df['ATC3 Class'] = df['ATC4 Class'].str[:4]
            df['ATC2 Class'] = df['ATC4 Class'].str[:3]

            # --- Deduplicate to avoid inflated counts ---
            pairs2 = df[['record_id', 'NDC', 'ATC2 Class']].drop_duplicates()
            pairs3 = df[['record_id', 'NDC', 'ATC3 Class']].drop_duplicates()

            # === Compute denominators ===
            total_unique_ndcs = df['NDC'].nunique()
            total_records = len(df)
            print(f"Total records analyzed: {total_records:,}")
            print(f"Total unique NDCs mapped: {total_unique_ndcs:,}")

            # === ATC2 Summary ===
            atc2_summary = (
                pairs2.groupby('ATC2 Class')
                    .agg(Unique_NDCs=('NDC', 'nunique'),
                        Total_Records=('record_id', 'nunique'))
            )
            atc2_summary['Percentage_of_NDCs'] = (
                atc2_summary['Unique_NDCs'] / total_unique_ndcs * 100
            ).round(1)
            atc2_summary['Avg_Records_per_NDC'] = (
                atc2_summary['Total_Records'] / atc2_summary['Unique_NDCs']
            ).round(1)
            atc2_summary = atc2_summary.sort_values('Unique_NDCs', ascending=False)

            # === ATC3 Summary ===
            atc3_summary = (
                pairs3.groupby('ATC3 Class')
                    .agg(Unique_NDCs=('NDC', 'nunique'),
                        Total_Records=('record_id', 'nunique'))
            )
            atc3_summary['Percentage_of_NDCs'] = (
                atc3_summary['Unique_NDCs'] / total_unique_ndcs * 100
            ).round(1)
            atc3_summary['Avg_Records_per_NDC'] = (
                atc3_summary['Total_Records'] / atc3_summary['Unique_NDCs']
            ).round(1)
            atc3_summary = atc3_summary.sort_values('Unique_NDCs', ascending=False)

            # === Display ===
            print(f"\nTOP {top_n} ATC2 CLASSES BY UNIQUE NDCs")
            print(atc2_summary.head(top_n))
            print(f"\nTOP {top_n} ATC3 CLASSES BY UNIQUE NDCs")
            print(atc3_summary.head(top_n))

            # === Sample NDCs ===
            print("\nSample NDCs for top 3 ATC2 classes:")
            for i, atc2 in enumerate(atc2_summary.head(3).index, start=1):
                sample = df[df['ATC2 Class'] == atc2]['NDC'].drop_duplicates().head(5).tolist()
                print(f"{i}. {atc2} ({atc2_summary.loc[atc2, 'Unique_NDCs']} NDCs): {', '.join(sample)}")

            print("\nSample NDCs for top 3 ATC3 classes:")
            for i, atc3 in enumerate(atc3_summary.head(3).index, start=1):
                sample = df[df['ATC3 Class'] == atc3]['NDC'].drop_duplicates().head(5).tolist()
                print(f"{i}. {atc3} ({atc3_summary.loc[atc3, 'Unique_NDCs']} NDCs): {', '.join(sample)}")

            return atc2_summary, atc3_summary

        except Exception as e:
            print(f"Error analyzing ATC2/ATC3 details for {year}: {e}")
            return pd.DataFrame(), pd.DataFrame()



In [None]:
analyzer = NDCATCAnalyzer(year=2024)
analyzer.clean_sdud_data()           # Clean SDUD data
analyzer.adding_key()                # Add record_id key
#analyzer.generate_ndc_txt()          # Generate NDC text file
analyzer.analyze_atc4_mapping()      # Merge ATC4 by record_id & NDC

In [None]:
# Distribution analyses
atc4_dist = analyzer.analyze_atc4_distribution()
atc3_dist = analyzer.analyze_atc3_distribution() 
atc2_dist = analyzer.analyze_atc2_distribution()
#different_atc2_records = analyzer.analyze_different_atc2_records()


In [None]:
analyzer.fetch_atc_names()           
analyzer.prepare_final_dataframe()   
analyzer.export_merged_data()  
# Simple version - only percentages

In [22]:
#seeing unique NDC for unique ATC2 and ATC3 across years
years = [2020, 2021, 2022, 2023, 2024]
#NDCATCAnalyzer.create_multi_year_distribution_analysis_simple(years)
NDCATCAnalyzer.analyze_general_atc_overview(years)
# Detailed ATC2 & ATC3 analysis for a specific year
atc2_2023, atc3_2023 = NDCATCAnalyzer.get_atc_ndc_details(year=2024, top_n=10)


Creating ATC2 & ATC3 Overview: Unique NDCs per Class Across Years...
Processing 2020... Reading CSV file: c:\Users\lholguin\OneDrive - purdue.edu\VS code\Data\SDUD\SDUD2020.csv
Total rows in 2020 before filtering: 4922728
Rows after removing NA: 2508077
Rows after filtering State='XX': 2284815
Unique NDCs: 32220
Adding record_id column...
Created 2284815 record IDs
Sample record_id: AK_2020_4_FFSU_00002143380
Total rows in ATC4 mapping file: 4011219
Unique NDCs in ATC4 mapping file: 27661
Merging ATC4 mapping with cleaned data using record_id and NDC...
Merged dataframe rows: 4122259
Records with ATC4 mapping: 4011219 (97.3%)

Records without ATC4 mapping: 111040
Unique NDCs without mapping: 4559
✓ (ATC2: 90 classes, 27,661 unique NDCs; ATC3: 212 classes, 27,661 unique NDCs)
Processing 2021... Reading CSV file: c:\Users\lholguin\OneDrive - purdue.edu\VS code\Data\SDUD\SDUD2021.csv
Total rows in 2021 before filtering: 5042532
Rows after removing NA: 2575044
Rows after filtering State='X

In [None]:
#Just checking overlap between files with and without key
nokey_path=rf'C:\Users\{user}\OneDrive - purdue.edu\VS code\Data\ATC\ATC4_classes\Classes_notgood\NDCf_2023_ATC4_classes.csv'
keyed_path=rf'C:\Users\{user}\OneDrive - purdue.edu\VS code\Data\ATC\ATC4_classes\NDCNEW_2023_ATC4_classes.csv'

# Load them
keyed = pd.read_csv(keyed_path, dtype=str)
nokey = pd.read_csv(nokey_path, dtype=str)

# Normalize NDCs (remove hyphens, pad to 11 digits)
for df in [keyed, nokey]:
    df["NDC"] = df["NDC"].str.replace("-", "", regex=False).str.zfill(11)

# --- Summary stats ---
summary = {
    "File": ["With key (NDCNEW_2024_ATC4_classes)", "Without key (NDCf_2024_ATC4_classes)"],
    "Total rows": [len(keyed), len(nokey)],
    "Unique NDCs": [keyed["NDC"].nunique(), nokey["NDC"].nunique()],
    "Mapped NDCs (non-null ATC)": [
        keyed["ATC4 Class"].notna().sum(),
        nokey["ATC4 Class"].notna().sum(),
    ],
}
summary_df = pd.DataFrame(summary)

# --- Compare overlap of unique NDCs ---
ndc_keyed = set(keyed["NDC"].unique())
ndc_nokey = set(nokey["NDC"].unique())

overlap_ndcs = len(ndc_keyed & ndc_nokey)
only_in_nokey = len(ndc_nokey - ndc_keyed)
only_in_keyed = len(ndc_keyed - ndc_nokey)

comparison = pd.DataFrame({
    "Metric": ["Overlap NDCs", "Only in without-key file", "Only in with-key file", "Percent overlap"],
    "Value": [overlap_ndcs, only_in_nokey, only_in_keyed, overlap_ndcs / len(ndc_nokey) * 100]
})

print("\n=== Summary of Each File ===")
print(summary_df.to_string(index=False))

print("\n=== NDC Overlap Comparison ===")
print(comparison.to_string(index=False))