In [31]:
import pandas as pd
import numpy as np
import requests
import shelve
import os
from datetime import datetime
import matplotlib.pyplot as plt
import re

In [32]:
user="lholguin"
#user in personal pc <- "asus"

In [None]:
#New changes to the class
class NDCATCAnalyzer:

    def __init__(self, year, base_path=None):

        self.year = year
        if base_path is None:
            #Lookup the user's base path
            self.base_path = rf"c:\Users\{user}\OneDrive - purdue.edu\VS code\Data"
        else:
            self.base_path = base_path
            
        self.df_cleaned = None
        self.df_merged = None
        self.atc_mapping = None
        self.df_faf = None
        
    def clean_sdud_data(self):
        csv_file = os.path.join(self.base_path, f"SDUD\\SDUD{self.year}.csv")
        print(f"Reading CSV file: {csv_file}")
        
        # Read with NDC as string to preserve leading zeros
        df = pd.read_csv(csv_file, dtype={'NDC': 'object'})
        
        print(f"Total rows in {self.year} before filtering: {len(df)}")
        
        # Remove NA values
        df_filtered = df.dropna(subset=['Units Reimbursed', 'Number of Prescriptions'])
        print(f"Rows after removing NA: {len(df_filtered)}")
        
        # Filter out State='XX'
        df_filtered = df_filtered[df_filtered['State'] != 'XX']
        print(f"Rows after filtering State='XX': {len(df_filtered)}")
        print(f"Unique NDCs: {df_filtered['NDC'].nunique()}")
        
        self.df_cleaned = df_filtered
        return self.df_cleaned
    
    #NEW
    def adding_key(self):
        """Add record_id column to cleaned dataframe."""
        if self.df_cleaned is None:
            raise ValueError("Must run clean_sdud_data() first")
        
        print("Adding record_id column...")
        
        # Create record_id column
        self.df_cleaned['record_id'] = (
            self.df_cleaned['State'].astype(str) + "_" +
            self.df_cleaned['Year'].astype(str) + "_" +
            self.df_cleaned['Quarter'].astype(str) + "_" +
            self.df_cleaned['Utilization Type'].astype(str) + "_" +
            self.df_cleaned['NDC'].astype(str)
        )
        
        print(f"Created {len(self.df_cleaned)} record IDs")
        print(f"Sample record_id: {self.df_cleaned['record_id'].iloc[0]}")
        
        return self.df_cleaned
    
    def generate_ndc_txt(self, output_filename=None):
        """Step 2: Generate text file with unique NDC values and their record_id keys."""
        if self.df_cleaned is None:
            raise ValueError("Must run clean_sdud_data() first")
        
        if 'record_id' not in self.df_cleaned.columns:
            raise ValueError("Must run adding_key() first to create record_id column")
            
        if output_filename is None:
            output_filename = f"NDCNEW_{self.year}.txt"
        
        output_path = os.path.join(self.base_path, f"ATC\\text_files\\{output_filename}")
        
        # Get unique combinations of NDC and record_id
        unique_pairs = self.df_cleaned[['NDC', 'record_id']].drop_duplicates()
        
        with open(output_path, 'w') as f:
            # Write header
            f.write("NDC\trecord_id\n")
            # Write each unique pair
            for _, row in unique_pairs.iterrows():
                f.write(f"{row['NDC']}\t{row['record_id']}\n")
        
        print(f"Exported to: {output_path}")
        print(f"Unique record_id values: {unique_pairs['record_id'].nunique()}")
        return output_path
    
    def analyze_atc4_mapping(self):
        """Step 3: Analyze ATC4 mapping results and identify missing NDCs."""
        atc4_path = os.path.join(self.base_path, f"ATC\\ATC4_classes\\NDCNEW_{self.year}_ATC4_classes.csv")
        
        # Read ATC4 mapping
        df_atc4 = pd.read_csv(atc4_path, dtype={'NDC': 'object', 'record_id': 'string'})
        df_atc4['NDC'] = df_atc4['NDC'].str.zfill(11)

        #Print how many unique NDC codes were mapped
        print(f"Total rows in ATC4 mapping file: {len(df_atc4)}")
        print(f"Unique NDCs in ATC4 mapping file: {df_atc4['NDC'].nunique()}")

        # Ensure consistent data types before merge
        self.df_cleaned['record_id'] = self.df_cleaned['record_id'].astype('string')
        self.df_cleaned['NDC'] = self.df_cleaned['NDC'].astype('object')  
        df_atc4['record_id'] = df_atc4['record_id'].astype('string')
        df_atc4['NDC'] = df_atc4['NDC'].astype('object')  

        # Merge ATC4 mapping with cleaned data using record_id
        if self.df_cleaned is None:
            raise ValueError("Must run clean_sdud_data() and adding_key() first")
        
        if 'record_id' not in self.df_cleaned.columns:
            raise ValueError("Must run adding_key() first to create record_id column")
        
        print(f"Merging ATC4 mapping with cleaned data using record_id and NDC...")
        
        # Merge on BOTH record_id AND NDC
        self.atc_mapping = pd.merge(
            self.df_cleaned,
            df_atc4[['record_id', 'NDC', 'ATC4 Class']],  # Include NDC in the selection
            on=['record_id', 'NDC'],  # Merge on both columns
            how='left'
        )
        
        # Print rows of the merged dataframe
        print(f"Merged dataframe rows: {len(self.atc_mapping)}")
        #print(self.atc_mapping.head())
        total_records = len(self.atc_mapping)
        mapped_records = self.atc_mapping['ATC4 Class'].notna().sum()
        print(f"Records with ATC4 mapping: {mapped_records} ({mapped_records/total_records*100:.1f}%)")

        # Identify missing mappings
        missing_records = self.atc_mapping[self.atc_mapping['ATC4 Class'].isna()]
        if len(missing_records) > 0:
            print(f"\nRecords without ATC4 mapping: {len(missing_records)}")
            print(f"Unique NDCs without mapping: {missing_records['NDC'].nunique()}")
        
        return self.atc_mapping
        
    def analyze_atc4_distribution(self):

        if self.atc_mapping is None:
            raise ValueError("Must run analyze_atc4_mapping() first")
        
        print(f"\n{'='*60}")
        print("ATC4 CLASSES PER RECORD_ID DISTRIBUTION")
        print(f"{'='*60}")
        
        # Count ATC4 classes per record_id (only valid mappings)
        records_with_mapping = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()].copy()
        
        if len(records_with_mapping) == 0:
            print("No records with valid ATC4 mappings found.")
            return None
        
        # Group by record_id and count unique ATC4 classes
        atc4_per_record = records_with_mapping.groupby('record_id').agg({
            'ATC4 Class': 'nunique',
            'NDC': 'first',  # Get the NDC for reference
            'State': 'first',  # Get the state for reference
            'Year': 'first'    # Get the year for reference
        }).reset_index()
        
        atc4_per_record.columns = ['record_id', 'num_atc4_classes', 'NDC', 'State', 'Year']
        
        # Distribution analysis
        distribution = atc4_per_record['num_atc4_classes'].value_counts().sort_index()
        
        print("Distribution of ATC4 classes per record_id:")
        for classes, count in distribution.items():
            pct = (count / len(atc4_per_record)) * 100
            print(f"  {classes} class(es): {count:,} record_ids ({pct:.1f}%)")
        
        # Show examples of multi-class records
        #multi_class = atc4_per_record[atc4_per_record['num_atc4_classes'] > 1].sort_values('num_atc4_classes', ascending=False)
        
        #if len(multi_class) > 0:
            #print(f"\nTop 10 record_ids with most ATC4 classes:")
            #for _, row in multi_class.head(10).iterrows():
                #record_classes = records_with_mapping[records_with_mapping['record_id'] == row['record_id']]['ATC4 Class'].unique()
                #print(f"  {row['record_id']}: {row['num_atc4_classes']} classes")
                #print(f"    NDC: {row['NDC']}, State: {row['State']}, Year: {row['Year']}")
                #print(f"    Classes: {list(record_classes)}")
                #print()
        
        #return atc4_per_record

    def fetch_atc_names(self, cache_path=None):
        """Fetch ATC class names (ATC4, ATC3, ATC2) from RxNav API."""
        if self.atc_mapping is None:
            raise ValueError("Must run analyze_atc4_mapping() first")
        
        if cache_path is None:
            cache_path = os.path.join(self.base_path, "ATC\\cache_files\\atc_names_cache")
        
        print(f"\n{'='*60}")
        print("FETCHING ATC CLASS NAMES")
        print(f"{'='*60}")
        print(f"Using cache: {cache_path}")
        
        # Get only records with valid ATC4 mappings
        df_with_atc = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()].copy()
        
        # Create ATC3 and ATC2 columns from ATC4
        print("\nCreating ATC3 and ATC2 columns from ATC4...")
        df_with_atc['ATC3 Class'] = df_with_atc['ATC4 Class'].str[:4]
        df_with_atc['ATC2 Class'] = df_with_atc['ATC4 Class'].str[:3]
        
        # Get unique codes for each level
        unique_atc4 = df_with_atc['ATC4 Class'].dropna().unique()
        unique_atc3 = df_with_atc['ATC3 Class'].dropna().unique()
        unique_atc2 = df_with_atc['ATC2 Class'].dropna().unique()
        
        # Filter out invalid codes
        unique_atc4 = [c for c in unique_atc4 if c not in ['No ATC Mapping Found', 'No RxCUI Found', '']]
        unique_atc3 = [c for c in unique_atc3 if c not in ['No ATC Mapping Found', 'No RxCUI Found', '', 'No ', 'No']]
        unique_atc2 = [c for c in unique_atc2 if c not in ['No ATC Mapping Found', 'No RxCUI Found', '', 'No ', 'No']]
        
        print(f"\nUnique codes to fetch:")
        print(f"  ATC4: {len(unique_atc4)}")
        print(f"  ATC3: {len(unique_atc3)}")
        print(f"  ATC2: {len(unique_atc2)}")
        
        # Build mappings
        atc4_names = {}
        atc3_names = {}
        atc2_names = {}
        
        with shelve.open(cache_path) as cache:
            start_time = datetime.now()
            
            print("\nFetching ATC4 names...")
            for code in unique_atc4:
                atc4_names[code] = self._get_atc_name(code, cache)
            
            print("Fetching ATC3 names...")
            for code in unique_atc3:
                atc3_names[code] = self._get_atc_name(code, cache)
            
            print("Fetching ATC2 names...")
            for code in unique_atc2:
                atc2_names[code] = self._get_atc_name(code, cache)
            
            print(f"\nTotal processing time: {(datetime.now() - start_time).total_seconds()/60:.1f} minutes")
        
        # Apply names to all records in atc_mapping
        print("\nApplying names to dataframe...")
        self.atc_mapping['ATC3 Class'] = self.atc_mapping['ATC4 Class'].str[:4]
        self.atc_mapping['ATC2 Class'] = self.atc_mapping['ATC4 Class'].str[:3]
        
        self.atc_mapping['ATC4_Name'] = self.atc_mapping['ATC4 Class'].map(atc4_names).fillna('')
        self.atc_mapping['ATC3_Name'] = self.atc_mapping['ATC3 Class'].map(atc3_names).fillna('')
        self.atc_mapping['ATC2_Name'] = self.atc_mapping['ATC2 Class'].map(atc2_names).fillna('')
        
        print(f"\nATC names added successfully!")
        print("\nSample output:")
        sample = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()][['NDC', 'record_id', 'ATC4 Class', 'ATC4_Name', 'ATC3 Class', 'ATC3_Name', 'ATC2 Class', 'ATC2_Name']].head(5)
        print(sample.to_string())
        
        return self.atc_mapping
    
    def prepare_final_dataframe(self):
        if self.atc_mapping is None:
            raise ValueError("Must run fetch_atc_names() first")
        
        print(f"\n{'='*60}")
        print("PREPARING FINAL DATAFRAME")
        print(f"{'='*60}")
        
        # Create a copy for final output
        self.df_merged = self.atc_mapping.copy()
        
        # Scale units
        print("\nScaling units...")
        self.df_merged['Units Reimbursed'] = self.df_merged['Units Reimbursed'] / 1e9
        self.df_merged['Number of Prescriptions'] = self.df_merged['Number of Prescriptions'] / 1e6
        
        # Report final statistics
        total_records = len(self.df_merged)
        mapped_records = self.df_merged['ATC4 Class'].notna().sum()

        #Those are with duplicates included (merged_dataframe), so the metrics will be higher than the deduplicated exported file
        #print(f"Records with ATC4 mapping: {mapped_records:,} ({mapped_records/total_records*100:.1f}%)") 
        #print(f"Total Units Reimbursed: {self.df_merged['Units Reimbursed'].sum():.2f} Billion")
        #print(f"Total Prescriptions: {self.df_merged['Number of Prescriptions'].sum():.2f} Million")
        
        return self.df_merged
    
    def _get_atc_name(self, atc_code, cache):
        """Get ATC class name from code, using cache."""
        cache_key = f"atc_name:{atc_code}"
        if cache_key in cache:
            return cache[cache_key]
        
        try:
            url = f"https://rxnav.nlm.nih.gov/REST/rxclass/class/byId.json?classId={atc_code}"
            response = requests.get(url)
            response.raise_for_status()
            data = response.json()
            
            # Get class name
            if 'rxclassMinConceptList' in data and 'rxclassMinConcept' in data['rxclassMinConceptList']:
                concepts = data['rxclassMinConceptList']['rxclassMinConcept']
                if concepts:
                    name = concepts[0].get('className', '')
                    cache[cache_key] = name
                    return name
            
            cache[cache_key] = ''
            return ''
            
        except Exception as e:
            print(f"Error retrieving name for {atc_code}: {e}")
            cache[cache_key] = ''
            return ''
    
    def analyze_atc3_distribution(self):
        """Analyze distribution of ATC3 classes per record_id."""
        if self.atc_mapping is None:
            raise ValueError("Must run analyze_atc4_mapping() first")
        
        print(f"\n{'='*60}")
        print("ATC3 CLASSES PER RECORD_ID DISTRIBUTION")
        print(f"{'='*60}")
        
        # Create ATC3 classes from ATC4
        records_with_mapping = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()].copy()
        
        if len(records_with_mapping) == 0:
            print("No records with valid ATC4 mappings found.")
            return None
        
        # Create ATC3 class from ATC4 class (first 4 characters)
        records_with_mapping['ATC3 Class'] = records_with_mapping['ATC4 Class'].str[:4]
        
        # Group by record_id and count unique ATC3 classes
        atc3_per_record = records_with_mapping.groupby('record_id').agg({
            'ATC3 Class': 'nunique',
            'NDC': 'first',
            'State': 'first',
            'Year': 'first'
        }).reset_index()
        
        atc3_per_record.columns = ['record_id', 'num_atc3_classes', 'NDC', 'State', 'Year']
        
        # Distribution analysis
        distribution = atc3_per_record['num_atc3_classes'].value_counts().sort_index()
        
        print("Distribution of ATC3 classes per record_id:")
        for classes, count in distribution.items():
            pct = (count / len(atc3_per_record)) * 100
            print(f"  {classes} class(es): {count:,} record_ids ({pct:.1f}%)")
        
        # Show examples of multi-class records
        #multi_class = atc3_per_record[atc3_per_record['num_atc3_classes'] > 1].sort_values('num_atc3_classes', ascending=False)
        
        #if len(multi_class) > 0:
            #print(f"\nTop 10 record_ids with most ATC3 classes:")
            #for _, row in multi_class.head(10).iterrows():
                #record_classes = records_with_mapping[records_with_mapping['record_id'] == row['record_id']]['ATC3 Class'].unique()
                #print(f"  {row['record_id']}: {row['num_atc3_classes']} classes")
                #print(f"    NDC: {row['NDC']}, State: {row['State']}, Year: {row['Year']}")
                #print(f"    Classes: {list(record_classes)}")
                #print()
        
        # Summary statistics
        print(f"\nATC3 Summary:")
        print(f"  Total record_ids with ATC3 mapping: {len(atc3_per_record):,}")
        print(f"  Average ATC3 classes per record_id: {atc3_per_record['num_atc3_classes'].mean():.2f}")
        print(f"  Max ATC3 classes for single record_id: {atc3_per_record['num_atc3_classes'].max()}")
        
        return atc3_per_record

    def analyze_atc2_distribution(self):

        if self.atc_mapping is None:
            raise ValueError("Must run analyze_atc4_mapping() first")
        
        print(f"\n{'='*60}")
        print("ATC2 CLASSES PER RECORD_ID DISTRIBUTION")
        print(f"{'='*60}")
        
        # Create ATC2 classes from ATC4
        records_with_mapping = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()].copy()
        
        if len(records_with_mapping) == 0:
            print("No records with valid ATC4 mappings found.")
            return None
        
        # Create ATC2 class from ATC4 class (first 3 characters)
        records_with_mapping['ATC2 Class'] = records_with_mapping['ATC4 Class'].str[:3]
        
        # Group by record_id and count unique ATC2 classes
        atc2_per_record = records_with_mapping.groupby('record_id').agg({
            'ATC2 Class': 'nunique',
            'NDC': 'first',
            'State': 'first',
            'Year': 'first'
        }).reset_index()
        
        atc2_per_record.columns = ['record_id', 'num_atc2_classes', 'NDC', 'State', 'Year']
        
        # Distribution analysis
        distribution = atc2_per_record['num_atc2_classes'].value_counts().sort_index()
        
        print("Distribution of ATC2 classes per record_id:")
        for classes, count in distribution.items():
            pct = (count / len(atc2_per_record)) * 100
            print(f"  {classes} class(es): {count:,} record_ids ({pct:.1f}%)")
        
        # Show examples of multi-class records
        #multi_class = atc2_per_record[atc2_per_record['num_atc2_classes'] > 1].sort_values('num_atc2_classes', ascending=False)
        
        #if len(multi_class) > 0:
            #print(f"\nTop 10 record_ids with most ATC2 classes:")
            #for _, row in multi_class.head(10).iterrows():
                #record_classes = records_with_mapping[records_with_mapping['record_id'] == row['record_id']]['ATC2 Class'].unique()
                #print(f"  {row['record_id']}: {row['num_atc2_classes']} classes")
                #print(f"    NDC: {row['NDC']}, State: {row['State']}, Year: {row['Year']}")
                #print(f"    Classes: {list(record_classes)}")
                #print()
        
        # Summary statistics
        print(f"\nATC2 Summary:")
        print(f"  Total record_ids with ATC2 mapping: {len(atc2_per_record):,}")
        print(f"  Average ATC2 classes per record_id: {atc2_per_record['num_atc2_classes'].mean():.2f}")
        print(f"  Max ATC2 classes for single record_id: {atc2_per_record['num_atc2_classes'].max()}")
        
        return atc2_per_record

    def export_merged_data(self, output_filename=None):
     
        if self.df_merged is None:
            raise ValueError("Must run prepare_final_dataframe() first to create merged dataframe")
            
        if output_filename is None:
            output_filename = f"merged_NEWdata_{self.year}.csv"

        output_path = os.path.join(self.base_path, f"ATC\\merged_data\\{output_filename}")

        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        # Check for duplicate record_ids before removal
        print(f"\n{'='*60}")
        print("CHECKING FOR DUPLICATE RECORD_IDs")
        print(f"{'='*60}")
        print(f"Total rows before deduplication: {len(self.df_merged):,}")
        
        duplicate_count = self.df_merged['record_id'].duplicated().sum()
        print(f"Duplicate record_ids found: {duplicate_count:,}")
        
        if duplicate_count > 0:
            print(f"Unique record_ids: {self.df_merged['record_id'].nunique():,}")
            
            # Show sample duplicates BEFORE deduplication
            duplicated_records = self.df_merged[self.df_merged['record_id'].duplicated(keep=False)].sort_values('record_id')
            print(f"\nSample duplicate record_ids BEFORE deduplication (first 10 rows):")
            print(duplicated_records[['record_id', 'NDC', 'State', 'ATC4 Class']].head(10))
            
            # Print two specific record_id values before deduplication
            print(f"\n{'='*60}")
            print("DETAILED VIEW: Two record_ids BEFORE deduplication")
            print(f"{'='*60}")
            sample_record_ids = duplicated_records['record_id'].unique()[:2]
            for rid in sample_record_ids:
                print(f"\nrecord_id: {rid}")
                sample_rows = self.df_merged[self.df_merged['record_id'] == rid]
                print(sample_rows[['record_id', 'NDC', 'State', 'Year', 'Quarter', 'ATC4 Class', 'ATC3 Class','ATC2 Class', 'Units Reimbursed', 'Number of Prescriptions']].to_string(index=False))
        
        # Remove duplicates, keeping first occurrence
        df_deduplicated = self.df_merged.drop_duplicates(subset='record_id', keep='first')
        
        print(f"\nTotal rows after deduplication: {len(df_deduplicated):,}")
        print(f"Rows removed: {len(self.df_merged) - len(df_deduplicated):,}")
        
        # Show sample of data AFTER deduplication
        print(f"\nSample of deduplicated data (first 10 rows):")
        print(df_deduplicated[['record_id', 'NDC', 'State', 'ATC4 Class']].head(10))
        
        # Print the same two record_id values after deduplication
        if duplicate_count > 0:
            print(f"\n{'='*60}")
            print("DETAILED VIEW: Same two record_ids AFTER deduplication")
            print(f"{'='*60}")
            for rid in sample_record_ids:
                print(f"\nrecord_id: {rid}")
                sample_rows = df_deduplicated[df_deduplicated['record_id'] == rid]
                print(sample_rows[['record_id', 'NDC', 'State', 'Year', 'Quarter', 'ATC4 Class', 'ATC3 Class','ATC2 Class','Units Reimbursed', 'Number of Prescriptions']].to_string(index=False))

        # Export deduplicated dataframe
        df_deduplicated.to_csv(output_path, index=False)
        
        print(f"\n{'='*60}")
        print("DATA EXPORT COMPLETE")
        print(f"{'='*60}")
        print(f"Exported to: {output_path}")
        print(f"Total rows exported: {len(df_deduplicated):,}")
        print(f"Columns: {', '.join(df_deduplicated.columns.tolist())}")
        
        # Compare with analyze_atc4_mapping and cleaned dataframe output
        print(f"\n{'='*60}")
        print("COMPARISON WITH analyze_atc4_mapping()")
        print(f"{'='*60}")
        if self.atc_mapping is not None:
            print(f"Rows in atc_mapping (from analyze_atc4_mapping): {len(self.atc_mapping):,}")
            print(f"Rows in df_merged (before deduplication): {len(self.df_merged):,}")
            print(f"Rows in exported file (after deduplication): {len(df_deduplicated):,}")
            print(f"Difference from df_cleaned: {len(df_deduplicated) - len(self.df_cleaned):,}")
            print(f"Difference from atc_mapping: {len(df_deduplicated) - len(self.atc_mapping):,}")
        else:
            print("atc_mapping not available for comparison")
        
        return output_path

    '''
    def analyze_different_atc2_records(self):

    
        """Find and display records that have different ATC2 classes with names."""
        if self.atc_mapping is None:
            raise ValueError("Must run analyze_atc4_mapping() first")
        
        print("="*40)
        print("RECORDS WITH DIFFERENT ATC2 CLASSES")
        print("="*40)
        
        # Get records and create ATC2
        records = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()].copy()
        records['ATC2 Class'] = records['ATC4 Class'].str[:3]
        
        # Find records with multiple ATC2 classes
        multi_atc2 = records.groupby('record_id')['ATC2 Class'].nunique()
        different_records = multi_atc2[multi_atc2 > 1].index.tolist()
        
        if len(different_records) == 0:
            print("No records found with different ATC2 classes.")
            return None
        
        print(f"Found {len(different_records)} records with different ATC2 classes\n")
        
        # Show first 3 examples
        has_names = 'ATC2_Name' in records.columns
        for i, record_id in enumerate(different_records[:3]):
            record_data = records[records['record_id'] == record_id]
            atc_info = record_data[['ATC2 Class', 'ATC2_Name']].drop_duplicates() if has_names else record_data[['ATC2 Class']].drop_duplicates()
            
            print(f"{i+1}. {record_id} | NDC: {record_data['NDC'].iloc[0]}")
            for _, row in atc_info.iterrows():
                if has_names:
                    print(f"   ATC2: {row['ATC2 Class']} - {row['ATC2_Name'][:40]}...")
                else:
                    print(f"   ATC2: {row['ATC2 Class']}")
            print()
        
        # Quick summary
        print(f"Summary: {multi_atc2[multi_atc2 > 1].value_counts().to_dict()}")
        
        return records[records['record_id'].isin(different_records)]
    '''
    @staticmethod
    def create_multi_year_distribution_analysis(years_list):
        
        print("Creating Multi-Year ATC Distribution Analysis...")
        print("="*60)
        
        # Dictionary to store results
        results = {
            'ATC4_1_class': {},
            'ATC4_2_classes': {},
            'ATC4_3+_classes': {},
            'ATC3_1_class': {},
            'ATC3_2_classes': {},
            'ATC3_3+_classes': {},
            'ATC2_1_class': {},
            'ATC2_2_classes': {},
            'ATC2_3+_classes': {}
        }

        # NDC counts with percentages
        ndc_counts = {
            'Total_unique_NDCs': {},
            'NDCs_ATC4_1_class': {},
            'NDCs_ATC4_2_classes': {},
            'NDCs_ATC4_3+_classes': {},
            'NDCs_ATC3_1_class': {},
            'NDCs_ATC3_2_classes': {},
            'NDCs_ATC3_3+_classes': {},
            'NDCs_ATC2_1_class': {},
            'NDCs_ATC2_2_classes': {},
            'NDCs_ATC2_3+_classes': {}
        }
        
        # Dictionary to store unique ATC class counts
        atc_class_counts = {
            'Total_unique_ATC4_classes': {},
            'Total_unique_ATC3_classes': {},
            'Total_unique_ATC2_classes': {},
            'ATC4_classes_in_1_class_NDCs': {},
            'ATC4_classes_in_2_class_NDCs': {},
            'ATC4_classes_in_3plus_class_NDCs': {},
            'ATC3_classes_in_1_class_NDCs': {},
            'ATC3_classes_in_2_class_NDCs': {},
            'ATC3_classes_in_3plus_class_NDCs': {},
            'ATC2_classes_in_1_class_NDCs': {},
            'ATC2_classes_in_2_class_NDCs': {},
            'ATC2_classes_in_3plus_class_NDCs': {}
        }
        
        for year in years_list:
            print(f"Processing {year}...", end=" ")
            
            try:
                # Create a separate analyzer instance for each year
                analyzer = NDCATCAnalyzer(year=year)
                analyzer.clean_sdud_data()
                analyzer.adding_key()
                analyzer.analyze_atc4_mapping()
                
                # Get records with valid ATC4 mappings
                records = analyzer.atc_mapping[analyzer.atc_mapping['ATC4 Class'].notna()].copy()
                records['ATC2 Class'] = records['ATC4 Class'].str[:3]
                records['ATC3 Class'] = records['ATC4 Class'].str[:4]
                
                # Calculate distributions
                atc4_per_record = records.groupby('record_id')['ATC4 Class'].nunique()
                atc3_per_record = records.groupby('record_id')['ATC3 Class'].nunique()
                atc2_per_record = records.groupby('record_id')['ATC2 Class'].nunique()
                
                # ATC4 percentages
                atc4_dist = atc4_per_record.value_counts().sort_index()
                total_atc4 = len(atc4_per_record)
                results['ATC4_1_class'][year] = f"{(atc4_dist.get(1, 0) / total_atc4 * 100):.1f}%"
                results['ATC4_2_classes'][year] = f"{(atc4_dist.get(2, 0) / total_atc4 * 100):.1f}%"
                results['ATC4_3+_classes'][year] = f"{(atc4_dist[atc4_dist.index >= 3].sum() / total_atc4 * 100):.1f}%"
                
                # ATC3 percentages
                atc3_dist = atc3_per_record.value_counts().sort_index()
                total_atc3 = len(atc3_per_record)
                results['ATC3_1_class'][year] = f"{(atc3_dist.get(1, 0) / total_atc3 * 100):.1f}%"
                results['ATC3_2_classes'][year] = f"{(atc3_dist.get(2, 0) / total_atc3 * 100):.1f}%"
                results['ATC3_3+_classes'][year] = f"{(atc3_dist[atc3_dist.index >= 3].sum() / total_atc3 * 100):.1f}%"
                
                # ATC2 percentages
                atc2_dist = atc2_per_record.value_counts().sort_index()
                total_atc2 = len(atc2_per_record)
                results['ATC2_1_class'][year] = f"{(atc2_dist.get(1, 0) / total_atc2 * 100):.1f}%"
                results['ATC2_2_classes'][year] = f"{(atc2_dist.get(2, 0) / total_atc2 * 100):.1f}%"
                results['ATC2_3+_classes'][year] = f"{(atc2_dist[atc2_dist.index >= 3].sum() / total_atc2 * 100):.1f}%"
                
                # NDC counts with percentages
                total_unique_ndcs = records['NDC'].nunique()
                ndc_counts['Total_unique_NDCs'][year] = total_unique_ndcs
                
                # Get record_ids for each ATC4 category
                atc4_1_class_records = atc4_per_record[atc4_per_record == 1].index
                atc4_2_class_records = atc4_per_record[atc4_per_record == 2].index
                atc4_3plus_class_records = atc4_per_record[atc4_per_record >= 3].index
                
                atc4_1_ndcs = records[records['record_id'].isin(atc4_1_class_records)]['NDC'].nunique()
                atc4_2_ndcs = records[records['record_id'].isin(atc4_2_class_records)]['NDC'].nunique()
                atc4_3plus_ndcs = records[records['record_id'].isin(atc4_3plus_class_records)]['NDC'].nunique()
                
                ndc_counts['NDCs_ATC4_1_class'][year] = f"{atc4_1_ndcs:,} ({atc4_1_ndcs/total_unique_ndcs*100:.1f}%)"
                ndc_counts['NDCs_ATC4_2_classes'][year] = f"{atc4_2_ndcs:,} ({atc4_2_ndcs/total_unique_ndcs*100:.1f}%)"
                ndc_counts['NDCs_ATC4_3+_classes'][year] = f"{atc4_3plus_ndcs:,} ({atc4_3plus_ndcs/total_unique_ndcs*100:.1f}%)"
                
                # Get record_ids for each ATC3 category
                atc3_1_class_records = atc3_per_record[atc3_per_record == 1].index
                atc3_2_class_records = atc3_per_record[atc3_per_record == 2].index
                atc3_3plus_class_records = atc3_per_record[atc3_per_record >= 3].index
                
                atc3_1_ndcs = records[records['record_id'].isin(atc3_1_class_records)]['NDC'].nunique()
                atc3_2_ndcs = records[records['record_id'].isin(atc3_2_class_records)]['NDC'].nunique()
                atc3_3plus_ndcs = records[records['record_id'].isin(atc3_3plus_class_records)]['NDC'].nunique()
                
                ndc_counts['NDCs_ATC3_1_class'][year] = f"{atc3_1_ndcs:,} ({atc3_1_ndcs/total_unique_ndcs*100:.1f}%)"
                ndc_counts['NDCs_ATC3_2_classes'][year] = f"{atc3_2_ndcs:,} ({atc3_2_ndcs/total_unique_ndcs*100:.1f}%)"
                ndc_counts['NDCs_ATC3_3+_classes'][year] = f"{atc3_3plus_ndcs:,} ({atc3_3plus_ndcs/total_unique_ndcs*100:.1f}%)"
                
                # Get record_ids for each ATC2 category
                atc2_1_class_records = atc2_per_record[atc2_per_record == 1].index
                atc2_2_class_records = atc2_per_record[atc2_per_record == 2].index
                atc2_3plus_class_records = atc2_per_record[atc2_per_record >= 3].index
                
                atc2_1_ndcs = records[records['record_id'].isin(atc2_1_class_records)]['NDC'].nunique()
                atc2_2_ndcs = records[records['record_id'].isin(atc2_2_class_records)]['NDC'].nunique()
                atc2_3plus_ndcs = records[records['record_id'].isin(atc2_3plus_class_records)]['NDC'].nunique()
                
                ndc_counts['NDCs_ATC2_1_class'][year] = f"{atc2_1_ndcs:,} ({atc2_1_ndcs/total_unique_ndcs*100:.1f}%)"
                ndc_counts['NDCs_ATC2_2_classes'][year] = f"{atc2_2_ndcs:,} ({atc2_2_ndcs/total_unique_ndcs*100:.1f}%)"
                ndc_counts['NDCs_ATC2_3+_classes'][year] = f"{atc2_3plus_ndcs:,} ({atc2_3plus_ndcs/total_unique_ndcs*100:.1f}%)"
                
                # Calculate unique ATC class counts
                # Total unique ATC classes across all NDCs
                atc_class_counts['Total_unique_ATC4_classes'][year] = records['ATC4 Class'].nunique()
                atc_class_counts['Total_unique_ATC3_classes'][year] = records['ATC3 Class'].nunique()
                atc_class_counts['Total_unique_ATC2_classes'][year] = records['ATC2 Class'].nunique()
                
                # Unique ATC4 classes in each NDC category
                atc4_classes_1_class = records[records['record_id'].isin(atc4_1_class_records)]['ATC4 Class'].nunique()
                atc4_classes_2_class = records[records['record_id'].isin(atc4_2_class_records)]['ATC4 Class'].nunique()
                atc4_classes_3plus = records[records['record_id'].isin(atc4_3plus_class_records)]['ATC4 Class'].nunique()
                
                atc_class_counts['ATC4_classes_in_1_class_NDCs'][year] = atc4_classes_1_class
                atc_class_counts['ATC4_classes_in_2_class_NDCs'][year] = atc4_classes_2_class
                atc_class_counts['ATC4_classes_in_3plus_class_NDCs'][year] = atc4_classes_3plus
                
                # Unique ATC3 classes in each NDC category
                atc3_classes_1_class = records[records['record_id'].isin(atc3_1_class_records)]['ATC3 Class'].nunique()
                atc3_classes_2_class = records[records['record_id'].isin(atc3_2_class_records)]['ATC3 Class'].nunique()
                atc3_classes_3plus = records[records['record_id'].isin(atc3_3plus_class_records)]['ATC3 Class'].nunique()
                
                atc_class_counts['ATC3_classes_in_1_class_NDCs'][year] = atc3_classes_1_class
                atc_class_counts['ATC3_classes_in_2_class_NDCs'][year] = atc3_classes_2_class
                atc_class_counts['ATC3_classes_in_3plus_class_NDCs'][year] = atc3_classes_3plus
                
                # Unique ATC2 classes in each NDC category
                atc2_classes_1_class = records[records['record_id'].isin(atc2_1_class_records)]['ATC2 Class'].nunique()
                atc2_classes_2_class = records[records['record_id'].isin(atc2_2_class_records)]['ATC2 Class'].nunique()
                atc2_classes_3plus = records[records['record_id'].isin(atc2_3plus_class_records)]['ATC2 Class'].nunique()
                
                atc_class_counts['ATC2_classes_in_1_class_NDCs'][year] = atc2_classes_1_class
                atc_class_counts['ATC2_classes_in_2_class_NDCs'][year] = atc2_classes_2_class
                atc_class_counts['ATC2_classes_in_3plus_class_NDCs'][year] = atc2_classes_3plus
                
                print("✓")
                
            except Exception as e:
                print(f"✗ Error: {e}")
                # Fill with N/A for failed years
                for key in results.keys():
                    results[key][year] = "N/A"
                for key in ndc_counts.keys():
                    ndc_counts[key][year] = "N/A"
                for key in atc_class_counts.keys():
                    atc_class_counts[key][year] = "N/A"
        
        # Create DataFrames
        df_percentages = pd.DataFrame(results).T
        df_ndc_counts = pd.DataFrame(ndc_counts).T
        df_atc_class_counts = pd.DataFrame(atc_class_counts).T
        
        print(f"\nATC DISTRIBUTION PERCENTAGES ACROSS YEARS")
        print("="*60)
        print(df_percentages)
        
        print(f"\nUNIQUE NDC COUNTS BY CATEGORY ACROSS YEARS (with percentages)")
        print("="*70)
        print(df_ndc_counts)
        
        print(f"\nUNIQUE ATC CLASS COUNTS BY CATEGORY ACROSS YEARS")
        print("="*60)
        print(df_atc_class_counts)
        
        return df_percentages, df_ndc_counts, df_atc_class_counts
    
    

