In [15]:
import pandas as pd
import numpy as np
import requests
import shelve
import os
from datetime import datetime
import matplotlib.pyplot as plt
import re
import ipywidgets as widgets
from IPython.display import display, clear_output
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [16]:
user="lholguin"
#user in personal pc <- "asus"

In [17]:
#New changes to the class
class NDCATCAnalyzer:

    def __init__(self, year, base_path=None):

        self.year = year
        if base_path is None:
            #Lookup the user's base path
            self.base_path = rf"c:\Users\{user}\OneDrive - purdue.edu\VS code\Data"
        else:
            self.base_path = base_path
            
        self.df_cleaned = None
        self.df_merged = None
        self.atc_mapping = None
        self.df_faf = None
        
    def clean_sdud_data(self):
        csv_file = os.path.join(self.base_path, f"SDUD\\SDUD{self.year}.csv")
        print(f"Reading CSV file: {csv_file}")
        
        # Read with NDC as string to preserve leading zeros
        df = pd.read_csv(csv_file, dtype={'NDC': 'object'})
        
        print(f"Total rows in {self.year} before filtering: {len(df)}")
        
        # Remove NA values
        df_filtered = df.dropna(subset=['Units Reimbursed', 'Number of Prescriptions'])
        print(f"Rows after removing NA: {len(df_filtered)}")
        
        # Filter out State='XX'
        df_filtered = df_filtered[df_filtered['State'] != 'XX']
        print(f"Rows after filtering State='XX': {len(df_filtered)}")
        print(f"Unique NDCs: {df_filtered['NDC'].nunique()}")
        
        self.df_cleaned = df_filtered
        return self.df_cleaned
    
    #NEW
    def adding_key(self):
        """Add record_id column to cleaned dataframe."""
        if self.df_cleaned is None:
            raise ValueError("Must run clean_sdud_data() first")
        
        print("Adding record_id column...")
        
        # Create record_id column
        self.df_cleaned['record_id'] = (
            self.df_cleaned['State'].astype(str) + "_" +
            self.df_cleaned['Year'].astype(str) + "_" +
            self.df_cleaned['Quarter'].astype(str) + "_" +
            self.df_cleaned['Utilization Type'].astype(str) + "_" +
            self.df_cleaned['NDC'].astype(str)
        )
        
        print(f"Created {len(self.df_cleaned)} record IDs")
        print(f"Sample record_id: {self.df_cleaned['record_id'].iloc[0]}")
        
        return self.df_cleaned
    
    def generate_ndc_txt(self, output_filename=None):
        """Step 2: Generate text file with unique NDC values and their record_id keys."""
        if self.df_cleaned is None:
            raise ValueError("Must run clean_sdud_data() first")
        
        if 'record_id' not in self.df_cleaned.columns:
            raise ValueError("Must run adding_key() first to create record_id column")
            
        if output_filename is None:
            output_filename = f"NDCNEW_{self.year}.txt"
        
        output_path = os.path.join(self.base_path, f"ATC\\text_files\\{output_filename}")
        
        # Get unique combinations of NDC and record_id
        unique_pairs = self.df_cleaned[['NDC', 'record_id']].drop_duplicates()
        
        with open(output_path, 'w') as f:
            # Write header
            f.write("NDC\trecord_id\n")
            # Write each unique pair
            for _, row in unique_pairs.iterrows():
                f.write(f"{row['NDC']}\t{row['record_id']}\n")
        
        print(f"Exported to: {output_path}")
        print(f"Unique record_id values: {unique_pairs['record_id'].nunique()}")
        return output_path
    
    def analyze_atc4_mapping(self):
        """Step 3: Analyze ATC4 mapping results and identify missing NDCs."""
        atc4_path = os.path.join(self.base_path, f"ATC\\ATC4_classes\\NDCNEW_{self.year}_ATC4_classes.csv")
        
        # Read ATC4 mapping
        df_atc4 = pd.read_csv(atc4_path, dtype={'NDC': 'object', 'record_id': 'string'})
        df_atc4['NDC'] = df_atc4['NDC'].str.zfill(11)
        
        # Ensure consistent data types before merge
        self.df_cleaned['record_id'] = self.df_cleaned['record_id'].astype('string')
        self.df_cleaned['NDC'] = self.df_cleaned['NDC'].astype('object')  
        df_atc4['record_id'] = df_atc4['record_id'].astype('string')
        df_atc4['NDC'] = df_atc4['NDC'].astype('object')  

        # Merge ATC4 mapping with cleaned data using record_id
        if self.df_cleaned is None:
            raise ValueError("Must run clean_sdud_data() and adding_key() first")
        
        if 'record_id' not in self.df_cleaned.columns:
            raise ValueError("Must run adding_key() first to create record_id column")
        
        print(f"Merging ATC4 mapping with cleaned data using record_id and NDC...")
        
        # Merge on BOTH record_id AND NDC
        self.atc_mapping = pd.merge(
            self.df_cleaned,
            df_atc4[['record_id', 'NDC', 'ATC4 Class']],  # Include NDC in the selection
            on=['record_id', 'NDC'],  # Merge on both columns
            how='left'
        )
        
        # Print rows of the merged dataframe
        print(f"Merged dataframe rows: {len(self.atc_mapping)}")
        print(self.atc_mapping.head())
        total_records = len(self.atc_mapping)
        mapped_records = self.atc_mapping['ATC4 Class'].notna().sum()
        print(f"Records with ATC4 mapping: {mapped_records} ({mapped_records/total_records*100:.1f}%)")
        
        # Identify missing mappings
        missing_records = self.atc_mapping[self.atc_mapping['ATC4 Class'].isna()]
        if len(missing_records) > 0:
            print(f"\nRecords without ATC4 mapping: {len(missing_records)}")
            print(f"Unique NDCs without mapping: {missing_records['NDC'].nunique()}")
        
        return self.atc_mapping
        
    def analyze_atc4_distribution(self):

        """Analyze distribution of ATC4 classes per record_id."""
        if self.atc_mapping is None:
            raise ValueError("Must run analyze_atc4_mapping() first")
        
        print(f"\n{'='*60}")
        print("ATC4 CLASSES PER RECORD_ID DISTRIBUTION")
        print(f"{'='*60}")
        
        # Count ATC4 classes per record_id (only valid mappings)
        records_with_mapping = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()].copy()
        
        if len(records_with_mapping) == 0:
            print("No records with valid ATC4 mappings found.")
            return None
        
        # Group by record_id and count unique ATC4 classes
        atc4_per_record = records_with_mapping.groupby('record_id').agg({
            'ATC4 Class': 'nunique',
            'NDC': 'first',  # Get the NDC for reference
            'State': 'first',  # Get the state for reference
            'Year': 'first'    # Get the year for reference
        }).reset_index()
        
        atc4_per_record.columns = ['record_id', 'num_atc4_classes', 'NDC', 'State', 'Year']
        
        # Distribution analysis
        distribution = atc4_per_record['num_atc4_classes'].value_counts().sort_index()
        
        print("Distribution of ATC4 classes per record_id:")
        for classes, count in distribution.items():
            pct = (count / len(atc4_per_record)) * 100
            print(f"  {classes} class(es): {count:,} record_ids ({pct:.1f}%)")
        
        # Show examples of multi-class records
        multi_class = atc4_per_record[atc4_per_record['num_atc4_classes'] > 1].sort_values('num_atc4_classes', ascending=False)
        
        if len(multi_class) > 0:
            print(f"\nTop 10 record_ids with most ATC4 classes:")
            for _, row in multi_class.head(10).iterrows():
                record_classes = records_with_mapping[records_with_mapping['record_id'] == row['record_id']]['ATC4 Class'].unique()
                print(f"  {row['record_id']}: {row['num_atc4_classes']} classes")
                print(f"    NDC: {row['NDC']}, State: {row['State']}, Year: {row['Year']}")
                print(f"    Classes: {list(record_classes)}")
                print()
        
        return atc4_per_record

    def fetch_atc_names(self, cache_path=None):
        """Fetch ATC class names (ATC4, ATC3, ATC2) from RxNav API."""
        if self.atc_mapping is None:
            raise ValueError("Must run analyze_atc4_mapping() first")
        
        if cache_path is None:
            cache_path = os.path.join(self.base_path, "ATC\\cache_files\\atc_names_cache")
        
        print(f"\n{'='*60}")
        print("FETCHING ATC CLASS NAMES")
        print(f"{'='*60}")
        print(f"Using cache: {cache_path}")
        
        # Get only records with valid ATC4 mappings
        df_with_atc = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()].copy()
        
        # Create ATC3 and ATC2 columns from ATC4
        print("\nCreating ATC3 and ATC2 columns from ATC4...")
        df_with_atc['ATC3 Class'] = df_with_atc['ATC4 Class'].str[:4]
        df_with_atc['ATC2 Class'] = df_with_atc['ATC4 Class'].str[:3]
        
        # Get unique codes for each level
        unique_atc4 = df_with_atc['ATC4 Class'].dropna().unique()
        unique_atc3 = df_with_atc['ATC3 Class'].dropna().unique()
        unique_atc2 = df_with_atc['ATC2 Class'].dropna().unique()
        
        # Filter out invalid codes
        unique_atc4 = [c for c in unique_atc4 if c not in ['No ATC Mapping Found', 'No RxCUI Found', '']]
        unique_atc3 = [c for c in unique_atc3 if c not in ['No ATC Mapping Found', 'No RxCUI Found', '', 'No ', 'No']]
        unique_atc2 = [c for c in unique_atc2 if c not in ['No ATC Mapping Found', 'No RxCUI Found', '', 'No ', 'No']]
        
        print(f"\nUnique codes to fetch:")
        print(f"  ATC4: {len(unique_atc4)}")
        print(f"  ATC3: {len(unique_atc3)}")
        print(f"  ATC2: {len(unique_atc2)}")
        
        # Build mappings
        atc4_names = {}
        atc3_names = {}
        atc2_names = {}
        
        with shelve.open(cache_path) as cache:
            start_time = datetime.now()
            
            print("\nFetching ATC4 names...")
            for code in unique_atc4:
                atc4_names[code] = self._get_atc_name(code, cache)
            
            print("Fetching ATC3 names...")
            for code in unique_atc3:
                atc3_names[code] = self._get_atc_name(code, cache)
            
            print("Fetching ATC2 names...")
            for code in unique_atc2:
                atc2_names[code] = self._get_atc_name(code, cache)
            
            print(f"\nTotal processing time: {(datetime.now() - start_time).total_seconds()/60:.1f} minutes")
        
        # Apply names to all records in atc_mapping
        print("\nApplying names to dataframe...")
        self.atc_mapping['ATC3 Class'] = self.atc_mapping['ATC4 Class'].str[:4]
        self.atc_mapping['ATC2 Class'] = self.atc_mapping['ATC4 Class'].str[:3]
        
        self.atc_mapping['ATC4_Name'] = self.atc_mapping['ATC4 Class'].map(atc4_names).fillna('')
        self.atc_mapping['ATC3_Name'] = self.atc_mapping['ATC3 Class'].map(atc3_names).fillna('')
        self.atc_mapping['ATC2_Name'] = self.atc_mapping['ATC2 Class'].map(atc2_names).fillna('')
        
        print(f"\nATC names added successfully!")
        print("\nSample output:")
        sample = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()][['NDC', 'record_id', 'ATC4 Class', 'ATC4_Name', 'ATC3 Class', 'ATC3_Name', 'ATC2 Class', 'ATC2_Name']].head(5)
        print(sample.to_string())
        
        return self.atc_mapping
    
    def prepare_final_dataframe(self):
        """Prepare final dataframe with scaled metrics for export."""
        if self.atc_mapping is None:
            raise ValueError("Must run fetch_atc_names() first")
        
        print(f"\n{'='*60}")
        print("PREPARING FINAL DATAFRAME")
        print(f"{'='*60}")
        
        # Create a copy for final output
        self.df_merged = self.atc_mapping.copy()
        
        # Scale units
        print("\nScaling units...")
        self.df_merged['Units Reimbursed'] = self.df_merged['Units Reimbursed'] / 1e9
        self.df_merged['Number of Prescriptions'] = self.df_merged['Number of Prescriptions'] / 1e6
        
        # Report final statistics
        total_records = len(self.df_merged)
        mapped_records = self.df_merged['ATC4 Class'].notna().sum()
        
        print(f"\nFinal statistics:")
        print(f"Total records: {total_records:,}")
        print(f"Records with ATC4 mapping: {mapped_records:,} ({mapped_records/total_records*100:.1f}%)")
        print(f"Total Units Reimbursed: {self.df_merged['Units Reimbursed'].sum():.2f} Billion")
        print(f"Total Prescriptions: {self.df_merged['Number of Prescriptions'].sum():.2f} Million")
        
        return self.df_merged
    
    def _get_atc_name(self, atc_code, cache):
        """Get ATC class name from code, using cache."""
        cache_key = f"atc_name:{atc_code}"
        if cache_key in cache:
            return cache[cache_key]
        
        try:
            url = f"https://rxnav.nlm.nih.gov/REST/rxclass/class/byId.json?classId={atc_code}"
            response = requests.get(url)
            response.raise_for_status()
            data = response.json()
            
            # Get class name
            if 'rxclassMinConceptList' in data and 'rxclassMinConcept' in data['rxclassMinConceptList']:
                concepts = data['rxclassMinConceptList']['rxclassMinConcept']
                if concepts:
                    name = concepts[0].get('className', '')
                    cache[cache_key] = name
                    return name
            
            cache[cache_key] = ''
            return ''
            
        except Exception as e:
            print(f"Error retrieving name for {atc_code}: {e}")
            cache[cache_key] = ''
            return ''
    
    def analyze_atc3_distribution(self):
        """Analyze distribution of ATC3 classes per record_id."""
        if self.atc_mapping is None:
            raise ValueError("Must run analyze_atc4_mapping() first")
        
        print(f"\n{'='*60}")
        print("ATC3 CLASSES PER RECORD_ID DISTRIBUTION")
        print(f"{'='*60}")
        
        # Create ATC3 classes from ATC4
        records_with_mapping = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()].copy()
        
        if len(records_with_mapping) == 0:
            print("No records with valid ATC4 mappings found.")
            return None
        
        # Create ATC3 class from ATC4 class (first 4 characters)
        records_with_mapping['ATC3 Class'] = records_with_mapping['ATC4 Class'].str[:4]
        
        # Group by record_id and count unique ATC3 classes
        atc3_per_record = records_with_mapping.groupby('record_id').agg({
            'ATC3 Class': 'nunique',
            'NDC': 'first',
            'State': 'first',
            'Year': 'first'
        }).reset_index()
        
        atc3_per_record.columns = ['record_id', 'num_atc3_classes', 'NDC', 'State', 'Year']
        
        # Distribution analysis
        distribution = atc3_per_record['num_atc3_classes'].value_counts().sort_index()
        
        print("Distribution of ATC3 classes per record_id:")
        for classes, count in distribution.items():
            pct = (count / len(atc3_per_record)) * 100
            print(f"  {classes} class(es): {count:,} record_ids ({pct:.1f}%)")
        
        # Show examples of multi-class records
        multi_class = atc3_per_record[atc3_per_record['num_atc3_classes'] > 1].sort_values('num_atc3_classes', ascending=False)
        
        if len(multi_class) > 0:
            print(f"\nTop 10 record_ids with most ATC3 classes:")
            for _, row in multi_class.head(10).iterrows():
                record_classes = records_with_mapping[records_with_mapping['record_id'] == row['record_id']]['ATC3 Class'].unique()
                print(f"  {row['record_id']}: {row['num_atc3_classes']} classes")
                print(f"    NDC: {row['NDC']}, State: {row['State']}, Year: {row['Year']}")
                print(f"    Classes: {list(record_classes)}")
                print()
        
        # Summary statistics
        print(f"\nATC3 Summary:")
        print(f"  Total record_ids with ATC3 mapping: {len(atc3_per_record):,}")
        print(f"  Average ATC3 classes per record_id: {atc3_per_record['num_atc3_classes'].mean():.2f}")
        print(f"  Max ATC3 classes for single record_id: {atc3_per_record['num_atc3_classes'].max()}")
        
        return atc3_per_record

    def analyze_atc2_distribution(self):
        """Analyze distribution of ATC2 classes per record_id."""
        if self.atc_mapping is None:
            raise ValueError("Must run analyze_atc4_mapping() first")
        
        print(f"\n{'='*60}")
        print("ATC2 CLASSES PER RECORD_ID DISTRIBUTION")
        print(f"{'='*60}")
        
        # Create ATC2 classes from ATC4
        records_with_mapping = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()].copy()
        
        if len(records_with_mapping) == 0:
            print("No records with valid ATC4 mappings found.")
            return None
        
        # Create ATC2 class from ATC4 class (first 3 characters)
        records_with_mapping['ATC2 Class'] = records_with_mapping['ATC4 Class'].str[:3]
        
        # Group by record_id and count unique ATC2 classes
        atc2_per_record = records_with_mapping.groupby('record_id').agg({
            'ATC2 Class': 'nunique',
            'NDC': 'first',
            'State': 'first',
            'Year': 'first'
        }).reset_index()
        
        atc2_per_record.columns = ['record_id', 'num_atc2_classes', 'NDC', 'State', 'Year']
        
        # Distribution analysis
        distribution = atc2_per_record['num_atc2_classes'].value_counts().sort_index()
        
        print("Distribution of ATC2 classes per record_id:")
        for classes, count in distribution.items():
            pct = (count / len(atc2_per_record)) * 100
            print(f"  {classes} class(es): {count:,} record_ids ({pct:.1f}%)")
        
        # Show examples of multi-class records
        multi_class = atc2_per_record[atc2_per_record['num_atc2_classes'] > 1].sort_values('num_atc2_classes', ascending=False)
        
        if len(multi_class) > 0:
            print(f"\nTop 10 record_ids with most ATC2 classes:")
            for _, row in multi_class.head(10).iterrows():
                record_classes = records_with_mapping[records_with_mapping['record_id'] == row['record_id']]['ATC2 Class'].unique()
                print(f"  {row['record_id']}: {row['num_atc2_classes']} classes")
                print(f"    NDC: {row['NDC']}, State: {row['State']}, Year: {row['Year']}")
                print(f"    Classes: {list(record_classes)}")
                print()
        
        # Summary statistics
        print(f"\nATC2 Summary:")
        print(f"  Total record_ids with ATC2 mapping: {len(atc2_per_record):,}")
        print(f"  Average ATC2 classes per record_id: {atc2_per_record['num_atc2_classes'].mean():.2f}")
        print(f"  Max ATC2 classes for single record_id: {atc2_per_record['num_atc2_classes'].max()}")
        
        return atc2_per_record

    def export_merged_data(self, output_filename=None):
        """Export the final merged dataframe to CSV after removing duplicate record_ids."""
        if self.df_merged is None:
            raise ValueError("Must run prepare_final_dataframe() first to create merged dataframe")
            
        if output_filename is None:
            output_filename = f"merged_NEWdata_{self.year}.csv"

        output_path = os.path.join(self.base_path, f"ATC\\merged_data\\{output_filename}")

        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        # Check for duplicate record_ids before removal
        print(f"\n{'='*60}")
        print("CHECKING FOR DUPLICATE RECORD_IDs")
        print(f"{'='*60}")
        print(f"Total rows before deduplication: {len(self.df_merged):,}")
        
        duplicate_count = self.df_merged['record_id'].duplicated().sum()
        print(f"Duplicate record_ids found: {duplicate_count:,}")
        
        if duplicate_count > 0:
            print(f"Unique record_ids: {self.df_merged['record_id'].nunique():,}")
            
            # Show sample duplicates BEFORE deduplication
            duplicated_records = self.df_merged[self.df_merged['record_id'].duplicated(keep=False)].sort_values('record_id')
            print(f"\nSample duplicate record_ids BEFORE deduplication (first 10 rows):")
            print(duplicated_records[['record_id', 'NDC', 'State', 'ATC4 Class']].head(10))
            
            # Print two specific record_id values before deduplication
            print(f"\n{'='*60}")
            print("DETAILED VIEW: Two record_ids BEFORE deduplication")
            print(f"{'='*60}")
            sample_record_ids = duplicated_records['record_id'].unique()[:2]
            for rid in sample_record_ids:
                print(f"\nrecord_id: {rid}")
                sample_rows = self.df_merged[self.df_merged['record_id'] == rid]
                print(sample_rows[['record_id', 'NDC', 'State', 'Year', 'Quarter', 'ATC4 Class', 'ATC3 Class','ATC2 Class', 'Units Reimbursed', 'Number of Prescriptions']].to_string(index=False))
        
        # Remove duplicates, keeping first occurrence
        df_deduplicated = self.df_merged.drop_duplicates(subset='record_id', keep='first')
        
        print(f"\nTotal rows after deduplication: {len(df_deduplicated):,}")
        print(f"Rows removed: {len(self.df_merged) - len(df_deduplicated):,}")
        
        # Show sample of data AFTER deduplication
        print(f"\nSample of deduplicated data (first 10 rows):")
        print(df_deduplicated[['record_id', 'NDC', 'State', 'ATC4 Class']].head(10))
        
        # Print the same two record_id values after deduplication
        if duplicate_count > 0:
            print(f"\n{'='*60}")
            print("DETAILED VIEW: Same two record_ids AFTER deduplication")
            print(f"{'='*60}")
            for rid in sample_record_ids:
                print(f"\nrecord_id: {rid}")
                sample_rows = df_deduplicated[df_deduplicated['record_id'] == rid]
                print(sample_rows[['record_id', 'NDC', 'State', 'Year', 'Quarter', 'ATC4 Class', 'ATC3 Class','ATC2 Class','Units Reimbursed', 'Number of Prescriptions']].to_string(index=False))

        # Export deduplicated dataframe
        df_deduplicated.to_csv(output_path, index=False)
        
        print(f"\n{'='*60}")
        print("DATA EXPORT COMPLETE")
        print(f"{'='*60}")
        print(f"Exported to: {output_path}")
        print(f"Total rows exported: {len(df_deduplicated):,}")
        print(f"Columns: {', '.join(df_deduplicated.columns.tolist())}")
        
        # Compare with analyze_atc4_mapping output
        print(f"\n{'='*60}")
        print("COMPARISON WITH analyze_atc4_mapping()")
        print(f"{'='*60}")
        if self.atc_mapping is not None:
            print(f"Rows in atc_mapping (from analyze_atc4_mapping): {len(self.atc_mapping):,}")
            print(f"Rows in df_merged (before deduplication): {len(self.df_merged):,}")
            print(f"Rows in exported file (after deduplication): {len(df_deduplicated):,}")
            print(f"Difference from atc_mapping: {len(df_deduplicated) - len(self.atc_mapping):,}")
        else:
            print("atc_mapping not available for comparison")
        
        return output_path

    def analyze_different_atc2_records(self):
        """Find and display records that have different ATC2 classes with names."""
        if self.atc_mapping is None:
            raise ValueError("Must run analyze_atc4_mapping() first")
        
        print("RECORDS WITH DIFFERENT ATC2 CLASSES")
        print("="*40)
        
        # Get records and create ATC2
        records = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()].copy()
        records['ATC2 Class'] = records['ATC4 Class'].str[:3]
        
        # Find records with multiple ATC2 classes
        multi_atc2 = records.groupby('record_id')['ATC2 Class'].nunique()
        different_records = multi_atc2[multi_atc2 > 1].index.tolist()
        
        if len(different_records) == 0:
            print("No records found with different ATC2 classes.")
            return None
        
        print(f"Found {len(different_records)} records with different ATC2 classes\n")
        
        # Show first 3 examples
        has_names = 'ATC2_Name' in records.columns
        for i, record_id in enumerate(different_records[:3]):
            record_data = records[records['record_id'] == record_id]
            atc_info = record_data[['ATC2 Class', 'ATC2_Name']].drop_duplicates() if has_names else record_data[['ATC2 Class']].drop_duplicates()
            
            print(f"{i+1}. {record_id} | NDC: {record_data['NDC'].iloc[0]}")
            for _, row in atc_info.iterrows():
                if has_names:
                    print(f"   ATC2: {row['ATC2 Class']} - {row['ATC2_Name'][:40]}...")
                else:
                    print(f"   ATC2: {row['ATC2 Class']}")
            print()
        
        # Quick summary
        print(f"Summary: {multi_atc2[multi_atc2 > 1].value_counts().to_dict()}")
        
        return records[records['record_id'].isin(different_records)]

