In [2]:
import pandas as pd
import numpy as np
import requests
import shelve
import os
from datetime import datetime
import matplotlib.pyplot as plt
import re
import ipywidgets as widgets
from IPython.display import display, clear_output
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [7]:
user="Lilian"
#user in personal pc <- "asus"

In [None]:
#New changes to the class
class NDCATCAnalyzer:

    def __init__(self, year, base_path=None):

        self.year = year
        if base_path is None:
            #Lookup the user's base path
            self.base_path = rf"c:\Users\{user}\OneDrive - purdue.edu\VS code\Data"
        else:
            self.base_path = base_path
            
        self.df_cleaned = None
        self.df_merged = None
        self.atc_mapping = None
        self.df_faf = None
        
    def clean_sdud_data(self):
        csv_file = os.path.join(self.base_path, f"SDUD\\SDUD{self.year}.csv")
        print(f"Reading CSV file: {csv_file}")
        
        # Read with NDC as string to preserve leading zeros
        df = pd.read_csv(csv_file, dtype={'NDC': 'object'})
        
        print(f"Total rows in {self.year} before filtering: {len(df)}")
        
        # Remove NA values
        df_filtered = df.dropna(subset=['Units Reimbursed', 'Number of Prescriptions'])
        print(f"Rows after removing NA: {len(df_filtered)}")
        
        # Filter out State='XX'
        df_filtered = df_filtered[df_filtered['State'] != 'XX']
        print(f"Rows after filtering State='XX': {len(df_filtered)}")
        print(f"Unique NDCs: {df_filtered['NDC'].nunique()}")
        
        self.df_cleaned = df_filtered
        return self.df_cleaned
    
    #NEW
    def adding_key(self):
        """Add record_id column to cleaned dataframe."""
        if self.df_cleaned is None:
            raise ValueError("Must run clean_sdud_data() first")
        
        print("Adding record_id column...")
        
        # Create record_id column
        self.df_cleaned['record_id'] = (
            self.df_cleaned['State'].astype(str) + "_" +
            self.df_cleaned['Year'].astype(str) + "_" +
            self.df_cleaned['Quarter'].astype(str) + "_" +
            self.df_cleaned['Utilization Type'].astype(str) + "_" +
            self.df_cleaned['NDC'].astype(str)
        )
        
        print(f"Created {len(self.df_cleaned)} record IDs")
        print(f"Sample record_id: {self.df_cleaned['record_id'].iloc[0]}")
        
        return self.df_cleaned
    
    def generate_ndc_txt(self, output_filename=None):
        """Step 2: Generate text file with unique NDC values and their record_id keys."""
        if self.df_cleaned is None:
            raise ValueError("Must run clean_sdud_data() first")
        
        if 'record_id' not in self.df_cleaned.columns:
            raise ValueError("Must run adding_key() first to create record_id column")
            
        if output_filename is None:
            output_filename = f"NDCNEW_{self.year}.txt"
        
        output_path = os.path.join(self.base_path, f"ATC\\text_files\\{output_filename}")
        
        # Get unique combinations of NDC and record_id
        unique_pairs = self.df_cleaned[['NDC', 'record_id']].drop_duplicates()
        
        with open(output_path, 'w') as f:
            # Write header
            f.write("NDC\trecord_id\n")
            # Write each unique pair
            for _, row in unique_pairs.iterrows():
                f.write(f"{row['NDC']}\t{row['record_id']}\n")
        
        print(f"Exported to: {output_path}")
        print(f"Unique record_id values: {unique_pairs['record_id'].nunique()}")
        return output_path
    
    def analyze_atc4_mapping(self):
        """Step 3: Analyze ATC4 mapping results and identify missing NDCs."""
        atc4_path = os.path.join(self.base_path, f"ATC\\ATC4_classes\\NDCNEW_{self.year}_ATC4_classes.csv")
        
        # Read ATC4 mapping
        df_atc4 = pd.read_csv(atc4_path, dtype={'NDC': 'object', 'record_id': 'string'})
        df_atc4['NDC'] = df_atc4['NDC'].str.zfill(11)
        
        # Ensure consistent data types before merge
        self.df_cleaned['record_id'] = self.df_cleaned['record_id'].astype('string')
        self.df_cleaned['NDC'] = self.df_cleaned['NDC'].astype('object')  
        df_atc4['record_id'] = df_atc4['record_id'].astype('string')
        df_atc4['NDC'] = df_atc4['NDC'].astype('object')  

        # Merge ATC4 mapping with cleaned data using record_id
        if self.df_cleaned is None:
            raise ValueError("Must run clean_sdud_data() and adding_key() first")
        
        if 'record_id' not in self.df_cleaned.columns:
            raise ValueError("Must run adding_key() first to create record_id column")
        
        print(f"Merging ATC4 mapping with cleaned data using record_id and NDC...")
        
        # Merge on BOTH record_id AND NDC
        self.atc_mapping = pd.merge(
            self.df_cleaned,
            df_atc4[['record_id', 'NDC', 'ATC4 Class']],  # Include NDC in the selection
            on=['record_id', 'NDC'],  # Merge on both columns
            how='left'
        )
        
        # Print rows of the merged dataframe
        print(f"Merged dataframe rows: {len(self.atc_mapping)}")
        print(self.atc_mapping.head())
        total_records = len(self.atc_mapping)
        mapped_records = self.atc_mapping['ATC4 Class'].notna().sum()
        print(f"Records with ATC4 mapping: {mapped_records} ({mapped_records/total_records*100:.1f}%)")
        
        # Identify missing mappings
        missing_records = self.atc_mapping[self.atc_mapping['ATC4 Class'].isna()]
        if len(missing_records) > 0:
            print(f"\nRecords without ATC4 mapping: {len(missing_records)}")
            print(f"Unique NDCs without mapping: {missing_records['NDC'].nunique()}")
        
        # Export the atc_mapping
        output_path = os.path.join(self.base_path, f"ATC\\merged_data\\atc_mapping_{self.year}.csv")
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        self.atc_mapping.to_csv(output_path, index=False)
        print(f"\nExported atc_mapping to: {output_path}")
        print(f"Total rows exported: {len(self.atc_mapping)}")
        
        return self.atc_mapping
        
    def analyze_atc4_distribution(self):

        """Analyze distribution of ATC4 classes per record_id."""
        if self.atc_mapping is None:
            raise ValueError("Must run analyze_atc4_mapping() first")
        
        print(f"\n{'='*60}")
        print("ATC4 CLASSES PER RECORD_ID DISTRIBUTION")
        print(f"{'='*60}")
        
        # Count ATC4 classes per record_id (only valid mappings)
        records_with_mapping = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()].copy()
        
        if len(records_with_mapping) == 0:
            print("No records with valid ATC4 mappings found.")
            return None
        
        # Group by record_id and count unique ATC4 classes
        atc4_per_record = records_with_mapping.groupby('record_id').agg({
            'ATC4 Class': 'nunique',
            'NDC': 'first',  # Get the NDC for reference
            'State': 'first',  # Get the state for reference
            'Year': 'first'    # Get the year for reference
        }).reset_index()
        
        atc4_per_record.columns = ['record_id', 'num_atc4_classes', 'NDC', 'State', 'Year']
        
        # Distribution analysis
        distribution = atc4_per_record['num_atc4_classes'].value_counts().sort_index()
        
        print("Distribution of ATC4 classes per record_id:")
        for classes, count in distribution.items():
            pct = (count / len(atc4_per_record)) * 100
            print(f"  {classes} class(es): {count:,} record_ids ({pct:.1f}%)")
        
        # Show examples of multi-class records
        multi_class = atc4_per_record[atc4_per_record['num_atc4_classes'] > 1].sort_values('num_atc4_classes', ascending=False)
        
        if len(multi_class) > 0:
            print(f"\nTop 10 record_ids with most ATC4 classes:")
            for _, row in multi_class.head(10).iterrows():
                record_classes = records_with_mapping[records_with_mapping['record_id'] == row['record_id']]['ATC4 Class'].unique()
                print(f"  {row['record_id']}: {row['num_atc4_classes']} classes")
                print(f"    NDC: {row['NDC']}, State: {row['State']}, Year: {row['Year']}")
                print(f"    Classes: {list(record_classes)}")
                print()
        
        return atc4_per_record

    def fetch_atc_names(self, cache_path=None):
        """Fetch ATC class names (ATC4, ATC3, ATC2) from RxNav API."""
        if self.atc_mapping is None:
            raise ValueError("Must run analyze_atc4_mapping() first")
        
        if cache_path is None:
            cache_path = os.path.join(self.base_path, "ATC\\cache_files\\atc_names_cache")
        
        print(f"\n{'='*60}")
        print("FETCHING ATC CLASS NAMES")
        print(f"{'='*60}")
        print(f"Using cache: {cache_path}")
        
        # Get only records with valid ATC4 mappings
        df_with_atc = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()].copy()
        
        # Create ATC3 and ATC2 columns from ATC4
        print("\nCreating ATC3 and ATC2 columns from ATC4...")
        df_with_atc['ATC3 Class'] = df_with_atc['ATC4 Class'].str[:4]
        df_with_atc['ATC2 Class'] = df_with_atc['ATC4 Class'].str[:3]
        
        # Get unique codes for each level
        unique_atc4 = df_with_atc['ATC4 Class'].dropna().unique()
        unique_atc3 = df_with_atc['ATC3 Class'].dropna().unique()
        unique_atc2 = df_with_atc['ATC2 Class'].dropna().unique()
        
        # Filter out invalid codes
        unique_atc4 = [c for c in unique_atc4 if c not in ['No ATC Mapping Found', 'No RxCUI Found', '']]
        unique_atc3 = [c for c in unique_atc3 if c not in ['No ATC Mapping Found', 'No RxCUI Found', '', 'No ', 'No']]
        unique_atc2 = [c for c in unique_atc2 if c not in ['No ATC Mapping Found', 'No RxCUI Found', '', 'No ', 'No']]
        
        print(f"\nUnique codes to fetch:")
        print(f"  ATC4: {len(unique_atc4)}")
        print(f"  ATC3: {len(unique_atc3)}")
        print(f"  ATC2: {len(unique_atc2)}")
        
        # Build mappings
        atc4_names = {}
        atc3_names = {}
        atc2_names = {}
        
        with shelve.open(cache_path) as cache:
            start_time = datetime.now()
            
            print("\nFetching ATC4 names...")
            for code in unique_atc4:
                atc4_names[code] = self._get_atc_name(code, cache)
            
            print("Fetching ATC3 names...")
            for code in unique_atc3:
                atc3_names[code] = self._get_atc_name(code, cache)
            
            print("Fetching ATC2 names...")
            for code in unique_atc2:
                atc2_names[code] = self._get_atc_name(code, cache)
            
            print(f"\nTotal processing time: {(datetime.now() - start_time).total_seconds()/60:.1f} minutes")
        
        # Apply names to all records in atc_mapping
        print("\nApplying names to dataframe...")
        self.atc_mapping['ATC3 Class'] = self.atc_mapping['ATC4 Class'].str[:4]
        self.atc_mapping['ATC2 Class'] = self.atc_mapping['ATC4 Class'].str[:3]
        
        self.atc_mapping['ATC4_Name'] = self.atc_mapping['ATC4 Class'].map(atc4_names).fillna('')
        self.atc_mapping['ATC3_Name'] = self.atc_mapping['ATC3 Class'].map(atc3_names).fillna('')
        self.atc_mapping['ATC2_Name'] = self.atc_mapping['ATC2 Class'].map(atc2_names).fillna('')
        
        print(f"\nATC names added successfully!")
        print("\nSample output:")
        sample = self.atc_mapping[self.atc_mapping['ATC4 Class'].notna()][['NDC', 'record_id', 'ATC4 Class', 'ATC4_Name', 'ATC3 Class', 'ATC3_Name', 'ATC2 Class', 'ATC2_Name']].head(5)
        print(sample.to_string())
        
        return self.atc_mapping
    
    def prepare_final_dataframe(self):
        """Prepare final dataframe with scaled metrics for export."""
        if self.atc_mapping is None:
            raise ValueError("Must run fetch_atc_names() first")
        
        print(f"\n{'='*60}")
        print("PREPARING FINAL DATAFRAME")
        print(f"{'='*60}")
        
        # Create a copy for final output
        self.df_merged = self.atc_mapping.copy()
        
        # Scale units
        print("\nScaling units...")
        self.df_merged['Units Reimbursed'] = self.df_merged['Units Reimbursed'] / 1e9
        self.df_merged['Number of Prescriptions'] = self.df_merged['Number of Prescriptions'] / 1e6
        
        # Report final statistics
        total_records = len(self.df_merged)
        mapped_records = self.df_merged['ATC4 Class'].notna().sum()
        
        print(f"\nFinal statistics:")
        print(f"Total records: {total_records:,}")
        print(f"Records with ATC4 mapping: {mapped_records:,} ({mapped_records/total_records*100:.1f}%)")
        print(f"Total Units Reimbursed: {self.df_merged['Units Reimbursed'].sum():.2f} Billion")
        print(f"Total Prescriptions: {self.df_merged['Number of Prescriptions'].sum():.2f} Million")
        
        return self.df_merged
    
    def _get_atc_name(self, atc_code, cache):
        """Get ATC class name from code, using cache."""
        cache_key = f"atc_name:{atc_code}"
        if cache_key in cache:
            return cache[cache_key]
        
        try:
            url = f"https://rxnav.nlm.nih.gov/REST/rxclass/class/byId.json?classId={atc_code}"
            response = requests.get(url)
            response.raise_for_status()
            data = response.json()
            
            # Get class name
            if 'rxclassMinConceptList' in data and 'rxclassMinConcept' in data['rxclassMinConceptList']:
                concepts = data['rxclassMinConceptList']['rxclassMinConcept']
                if concepts:
                    name = concepts[0].get('className', '')
                    cache[cache_key] = name
                    return name
            
            cache[cache_key] = ''
            return ''
            
        except Exception as e:
            print(f"Error retrieving name for {atc_code}: {e}")
            cache[cache_key] = ''
            return ''
    
    def export_merged_data(self, output_filename=None):
        """Export the final merged dataframe to CSV."""
        if self.df_merged is None:
            raise ValueError("Must run prepare_final_dataframe() first to create merged dataframe")
            
        if output_filename is None:
            output_filename = f"merged_NEWdata_{self.year}.csv"

        output_path = os.path.join(self.base_path, f"ATC\\merged_data\\{output_filename}")

        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        self.df_merged.to_csv(output_path, index=False)
        
        print(f"\n{'='*60}")
        print("DATA EXPORT COMPLETE")
        print(f"{'='*60}")
        print(f"Exported to: {output_path}")
        print(f"Total rows exported: {len(self.df_merged):,}")
        print(f"Columns: {', '.join(self.df_merged.columns.tolist())}")
        
        return output_path