In [34]:
analyzer = NDCATCAnalyzer(year=2024)
analyzer.clean_sdud_data()           # Clean SDUD data
analyzer.adding_key()                # Add record_id key
#analyzer.generate_ndc_txt()          # Generate NDC text file
analyzer.analyze_atc4_mapping()      # Merge ATC4 by record_id & NDC

Reading CSV file: c:\Users\lholguin\OneDrive - purdue.edu\VS code\Data\SDUD\SDUD2024.csv
Total rows in 2024 before filtering: 5205065
Rows after removing NA: 2599748
Rows after filtering State='XX': 2362630
Unique NDCs: 33397
Adding record_id column...
Created 2362630 record IDs
Sample record_id: AK_2024_4_FFSU_00002143380
Total rows in ATC4 mapping file: 4336231
Unique NDCs in ATC4 mapping file: 32203
Merging ATC4 mapping with cleaned data using record_id and NDC...
Merged dataframe rows: 4360194
Records with ATC4 mapping: 4336231 (99.5%)

Records without ATC4 mapping: 23963
Unique NDCs without mapping: 1194


Unnamed: 0,Utilization Type,State,NDC,Labeler Code,Product Code,Package Size,Year,Quarter,Suppression Used,Product Name,Units Reimbursed,Number of Prescriptions,Total Amount Reimbursed,Medicaid Amount Reimbursed,Non Medicaid Amount Reimbursed,record_id,ATC4 Class
0,FFSU,AK,00002143380,2,1433,80,2024,4,False,TRULICITY,230.0,110.0,108512.57,105868.53,2644.04,AK_2024_4_FFSU_00002143380,A10BJ
1,FFSU,AK,00002143480,2,1434,80,2024,4,False,TRULICITY,216.0,108.0,102247.58,97153.90,5093.68,AK_2024_4_FFSU_00002143480,A10BJ
2,FFSU,AK,00002143611,2,1436,11,2024,4,False,EMGALITY P,32.0,31.0,22193.94,22193.94,0.00,AK_2024_4_FFSU_00002143611,N02CD
3,FFSU,AK,00002144511,2,1445,11,2024,4,False,TALTZ AUTO,38.0,37.0,253852.99,226811.15,27041.84,AK_2024_4_FFSU_00002144511,L04AC
4,FFSU,AK,00002145780,2,1457,80,2024,4,False,MOUNJARO,182.0,91.0,94241.98,92199.80,2042.18,AK_2024_4_FFSU_00002145780,A10BX
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4360189,FFSU,WY,82347050505,82347,505,5,2024,1,False,LIDOCAINE,1149.0,38.0,2587.33,2587.33,0.00,WY_2024_1_FFSU_82347050505,C01BB
4360190,FFSU,WY,82347050505,82347,505,5,2024,1,False,LIDOCAINE,1149.0,38.0,2587.33,2587.33,0.00,WY_2024_1_FFSU_82347050505,N01BB
4360191,FFSU,WY,82347050505,82347,505,5,2024,1,False,LIDOCAINE,1149.0,38.0,2587.33,2587.33,0.00,WY_2024_1_FFSU_82347050505,S02DA
4360192,FFSU,WY,82347050505,82347,505,5,2024,1,False,LIDOCAINE,1149.0,38.0,2587.33,2587.33,0.00,WY_2024_1_FFSU_82347050505,S01HA