In [18]:
analyzer = NDCATCAnalyzer(year=2022)
analyzer.clean_sdud_data()           # Clean SDUD data
analyzer.adding_key()                # Add record_id key
analyzer.generate_ndc_txt()          # Generate NDC text file
analyzer.analyze_atc4_mapping()      # Merge ATC4 by record_id & NDC

Reading CSV file: c:\Users\lholguin\OneDrive - purdue.edu\VS code\Data\SDUD\SDUD2022.csv
Total rows in 2022 before filtering: 5164804
Rows after removing NA: 2621949
Rows after filtering State='XX': 2389418
Unique NDCs: 33005
Adding record_id column...
Created 2389418 record IDs
Sample record_id: AK_2022_4_FFSU_00002143380
Exported to: c:\Users\lholguin\OneDrive - purdue.edu\VS code\Data\ATC\text_files\NDCNEW_2022.txt
Unique record_id values: 2389418
Merging ATC4 mapping with cleaned data using record_id and NDC...
Merged dataframe rows: 4421713
  Utilization Type State          NDC  Labeler Code  Product Code  \
0             FFSU    AK  00002143380             2          1433   
1             FFSU    AK  00002143480             2          1434   
2             FFSU    AK  00002143611             2          1436   
3             FFSU    AK  00002144511             2          1445   
4             FFSU    AK  00002147180             2          1471   

   Package Size  Year  Quarter  S