In [12]:
analyzer = NDCATCAnalyzer(year=2023)
analyzer.clean_sdud_data()           # Clean SDUD data
analyzer.adding_key()                # Add record_id key
analyzer.generate_ndc_txt()          # Generate NDC text file
analyzer.analyze_atc4_mapping()      # Merge ATC4 by record_id & NDC
analyzer.analyze_atc4_distribution() # Analyze distribution
analyzer.fetch_atc_names()           # Fetch ATC4, ATC3, ATC2 names
analyzer.prepare_final_dataframe()   # Scale units and finalize
analyzer.export_merged_data()  

Reading CSV file: c:\Users\Lilian\OneDrive - purdue.edu\VS code\Data\SDUD\SDUD2023.csv
Total rows in 2023 before filtering: 5277298
Total rows in 2023 before filtering: 5277298
Rows after removing NA: 2651527
Rows after filtering State='XX': 2413521
Rows after removing NA: 2651527
Rows after filtering State='XX': 2413521
Unique NDCs: 34439
Adding record_id column...
Unique NDCs: 34439
Adding record_id column...
Created 2413521 record IDs
Sample record_id: AK_2023_4_FFSU_00002143380
Created 2413521 record IDs
Sample record_id: AK_2023_4_FFSU_00002143380
Exported to: c:\Users\Lilian\OneDrive - purdue.edu\VS code\Data\ATC\text_files\NDCNEW_2023.txt
Exported to: c:\Users\Lilian\OneDrive - purdue.edu\VS code\Data\ATC\text_files\NDCNEW_2023.txt
Unique record_id values: 2413521
Unique record_id values: 2413521
Merging ATC4 mapping with cleaned data using record_id and NDC...
Merging ATC4 mapping with cleaned data using record_id and NDC...
Merged dataframe rows: 4454193
  Utilization Type Sta