In [36]:
# Distribution analyses
atc4_dist = analyzer.analyze_atc4_distribution()
atc3_dist = analyzer.analyze_atc3_distribution() 
atc2_dist = analyzer.analyze_atc2_distribution()
#different_atc2_records = analyzer.analyze_different_atc2_records()



ATC4 CLASSES PER RECORD_ID DISTRIBUTION
Distribution of ATC4 classes per record_id:
  1 class(es): 1,623,867 record_ids (69.4%)
  2 class(es): 282,747 record_ids (12.1%)
  3 class(es): 207,971 record_ids (8.9%)
  4 class(es): 60,195 record_ids (2.6%)
  5 class(es): 49,616 record_ids (2.1%)
  6 class(es): 13,854 record_ids (0.6%)
  7 class(es): 24,746 record_ids (1.1%)
  8 class(es): 23,866 record_ids (1.0%)
  9 class(es): 16,379 record_ids (0.7%)
  10 class(es): 578 record_ids (0.0%)
  11 class(es): 23,370 record_ids (1.0%)
  12 class(es): 2,457 record_ids (0.1%)
  13 class(es): 670 record_ids (0.0%)
  14 class(es): 3,567 record_ids (0.2%)
  15 class(es): 467 record_ids (0.0%)
  16 class(es): 1,431 record_ids (0.1%)
  17 class(es): 483 record_ids (0.0%)
  20 class(es): 1,250 record_ids (0.1%)
  21 class(es): 48 record_ids (0.0%)
  22 class(es): 1,105 record_ids (0.0%)