Unnamed: 0,Utilization Type,State,NDC,Labeler Code,Product Code,Package Size,Year,Quarter,Suppression Used,Product Name,Units Reimbursed,Number of Prescriptions,Total Amount Reimbursed,Medicaid Amount Reimbursed,Non Medicaid Amount Reimbursed,record_id,ATC4 Class
0,FFSU,AK,00002143380,2,1433,80,2022,4,False,TRULICITY,473.0,198.0,201743.95,195095.41,6648.54,AK_2022_4_FFSU_00002143380,A10BJ
1,FFSU,AK,00002143480,2,1434,80,2022,4,False,TRULICITY,646.0,231.0,275657.31,267347.63,8309.68,AK_2022_4_FFSU_00002143480,A10BJ
2,FFSU,AK,00002143611,2,1436,11,2022,4,False,EMGALITY P,26.0,25.0,16642.78,16642.78,0.00,AK_2022_4_FFSU_00002143611,N02CD
3,FFSU,AK,00002144511,2,1445,11,2022,4,False,TALTZ AUTO,26.0,23.0,164678.63,164678.63,0.00,AK_2022_4_FFSU_00002144511,L04AC
4,FFSU,AK,00002147180,2,1471,80,2022,4,False,MOUNJARO,26.0,12.0,12214.55,10397.54,1817.01,AK_2022_4_FFSU_00002147180,A10BX
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4421708,FFSU,WY,78206012701,78206,127,1,2022,1,False,DULERA 100,546.0,38.0,13077.27,13077.27,0.00,WY_2022_1_FFSU_78206012701,R03CC
4421709,FFSU,WY,78206012701,78206,127,1,2022,1,False,DULERA 100,546.0,38.0,13077.27,13077.27,0.00,WY_2022_1_FFSU_78206012701,R03AK
4421710,FFSU,WY,78206012701,78206,127,1,2022,1,False,DULERA 100,546.0,38.0,13077.27,13077.27,0.00,WY_2022_1_FFSU_78206012701,R01AD
4421711,FFSU,WY,78206012701,78206,127,1,2022,1,False,DULERA 100,546.0,38.0,13077.27,13077.27,0.00,WY_2022_1_FFSU_78206012701,D07AC