'c:\\Users\\Lilian\\OneDrive - purdue.edu\\VS code\\Data\\ATC\\merged_data\\merged_NEWdata_2023.csv'

In [55]:
#Looking at the new key identifier
path=r"C:\Users\lholguin\OneDrive - purdue.edu\VS code\Data\ATC\ATC4_classes\NDCNEW_2023_ATC4_classes.csv"
df_atcnew=pd.read_csv(path)
print(df_atcnew.head(10))
print(df_atcnew.columns)

           NDC                   record_id ATC4 Class
0  63323010601  WV_2023_1_MCOU_63323010601      A06AD
1  65862016901  NC_2023_2_MCOU_65862016901      C07AB
2  51672206902  MN_2023_1_MCOU_51672206902      S02BA
3  65162083594  CT_2023_3_FFSU_65162083594      J05AB
4  16714098502  IA_2023_4_MCOU_16714098502      D07AB
5    409653311  VA_2023_1_MCOU_00409653311      A07AA
6  65162008203  NV_2023_3_MCOU_65162008203      N05AE
7  59651003212  MA_2023_3_FFSU_59651003212      M02AA
8  67877025130  NY_2023_3_FFSU_67877025130      D07AB
9  42858011830  SD_2023_3_FFSU_42858011830      C01BB
Index(['NDC', 'record_id', 'ATC4 Class'], dtype='object')