ATC3 CLASSES PER RECORD_ID DISTRIBUTION
Distribution of ATC3 classes per record_id:
  1 class(es): 1,669,581 record_i

In [37]:
analyzer.fetch_atc_names()           
analyzer.prepare_final_dataframe()   
analyzer.export_merged_data()  


FETCHING ATC CLASS NAMES
Using cache: c:\Users\lholguin\OneDrive - purdue.edu\VS code\Data\ATC\cache_files\atc_names_cache

Creating ATC3 and ATC2 columns from ATC4...

Unique codes to fetch:
  ATC4: 612
  ATC3: 212
  ATC2: 89

Fetching ATC4 names...
Fetching ATC3 names...
Fetching ATC2 names...

Total processing time: 0.0 minutes

Applying names to dataframe...

ATC names added successfully!

Sample output:
           NDC                   record_id ATC4 Class                                           ATC4_Name ATC3 Class                                     ATC3_Name ATC2 Class               ATC2_Name
0  00002143380  AK_2024_4_FFSU_00002143380      A10BJ           Glucagon-like peptide-1 (GLP-1) analogues       A10B  BLOOD GLUCOSE LOWERING DRUGS, EXCL. INSULINS        A10  DRUGS USED IN DIABETES
1  00002143480  AK_2024_4_FFSU_00002143480      A10BJ           Glucagon-like peptide-1 (GLP-1) analogues       A10B  BLOOD GLUCOSE LOWERING DRUGS, EXCL. INSULINS        A10  DRUGS USED IN DI