In [None]:
# Distribution analyses
atc4_dist = analyzer.analyze_atc4_distribution()
atc3_dist = analyzer.analyze_atc3_distribution() 
atc2_dist = analyzer.analyze_atc2_distribution()


ATC4 CLASSES PER RECORD_ID DISTRIBUTION
Distribution of ATC4 classes per record_id:
  1 class(es): 1,614,669 record_ids (68.6%)
  2 class(es): 285,692 record_ids (12.1%)
  3 class(es): 226,981 record_ids (9.6%)
  4 class(es): 62,089 record_ids (2.6%)
  5 class(es): 50,428 record_ids (2.1%)
  6 class(es): 14,941 record_ids (0.6%)
  7 class(es): 25,751 record_ids (1.1%)
  8 class(es): 23,149 record_ids (1.0%)
  9 class(es): 14,940 record_ids (0.6%)
  10 class(es): 719 record_ids (0.0%)
  11 class(es): 22,722 record_ids (1.0%)
  12 class(es): 2,300 record_ids (0.1%)
  13 class(es): 680 record_ids (0.0%)
  14 class(es): 3,738 record_ids (0.2%)
  15 class(es): 519 record_ids (0.0%)
  16 class(es): 884 record_ids (0.0%)
  17 class(es): 388 record_ids (0.0%)
  20 class(es): 1,475 record_ids (0.1%)
  21 class(es): 52 record_ids (0.0%)
  22 class(es): 1,064 record_ids (0.0%)