In [None]:
# Check total sums
total_units = analyzer.df_merged['Units Reimbursed'].sum()
total_prescriptions = analyzer.df_merged['Number of Prescriptions'].sum()

print(f"\nTotal Statistics:")
print(f"Total Units Reimbursed: {total_units:.4f} Billion")
print(f"Total Number of Prescriptions: {total_prescriptions:.4f} Million")
print(f"Total rows in dataset: {len(analyzer.df_merged):,}")

In [None]:
# Plot Year vs Top 5 ATC Classes for a Single State
if 'analyzer' in globals() and hasattr(analyzer, 'df_merged') and analyzer.df_merged is not None:
    # Select state to analyze (change this to any state you want)
    target_state = 'IN'  # Indiana - change to any state code
    
    # Filter data for the target state
    state_data = analyzer.df_merged[analyzer.df_merged['State'] == target_state].copy()
    
    if len(state_data) > 0:
        print(f"Analyzing state: {target_state}")
        print(f"Total records for {target_state}: {len(state_data):,}")
        
        # Group by Quarter (year) and ATC4 Class, sum prescriptions
        quarterly_atc = state_data.groupby(['Quarter', 'ATC4 Class']).agg({
            'Number of Prescriptions': 'sum',
            'Units Reimbursed': 'sum'
        }).reset_index()
        
        # Find top 5 ATC classes by total prescriptions across all quarters
        top_atc_classes = quarterly_atc.groupby('ATC4 Class')['Number of Prescriptions'].sum().sort_values(ascending=False).head(5)
        top_5_classes = top_atc_classes.index.tolist()
        
        print(f"\nTop 5 ATC Classes in {target_state}:")
        for i, atc_class in enumerate(top_5_classes, 1):
            total_prescriptions = top_atc_classes[atc_class]
            print(f"{i}. {atc_class}: {total_prescriptions:.2f} million prescriptions")
        
        # Filter data for top 5 classes only
        top_5_data = quarterly_atc[quarterly_atc['ATC4 Class'].isin(top_5_classes)].copy()
        
        # Create the plot
        plt.figure(figsize=(12, 8))
        
        # Plot each ATC class as a separate line
        for atc_class in top_5_classes:
            class_data = top_5_data[top_5_data['ATC4 Class'] == atc_class]
            plt.plot(class_data['Quarter'], class_data['Number of Prescriptions'], 
                    marker='o', linewidth=2, label=atc_class)
        
        plt.title(f'Top 5 ATC Classes by Quarter in {target_state} ({analyzer.year})', fontsize=14, fontweight='bold')
        plt.xlabel('Quarter', fontsize=12)
        plt.ylabel('Number of Prescriptions (Millions)', fontsize=12)
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()
        
        # Show summary table
        print(f"\nSummary table for {target_state}:")
        summary_pivot = top_5_data.pivot(index='Quarter', columns='ATC4 Class', values='Number of Prescriptions').fillna(0)
        print(summary_pivot.round(2))
        
    else:
        print(f"No data found for state: {target_state}")
        print(f"Available states: {sorted(analyzer.df_merged['State'].unique())}")
        