'c:\\Users\\lholguin\\OneDrive - purdue.edu\\VS code\\Data\\ATC\\merged_data\\merged_NEWdata_2024.csv'

In [None]:
nokey_path=rf'C:\Users\{user}\OneDrive - purdue.edu\VS code\Data\ATC\ATC4_classes\Classes_notgood\NDCf_2023_ATC4_classes.csv'
keyed_path=rf'C:\Users\{user}\OneDrive - purdue.edu\VS code\Data\ATC\ATC4_classes\NDCNEW_2023_ATC4_classes.csv'

# Load them
keyed = pd.read_csv(keyed_path, dtype=str)
nokey = pd.read_csv(nokey_path, dtype=str)

# Normalize NDCs (remove hyphens, pad to 11 digits)
for df in [keyed, nokey]:
    df["NDC"] = df["NDC"].str.replace("-", "", regex=False).str.zfill(11)

# --- Summary stats ---
summary = {
    "File": ["With key (NDCNEW_2024_ATC4_classes)", "Without key (NDCf_2024_ATC4_classes)"],
    "Total rows": [len(keyed), len(nokey)],
    "Unique NDCs": [keyed["NDC"].nunique(), nokey["NDC"].nunique()],
    "Mapped NDCs (non-null ATC)": [
        keyed["ATC4 Class"].notna().sum(),
        nokey["ATC4 Class"].notna().sum(),
    ],
}
summary_df = pd.DataFrame(summary)

# --- Compare overlap of unique NDCs ---
ndc_keyed = set(keyed["NDC"].unique())
ndc_nokey = set(nokey["NDC"].unique())

overlap_ndcs = len(ndc_keyed & ndc_nokey)
only_in_nokey = len(ndc_nokey - ndc_keyed)
only_in_keyed = len(ndc_keyed - ndc_nokey)

comparison = pd.DataFrame({
    "Metric": ["Overlap NDCs", "Only in without-key file", "Only in with-key file", "Percent overlap"],
    "Value": [overlap_ndcs, only_in_nokey, only_in_keyed, overlap_ndcs / len(ndc_nokey) * 100]
})

print("\n=== Summary of Each File ===")
print(summary_df.to_string(index=False))

print("\n=== NDC Overlap Comparison ===")
print(comparison.to_string(index=False))