Top 10 record_ids with most ATC4 classes:
  IA_2022_3_MCOU_24208083060: 22 classes
    NDC: 24208083060, State: IA, Yea

In [22]:
analyzer.fetch_atc_names()           # Fetch ATC4, ATC3, ATC2 names
analyzer.prepare_final_dataframe()   # Scale units and finalize
#Just to see something
different_atc2_records = analyzer.analyze_different_atc2_records()
analyzer.export_merged_data()  


FETCHING ATC CLASS NAMES
Using cache: c:\Users\lholguin\OneDrive - purdue.edu\VS code\Data\ATC\cache_files\atc_names_cache

Creating ATC3 and ATC2 columns from ATC4...

Unique codes to fetch:
  ATC4: 606
  ATC3: 209
  ATC2: 89

Fetching ATC4 names...
Fetching ATC3 names...
Fetching ATC2 names...

Total processing time: 0.0 minutes

Applying names to dataframe...

ATC names added successfully!

Sample output:
           NDC                   record_id ATC4 Class                                           ATC4_Name ATC3 Class                                     ATC3_Name ATC2 Class               ATC2_Name
0  00002143380  AK_2022_4_FFSU_00002143380      A10BJ           Glucagon-like peptide-1 (GLP-1) analogues       A10B  BLOOD GLUCOSE LOWERING DRUGS, EXCL. INSULINS        A10  DRUGS USED IN DIABETES
1  00002143480  AK_2022_4_FFSU_00002143480      A10BJ           Glucagon-like peptide-1 (GLP-1) analogues       A10B  BLOOD GLUCOSE LOWERING DRUGS, EXCL. INSULINS        A10  DRUGS USED IN DI