else:
    print("Please run the NDC-ATC workflow first!")

In [None]:
# Interactive Plot with Filters using Plotly
if 'analyzer' in globals() and hasattr(analyzer, 'df_merged') and analyzer.df_merged is not None:
    
    # Get unique values for filters
    states = ['All States']+sorted(analyzer.df_merged['State'].unique())
    years = sorted(analyzer.df_merged['Year'].unique()) if 'Year' in analyzer.df_merged.columns else [analyzer.year]
    quarters = ['All'] + sorted([str(q) for q in analyzer.df_merged['Quarter'].unique()])
    
    # Create dropdown widgets
    state_dropdown = widgets.Dropdown(
        options=states,
        value=states[0],
        description='State:',
        disabled=False,
    )
    
    year_dropdown = widgets.Dropdown(
        options=years,
        value=years[0],
        description='Year:',
        disabled=False,
    )
    
    quarter_dropdown = widgets.Dropdown(
        options=quarters,
        value='All',
        description='Quarter:',
        disabled=False,
    )
    
    top_n_slider = widgets.IntSlider(
        value=10,
        min=5,
        max=20,
        step=5,
        description='Top N:',
        disabled=False,
    )
    
    # Create output widget
    output = widgets.Output()
    
    def update_plot(change=None):
        with output:
            clear_output(wait=True)
            
            state = state_dropdown.value
            year = year_dropdown.value
            quarter = quarter_dropdown.value
            top_n = top_n_slider.value
            #Filtering
            if state=='All States':
                filtered_data=analyzer.df_merged.copy()
            else:
                filtered_data = analyzer.df_merged[analyzer.df_merged['State'] == state].copy()

            # Filter data by state and year
            filtered_data = analyzer.df_merged[
                analyzer.df_merged['State'] == state
            ].copy()
            
            if 'Year' in analyzer.df_merged.columns:
                filtered_data = filtered_data[filtered_data['Year'] == year]
            
            # Filter by quarter if not 'All'
            if quarter != 'All':
                filtered_data = filtered_data[filtered_data['Quarter'] == int(quarter)]
            
            quarter_text = f"Q{quarter}" if quarter != 'All' else "All Quarters"
            
            if len(filtered_data) == 0:
                print(f"No data available for State: {state}, Year: {year}, Quarter: {quarter_text}")
                return
            
            # Group by ATC3 Class and sum metrics
            atc3_summary = filtered_data.groupby(['ATC3 Class', 'ATC3_Name']).agg({
                'Units Reimbursed': 'sum',
                'Number of Prescriptions': 'sum'
            }).reset_index()
            
            # Create labels combining code and name
            atc3_summary['Label'] = atc3_summary['ATC3 Class'] + ': ' + atc3_summary['ATC3_Name']
            
            # Get top N for each metric
            top_units = atc3_summary.nlargest(top_n, 'Units Reimbursed').sort_values('Units Reimbursed', ascending=True)
            top_prescriptions = atc3_summary.nlargest(top_n, 'Number of Prescriptions').sort_values('Number of Prescriptions', ascending=True)
            
            # Create subplots
            fig = make_subplots(
                rows=2, cols=1,
                subplot_titles=(
                    f'Top {top_n} ATC3 Classes by Units Reimbursed<br>{state} - {year} - {quarter_text}',
                    f'Top {top_n} ATC3 Classes by Number of Prescriptions<br>{state} - {year} - {quarter_text}'
                ),
                vertical_spacing=0.15
            )
            
            # Add first bar chart (Units Reimbursed)
            fig.add_trace(
                go.Bar(
                    y=top_units['Label'],
                    x=top_units['Units Reimbursed'],
                    orientation='h',
                    marker=dict(
                        color=top_units['Units Reimbursed'],
                        colorscale='YlGnBu',
                        showscale=False
                    ),
                    text=[f'{val:.2f}B' for val in top_units['Units Reimbursed']],
                    textposition='outside',
                    hovertemplate='<b>%{y}</b><br>Units: %{x:.2f}B<extra></extra>'
                ),
                row=1, col=1
            )
            
            # Add second bar chart (Number of Prescriptions)
            fig.add_trace(
                go.Bar(
                    y=top_prescriptions['Label'],
                    x=top_prescriptions['Number of Prescriptions'],
                    orientation='h',
                    marker=dict(
                        color=top_prescriptions['Number of Prescriptions'],
                        colorscale='OrRd',
                        showscale=False
                    ),
                    text=[f'{val:.2f}M' for val in top_prescriptions['Number of Prescriptions']],
                    textposition='outside',
                    hovertemplate='<b>%{y}</b><br>Prescriptions: %{x:.2f}M<extra></extra>'
                ),
                row=2, col=1
            )
            
            # Update layout
            fig.update_xaxes(title_text='Units Reimbursed (Billions)', row=1, col=1, gridcolor='lightgray')
            fig.update_xaxes(title_text='Number of Prescriptions (Millions)', row=2, col=1, gridcolor='lightgray')
            
            fig.update_layout(
                height=800,
                showlegend=False,
                plot_bgcolor='white',
                font=dict(size=10)
            )
            
            fig.show()
            
            # Print summary statistics
            print(f"\n{'='*60}")
            print(f"Summary for {state} - {year} - {quarter_text}")
            print(f"{'='*60}")
            print(f"\nTotal ATC3 Classes: {len(atc3_summary)}")
            print(f"Total Units Reimbursed: {filtered_data['Units Reimbursed'].sum():.2f} Billion")
            print(f"Total Prescriptions: {filtered_data['Number of Prescriptions'].sum():.2f} Million")
    
    # Attach the update function to dropdown changes
    state_dropdown.observe(update_plot, names='value')
    year_dropdown.observe(update_plot, names='value')
    quarter_dropdown.observe(update_plot, names='value')
    top_n_slider.observe(update_plot, names='value')
    
    # Display widgets and output
    print("\n" + "="*60)
    print("Interactive ATC3 Class Analysis")
    print("="*60 + "\n")
    display(widgets.VBox([state_dropdown, year_dropdown, quarter_dropdown, top_n_slider, output]))
    
    # Generate initial plot
    update_plot()
    
else:
    print("Please run the NDC-ATC workflow first!")