'c:\\Users\\lholguin\\OneDrive - purdue.edu\\VS code\\Data\\ATC\\merged_data\\merged_NEWdata_2022.csv'

In [21]:
#Looking at the new key identifier
path=r"C:\Users\lholguin\OneDrive - purdue.edu\VS code\Data\ATC\ATC4_classes\NDCNEW_2023_ATC4_classes.csv"
df_atcnew=pd.read_csv(path)
print(df_atcnew.head(10))
print(df_atcnew.columns)

           NDC                   record_id ATC4 Class
0  63323010601  WV_2023_1_MCOU_63323010601      A06AD
1  65862016901  NC_2023_2_MCOU_65862016901      C07AB
2  51672206902  MN_2023_1_MCOU_51672206902      S02BA
3  65162083594  CT_2023_3_FFSU_65162083594      J05AB
4  16714098502  IA_2023_4_MCOU_16714098502      D07AB
5    409653311  VA_2023_1_MCOU_00409653311      A07AA
6  65162008203  NV_2023_3_MCOU_65162008203      N05AE
7  59651003212  MA_2023_3_FFSU_59651003212      M02AA
8  67877025130  NY_2023_3_FFSU_67877025130      D07AB
9  42858011830  SD_2023_3_FFSU_42858011830      C01BB
Index(['NDC', 'record_id', 'ATC4 Class'], dtype='object')


In [None]:
def compare_atc_distributions_across_years(self, years_list):
    """Compare ATC distribution metrics across multiple years."""
    
    print(f"\n{'='*60}")
    print("ATC DISTRIBUTION COMPARISON ACROSS YEARS")
    print(f"{'='*60}")
    
    results_summary = {}
    
    for year in years_list:
        print(f"\nProcessing year {year}...")
        
        try:
            # Create analyzer for this year
            year_analyzer = NDCATCAnalyzer(year=year)
            year_analyzer.clean_sdud_data()
            year_analyzer.adding_key()
            year_analyzer.analyze_atc4_mapping()
            
            # Get distribution data
            atc4_dist = year_analyzer.analyze_atc4_distribution()
            atc3_dist = year_analyzer.analyze_atc3_distribution() 
            atc2_dist = year_analyzer.analyze_atc2_distribution()
            
            # Extract key metrics
            results_summary[year] = {
                'total_records': len(year_analyzer.atc_mapping[year_analyzer.atc_mapping['ATC4 Class'].notna()]),
                'unique_record_ids': len(atc4_dist) if atc4_dist is not None else 0,
                'atc4_avg': atc4_dist['num_atc4_classes'].mean() if atc4_dist is not None else 0,
                'atc4_max': atc4_dist['num_atc4_classes'].max() if atc4_dist is not None else 0,
                'atc3_avg': atc3_dist['num_atc3_classes'].mean() if atc3_dist is not None else 0,
                'atc3_max': atc3_dist['num_atc3_classes'].max() if atc3_dist is not None else 0,
                'atc2_avg': atc2_dist['num_atc2_classes'].mean() if atc2_dist is not None else 0,
                'atc2_max': atc2_dist['num_atc2_classes'].max() if atc2_dist is not None else 0,
            }
            
        except Exception as e:
            print(f"Error processing year {year}: {e}")
            results_summary[year] = None
    
    # Display comparison table
    print(f"\n{'='*80}")
    print("SUMMARY COMPARISON TABLE")
    print(f"{'='*80}")
    
    import pandas as pd
    comparison_df = pd.DataFrame(results_summary).T
    comparison_df = comparison_df.dropna()
    
    print("\nKey Metrics Across Years:")
    print(comparison_df.round(2))
    
    # Show trends
    if len(comparison_df) > 1:
        print(f"\nTrends:")
        print(f"Total Records: {comparison_df['total_records'].iloc[0]:,.0f} → {comparison_df['total_records'].iloc[-1]:,.0f}")
        print(f"Avg ATC4/record: {comparison_df['atc4_avg'].iloc[0]:.2f} → {comparison_df['atc4_avg'].iloc[-1]:.2f}")
        print(f"Avg ATC3/record: {comparison_df['atc3_avg'].iloc[0]:.2f} → {comparison_df['atc3_avg'].iloc[-1]:.2f}")
        print(f"Avg ATC2/record: {comparison_df['atc2_avg'].iloc[0]:.2f} → {comparison_df['atc2_avg'].iloc[-1]:.2f}")
    
    return comparison_df

def quick_atc_metrics_only(self, years_list):
    """Get only the ATC distribution metrics for multiple years (faster version)."""
    
    print("ATC DISTRIBUTION METRICS ACROSS YEARS")
    print("="*50)
    
    metrics = {}
    
    for year in years_list:
        print(f"Year {year}:", end=" ")
        
        try:
            # Quick setup
            analyzer = NDCATCAnalyzer(year=year)
            analyzer.clean_sdud_data()
            analyzer.adding_key()
            analyzer.analyze_atc4_mapping()
            
            # Get only the metrics we need
            records = analyzer.atc_mapping[analyzer.atc_mapping['ATC4 Class'].notna()].copy()
            records['ATC2 Class'] = records['ATC4 Class'].str[:3]
            records['ATC3 Class'] = records['ATC4 Class'].str[:4]
            
            # Quick calculations
            atc4_per_record = records.groupby('record_id')['ATC4 Class'].nunique()
            atc3_per_record = records.groupby('record_id')['ATC3 Class'].nunique()
            atc2_per_record = records.groupby('record_id')['ATC2 Class'].nunique()
            
            metrics[year] = {
                'Records': len(records),
                'Unique_IDs': len(atc4_per_record),
                'ATC4_Avg': atc4_per_record.mean(),
                'ATC4_Max': atc4_per_record.max(),
                'ATC3_Avg': atc3_per_record.mean(),
                'ATC3_Max': atc3_per_record.max(),
                'ATC2_Avg': atc2_per_record.mean(),
                'ATC2_Max': atc2_per_record.max(),
            }
            
            print("✓")
            
        except Exception as e:
            print(f"✗ ({e})")
            metrics[year] = None
    
    # Display results
    print(f"\nRESULTS:")
    import pandas as pd
    results_df = pd.DataFrame(metrics).T.dropna()
    print(results_df.round(2))
    
    return results_df

In [None]:
# Option 1: Full analysis for each year (slower but complete)
years_to_analyze = [2021, 2020, 2021, 2022]
comparison_results = analyzer.compare_atc_distributions_across_years(years_to_analyze)

# Option 2: Quick metrics only (faster)
quick_results = analyzer.quick_atc_metrics_only([2019, 2020, 2021, 2022])

# Option 3: Just call one year at a time and collect metrics manually
years_metrics = {}
for year in [2019, 2020, 2021]:
    analyzer = NDCATCAnalyzer(year=year)
    analyzer.clean_sdud_data()
    analyzer.adding_key()
    analyzer.analyze_atc4_mapping()
    
    # Get just the distribution summaries
    atc4_dist = analyzer.analyze_atc4_distribution()
    atc3_dist = analyzer.analyze_atc3_distribution() 
    atc2_dist = analyzer.analyze_atc2_distribution()
    
    years_metrics[year] = {
        'atc4_avg': atc4_dist['num_atc4_classes'].mean(),
        'atc3_avg': atc3_dist['num_atc3_classes'].mean(),
        'atc2_avg': atc2_dist['num_atc2_classes'].mean()
    }

print(years_metrics)