In [1]:
# CMS Code Evaluation for Diagnosis given free text. Change df to given CSV and run EVAL. 

In [1]:
import pandas as pd

In [2]:
df = pd.read_csv('llama_diag.csv')

In [3]:
df.columns, df.shape

(Index(['Unnamed: 0', 'Unnamed: 0.1', 'Unnamed: 0.1.1', 'patientdurablekey',
        'encounterkey', 'ArrivalDateKey', 'DepartureDateKeyValue',
        'DepartureDateKey', 'DispositionDateKeyValue',
        'primarychiefcomplaintname', 'primaryeddiagnosisname', 'sex',
        'birthdate', 'firstrace', 'preferredlanguage',
        'highestlevelofeducation', 'maritalstatus', 'Age',
        'Discharge_Summary_Date', 'Discharge_Summary_Note_Key',
        'Progress_Note_Date', 'Progress_Note_Key', 'HP_Note_Date',
        'HP_Note_Key', 'Echo_Date', 'Echo_Key', 'Imaging_Date', 'Imaging_Key',
        'Consult_Date', 'Consult_Key', 'ED_Provider_Notes_Date',
        'ED_Provider_Notes_Key', 'ECG_Date', 'ECG_Key',
        'Discharge_Summary_Text', 'Progress_Note_Text', 'HP_Note_Text',
        'Echo_Text', 'Imaging_Text', 'Consult_Text', 'ECG_Text',
        'ED_Provider_Notes_Text', 'One_Sentence_Extracted', 'note_count',
        'Predicted_Diagnosis', 'Requested_Notes', 'Prediction_Correct',
   

In [4]:
cms_df = pd.read_csv('icd_10_cm_mappings2018.csv') ### 2018 CMS CODE ICD-10 CROSS WALK HERE! 

In [5]:
cms_df.columns

Index(['dgns_cd', 'description', 'hccesrdv21', 'hcc22', 'hccrxv05', 'hccesrd',
       'hcc', 'hccrx', 'fyear'],
      dtype='object')

In [42]:
cms_df[0:100]

Unnamed: 0,dgns_cd,description,hccesrdv21,hcc22,hccrxv05,hccesrd,hcc,hccrx,fyear
0,A0103,Typhoid pneumonia,115.0,115.0,,Yes,Yes,No,2018
1,A0104,Typhoid arthritis,39.0,39.0,,Yes,Yes,No,2018
2,A0105,Typhoid osteomyelitis,39.0,39.0,,Yes,Yes,No,2018
3,A021,Salmonella sepsis,2.0,2.0,,Yes,Yes,No,2018
4,A0222,Salmonella pneumonia,115.0,115.0,,Yes,Yes,No,2018
...,...,...,...,...,...,...,...,...,...
95,B3324,Viral cardiomyopathy,85.0,85.0,186.0,Yes,Yes,Yes,2018
96,B371,Pulmonary candidiasis,6.0,6.0,5.0,Yes,Yes,Yes,2018
97,B377,Candidal sepsis,2.0,2.0,5.0,Yes,Yes,Yes,2018
98,B377,Candidal sepsis,6.0,6.0,,Yes,Yes,--,2018


In [6]:
variations = df.loc[:,['primaryeddiagnosisname','Clean_Diagnosis']]

In [7]:
variations.to_csv('variations.csv')

In [34]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import re
import time

# Function to clean and standardize text
def standardize_text(text):
    if pd.isna(text):
        return ""
    text = str(text).lower()
    text = re.sub(r'[^\w\s]', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# Prepare CMS lookup data structures for efficient matching
def prepare_cms_lookup(cms_df):
    print("Preparing efficient CMS lookup tables...")
    
    # Create dictionaries for fast lookups
    code_to_desc = dict(zip(cms_df['dgns_cd'], cms_df['description']))
    
    # Create standardized descriptions
    cms_df['std_description'] = cms_df['description'].apply(standardize_text)
    
    # Create keyword index for faster matching
    keyword_to_codes = {}
    for _, row in cms_df.iterrows():
        std_desc = row['std_description']
        code = row['dgns_cd']
        
        # Add full description
        if std_desc not in keyword_to_codes:
            keyword_to_codes[std_desc] = []
        keyword_to_codes[std_desc].append(code)
        
        # Add individual keywords
        for word in std_desc.split():
            if len(word) > 3:  # Only index meaningful words
                if word not in keyword_to_codes:
                    keyword_to_codes[word] = []
                keyword_to_codes[word].append(code)
    
    return {
        'code_to_desc': code_to_desc,
        'keyword_to_codes': keyword_to_codes,
        'std_descriptions': cms_df['std_description'].tolist(),
        'codes': cms_df['dgns_cd'].tolist()
    }

# Optimized function to find matching CMS code
def efficient_match_to_cms(text, cms_lookup):
    if pd.isna(text) or text == "":
        return "UNKNOWN", 0
    
    # Standardize input text
    std_text = standardize_text(text)
    
    # 1. Try exact match first
    if std_text in cms_lookup['keyword_to_codes']:
        exact_matches = cms_lookup['keyword_to_codes'][std_text]
        if exact_matches:
            return exact_matches[0], 100
    
    # 2. Try keyword matching - more efficient than fuzzy matching
    matches = {}
    for word in std_text.split():
        if len(word) > 3 and word in cms_lookup['keyword_to_codes']:
            for code in cms_lookup['keyword_to_codes'][word]:
                if code not in matches:
                    matches[code] = 0
                matches[code] += 1
    
    if matches:
        # Sort matches by frequency of keyword matches
        sorted_matches = sorted(matches.items(), key=lambda x: x[1], reverse=True)
        best_code, match_count = sorted_matches[0]
        
        # Calculate a score based on how many words matched
        total_words = len([w for w in std_text.split() if len(w) > 3])
        score = min(100, int((match_count / max(1, total_words)) * 100))
        
        return best_code, score
    
    # 3. If no keyword matches, fall back to category matching
    # Check if the text contains common categories (first few characters of codes)
    categories = {}
    for code in cms_lookup['codes']:
        category = code[:3]
        if category not in categories:
            categories[category] = 0
        
        desc = cms_lookup['code_to_desc'].get(code, "")
        if any(word in standardize_text(desc) for word in std_text.split()):
            categories[category] += 1
    
    if categories:
        sorted_categories = sorted(categories.items(), key=lambda x: x[1], reverse=True)
        best_category = sorted_categories[0][0]
        # Find a code in this category
        for code in cms_lookup['codes']:
            if code.startswith(best_category):
                return code, 40
    
    return "UNKNOWN", 0

# Optimized function to map diagnoses to CMS codes with batching
def map_diagnoses_to_cms_optimized(df, cms_lookup, batch_size=1000):
    print("Mapping diagnoses to CMS codes with batching...")
    start_time = time.time()
    
    # Initialize new columns
    df['ChiefComplaint_CMS'] = "UNKNOWN"
    df['ChiefComplaint_Match_Score'] = 0
    df['PredictedDiagnosis_CMS'] = "UNKNOWN"
    df['PredictedDiagnosis_Match_Score'] = 0
    
    total_rows = len(df)
    for start_idx in range(0, total_rows, batch_size):
        end_idx = min(start_idx + batch_size, total_rows)
        print(f"Processing batch {start_idx//batch_size + 1}/{(total_rows-1)//batch_size + 1} (rows {start_idx}-{end_idx-1})...")
        
        # Process chief complaints
        if 'primaryeddiagnosisname' in df.columns:
            for i in range(start_idx, end_idx):
                if i < total_rows and not pd.isna(df.iloc[i]['primaryeddiagnosisname']):
                    code, score = efficient_match_to_cms(df.iloc[i]['primaryeddiagnosisname'], cms_lookup)
                    df.iloc[i, df.columns.get_loc('ChiefComplaint_CMS')] = code
                    df.iloc[i, df.columns.get_loc('ChiefComplaint_Match_Score')] = score
        
        # Process predicted diagnoses
        if 'Predicted_Diagnosis' in df.columns:
            for i in range(start_idx, end_idx):
                if i < total_rows and not pd.isna(df.iloc[i]['Predicted_Diagnosis']):
                    code, score = efficient_match_to_cms(df.iloc[i]['Predicted_Diagnosis'], cms_lookup)
                    df.iloc[i, df.columns.get_loc('PredictedDiagnosis_CMS')] = code
                    df.iloc[i, df.columns.get_loc('PredictedDiagnosis_Match_Score')] = score
    
    # Add match indicators
    df['CMS_Match'] = df['ChiefComplaint_CMS'] == df['PredictedDiagnosis_CMS']
    df['ChiefComplaint_Category'] = df['ChiefComplaint_CMS'].astype(str).str[:3]
    df['PredictedDiagnosis_Category'] = df['PredictedDiagnosis_CMS'].astype(str).str[:3]
    df['Category_Match'] = df['ChiefComplaint_Category'] == df['PredictedDiagnosis_Category']
    
    # Add descriptions
    df['ChiefComplaint_Description'] = df['ChiefComplaint_CMS'].map(cms_lookup['code_to_desc'])
    df['PredictedDiagnosis_Description'] = df['PredictedDiagnosis_CMS'].map(cms_lookup['code_to_desc'])
    
    elapsed_time = time.time() - start_time
    print(f"Mapping completed in {elapsed_time:.2f} seconds")
    
    return df

# Main evaluation function
def evaluate_diagnoses_optimized(df, cms_df):
    # Prepare efficient lookup
    cms_lookup = prepare_cms_lookup(cms_df)
    
    # Map diagnoses
    df_mapped = map_diagnoses_to_cms_optimized(df, cms_lookup)
    
    # Calculate metrics
    accuracy = df_mapped['CMS_Match'].mean()
    category_accuracy = df_mapped['Category_Match'].mean()
    
    # Filter to high confidence matches
    high_conf_df = df_mapped[(df_mapped['ChiefComplaint_Match_Score'] >= 70) & 
                            (df_mapped['PredictedDiagnosis_Match_Score'] >= 70)]
    high_conf_accuracy = high_conf_df['CMS_Match'].mean() if len(high_conf_df) > 0 else 0
    
    # Create a basic confusion matrix for top diagnoses
    top_cms_codes = df_mapped['ChiefComplaint_CMS'].value_counts().head(10).index.tolist()
    cm_df = df_mapped[df_mapped['ChiefComplaint_CMS'].isin(top_cms_codes)]
    
    confusion = None
    if len(cm_df) > 0:
        confusion = confusion_matrix(
            cm_df['ChiefComplaint_CMS'], 
            cm_df['PredictedDiagnosis_CMS'],
            labels=top_cms_codes
        )
    
    return {
        'overall_accuracy': accuracy,
        'category_accuracy': category_accuracy,
        'high_confidence_accuracy': high_conf_accuracy,
        'df_mapped': df_mapped,
        'top_cms_codes': top_cms_codes,
        'confusion_matrix': confusion
    }

# Simplified visualization function
def visualize_results_simple(results):
    df_mapped = results['df_mapped']
    
    # Create figure
    plt.figure(figsize=(18, 12))
    
    # Plot 1: Accuracy metrics
    plt.subplot(2, 2, 1)
    metrics = {
        'Exact Match': results['overall_accuracy'],
        'Category Match': results['category_accuracy'],
        'High Confidence': results.get('high_confidence_accuracy', 0)
    }
    plt.bar(metrics.keys(), metrics.values())
    plt.title('Accuracy Metrics')
    plt.ylabel('Accuracy')
    for i, (k, v) in enumerate(metrics.items()):
        plt.text(i, v + 0.02, f'{v:.1%}', ha='center')
    plt.ylim(0, 1)
    
    # Plot 2: Top predicted diagnoses
    plt.subplot(2, 2, 2)
    top_pred = df_mapped['PredictedDiagnosis_CMS'].value_counts().head(10)
    sns.barplot(x=top_pred.values, y=top_pred.index)
    plt.title('Top 10 Predicted CMS Codes')
    plt.xlabel('Count')
    
    # Plot 3: Top chief complaints
    plt.subplot(2, 2, 3)
    top_chief = df_mapped['ChiefComplaint_CMS'].value_counts().head(10)
    sns.barplot(x=top_chief.values, y=top_chief.index)
    plt.title('Top 10 Chief Complaint CMS Codes')
    plt.xlabel('Count')
    
    # Plot 4: Confusion matrix if available
    if results.get('confusion_matrix') is not None and results.get('top_cms_codes'):
        plt.subplot(2, 2, 4)
        cm = results['confusion_matrix']
        top_codes = results['top_cms_codes']
        
        # Limit to first 10 codes
        if len(top_codes) > 10:
            cm = cm[:10, :10]
            top_codes = top_codes[:10]
        
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=top_codes, yticklabels=top_codes)
        plt.title('Confusion Matrix (Top CMS Codes)')
        plt.ylabel('True CMS Code')
        plt.xlabel('Predicted CMS Code')
    
    plt.tight_layout()
    plt.savefig('diagnosis_accuracy_analysis.png')
    plt.show()
    
    return 'Analysis results saved to diagnosis_accuracy_analysis.png'

# Main analysis function
def analyze_diagnosis_accuracy_optimized(df, cms_df):
    print("Starting optimized analysis of diagnosis accuracy...")
    start_time = time.time()
    
    # Run the optimized evaluation
    results = evaluate_diagnoses_optimized(df, cms_df)
    
    # Print summary statistics
    print("\nSummary Statistics:")
    print(f"Total records analyzed: {len(df)}")
    print(f"Overall accuracy (exact CMS code match): {results['overall_accuracy']:.2%}")
    print(f"Category-level accuracy (first 3 chars): {results['category_accuracy']:.2%}")
    
    if 'high_confidence_accuracy' in results:
        print(f"High confidence accuracy: {results['high_confidence_accuracy']:.2%}")
    
    # Visualize the results
    print("\nGenerating visualizations...")
    vis_result = visualize_results_simple(results)
    print(vis_result)
    
    # Calculate runtime
    elapsed_time = time.time() - start_time
    print(f"Analysis completed in {elapsed_time:.2f} seconds")
    
    return results

# Example usage
# results = analyze_diagnosis_accuracy_optimized(df, cms)

In [8]:
# Import required libraries
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, IFrame, HTML
import warnings
warnings.filterwarnings('ignore')

In [9]:
import pandas as pd
import numpy as np
import re
from fuzzywuzzy import fuzz
from tqdm.notebook import tqdm  # For progress bar in notebooks
import time

class OptimizedDiagnosticEvaluator:
    def __init__(self):
        # Common medical acronyms and variations 
        self.acronym_expansions = {
            'mi': 'myocardial infarction',
            'ami': 'acute myocardial infarction',
            'uti': 'urinary tract infection',
            'aki': 'acute kidney injury',
            'arf': ['acute renal failure', 'acute respiratory failure'],
            'cad': 'coronary artery disease',
            'copd': 'chronic obstructive pulmonary disease',
            'ckd': 'chronic kidney disease',
            'chf': 'congestive heart failure',
            'hf': 'heart failure',
            'esrd': 'end stage renal disease',
            'cva': 'cerebrovascular accident',
            'tia': 'transient ischemic attack',
            'dvt': 'deep vein thrombosis',
            'pe': 'pulmonary embolism',
            'gi': 'gastrointestinal',
            'af': 'atrial fibrillation',
            'afib': 'atrial fibrillation',
            'cap': 'community acquired pneumonia',
            'hap': 'hospital acquired pneumonia',
            'ams': 'altered mental status',
            'sbo': 'small bowel obstruction',
            'lbo': 'large bowel obstruction',
            'osa': 'obstructive sleep apnea',
            'ms': 'multiple sclerosis'
        }
        
        # Modifiers to be normalized or removed
        self.modifiers_to_normalize = [
            'acute', 'chronic', 'unspecified', 'recurrent', 'initial', 
            'primary', 'secondary', 'type', 'encounter', 'due to', 'with',
            'without', 'specified', 'mild', 'moderate', 'severe'
        ]
        
        # Body system and HCC mapping derived from CMS categories
        self.body_system_mappings = {
            'A': 'Infectious and Parasitic Diseases',
            'B': 'Infectious and Parasitic Diseases',
            'C': 'Neoplasms',
            'D': 'Blood and Immune Systems',
            'E': 'Endocrine and Metabolic',
            'F': 'Mental and Behavioral',
            'G': 'Nervous System',
            'H': 'Eye and Ear',
            'I': 'Circulatory System',
            'J': 'Respiratory System',
            'K': 'Digestive System',
            'L': 'Skin and Subcutaneous',
            'M': 'Musculoskeletal System',
            'N': 'Genitourinary System',
            'O': 'Pregnancy and Childbirth',
            'P': 'Perinatal Period',
            'Q': 'Congenital Malformations',
            'R': 'Symptoms and Signs',
            'S': 'Injury and Poisoning',
            'T': 'Injury and Poisoning',
            'V': 'External Causes',
            'W': 'External Causes',
            'X': 'External Causes',
            'Y': 'External Causes',
            'Z': 'Health Status Factors'
        }
        
        # Clinical clusters based on analysis of variations.csv
        self.clinical_clusters = {
            'chest_pain': ['chest pain', 'acute coronary syndrome', 'angina', 'myocardial infarction', 'mi', 'cardiac chest pain'],
            'altered_mental_status': ['altered mental status', 'encephalopathy', 'confusion', 'delirium', 'ams', 'acute confusion'],
            'headache': ['headache', 'migraine', 'tension headache', 'cluster headache', 'intracranial hypertension'],
            'respiratory_distress': ['shortness of breath', 'dyspnea', 'respiratory failure', 'hypoxia', 'breathing difficulty', 'respiratory distress'],
            'gi_bleed': ['gi bleed', 'gastrointestinal hemorrhage', 'melena', 'hematochezia', 'upper gi bleed', 'lower gi bleed'],
            'pneumonia': ['pneumonia', 'cap', 'hap', 'aspiration pneumonia', 'bilateral pneumonia', 'lobar pneumonia'],
            'heart_failure': ['heart failure', 'chf', 'congestive heart failure', 'acute heart failure', 'cardiac failure', 'left heart failure'],
            'kidney_disease': ['acute kidney injury', 'chronic kidney disease', 'aki', 'ckd', 'renal failure', 'esrd', 'kidney failure'],
            'uti': ['uti', 'urinary tract infection', 'cystitis', 'pyelonephritis', 'urinary infection'],
            'stroke': ['stroke', 'cva', 'cerebrovascular accident', 'cerebral infarction', 'brain attack', 'ischemic stroke'],
            'abdominal_pain': ['abdominal pain', 'belly pain', 'epigastric pain', 'flank pain', 'periumbilical pain'],
            'sepsis': ['sepsis', 'septic shock', 'bacteremia', 'severe sepsis', 'systemic inflammatory response syndrome', 'sirs'],
            'copd': ['copd', 'chronic obstructive pulmonary disease', 'emphysema', 'chronic bronchitis', 'copd exacerbation'],
            'bowel_obstruction': ['bowel obstruction', 'sbo', 'small bowel obstruction', 'ileus', 'intestinal obstruction', 'lbo', 'large bowel obstruction']
        }
        
        # Keyword to body system mapping (used when ICD-10 code isn't available)
        self.keyword_body_system = {
            'heart': 'Circulatory System',
            'cardiac': 'Circulatory System',
            'coronary': 'Circulatory System',
            'myocardial': 'Circulatory System',
            'infarction': 'Circulatory System',
            'stroke': 'Circulatory System',
            'cva': 'Circulatory System',
            'lung': 'Respiratory System',
            'pulmonary': 'Respiratory System',
            'pneumonia': 'Respiratory System',
            'respiratory': 'Respiratory System',
            'copd': 'Respiratory System',
            'asthma': 'Respiratory System',
            'kidney': 'Genitourinary System',
            'renal': 'Genitourinary System',
            'urinary': 'Genitourinary System',
            'uti': 'Genitourinary System',
            'liver': 'Digestive System',
            'bowel': 'Digestive System',
            'intestinal': 'Digestive System',
            'abdominal': 'Digestive System',
            'gastric': 'Digestive System',
            'gi': 'Digestive System',
            'brain': 'Nervous System',
            'neuro': 'Nervous System',
            'mental': 'Mental and Behavioral',
            'psychiatric': 'Mental and Behavioral',
            'depression': 'Mental and Behavioral',
            'anxiety': 'Mental and Behavioral',
            'joint': 'Musculoskeletal System',
            'bone': 'Musculoskeletal System',
            'fracture': 'Musculoskeletal System',
            'arthritis': 'Musculoskeletal System',
            'skin': 'Skin and Subcutaneous',
            'dermatitis': 'Skin and Subcutaneous',
            'cellulitis': 'Skin and Subcutaneous',
            'cancer': 'Neoplasms',
            'tumor': 'Neoplasms',
            'malignant': 'Neoplasms',
            'infection': 'Infectious and Parasitic Diseases',
            'sepsis': 'Infectious and Parasitic Diseases'
        }
        
        # Caches to avoid redundant computations
        self.normalization_cache = {}
        self.icd_hcc_cache = {}
        self.cluster_cache = {}
        self.body_system_cache = {}
        
    def normalize_diagnosis(self, diagnosis):
        """
        Apply comprehensive normalization to diagnosis strings, with caching
        """
        if not isinstance(diagnosis, str) or not diagnosis:
            return ""
            
        # Check if already in cache
        if diagnosis in self.normalization_cache:
            return self.normalization_cache[diagnosis]
        
        # Convert to lowercase
        text = diagnosis.lower()
        
        # Take only the first part of comma-separated diagnoses
        if ',' in text:
            text = text.split(',')[0].strip()
        
        # Remove standard format strings
        text = re.sub(r'\(cms code\)', '', text)
        
        # Remove parenthetical content
        text = re.sub(r'\([^)]*\)', '', text)
        
        # Replace common acronyms with their expanded forms
        for acronym, expansion in self.acronym_expansions.items():
            if isinstance(expansion, list):
                # For acronyms with multiple possible expansions, we need context
                if acronym == 'arf' and ('kidney' in text or 'renal' in text):
                    text = re.sub(r'\b' + acronym + r'\b', 'acute renal failure', text)
                elif acronym == 'arf' and ('lung' in text or 'respiratory' in text or 'breathing' in text):
                    text = re.sub(r'\b' + acronym + r'\b', 'acute respiratory failure', text)
            else:
                text = re.sub(r'\b' + acronym + r'\b', expansion, text)
        
        # Remove modifiers
        for modifier in self.modifiers_to_normalize:
            text = re.sub(r'\b' + modifier + r'\b', '', text)
        
        # Remove punctuation and special characters
        text = re.sub(r'[^\w\s]', ' ', text)
        text = re.sub(r'\s+', ' ', text).strip()
        
        # Store in cache
        self.normalization_cache[diagnosis] = text
        
        return text

    def prepare_cms_mapping(self, cms_df):
        """
        Preprocess CMS DataFrame to optimize lookups
        """
        if cms_df is None:
            return None
            
        # Create a lookup dictionary for faster mapping
        lookup = {}
        
        for _, row in cms_df.iterrows():
            description = row.get('description', '')
            if not isinstance(description, str) or not description:
                continue
                
            # Normalize the description for better matching
            norm_desc = description.lower()
            
            # Get ICD-10 code
            icd = row.get('dgns_cd')
            
            # Get HCC category (try multiple columns)
            hcc = row.get('hcc')
            if pd.isna(hcc) or str(hcc).lower() == 'nan':
                for alt_col in ['hcc22', 'hccesrdv21', 'hccrx']:
                    if alt_col in row and not pd.isna(row[alt_col]) and str(row[alt_col]).lower() != 'nan':
                        hcc = row[alt_col]
                        break
            
            # Add to lookup dictionary
            lookup[norm_desc] = (icd, hcc)
        
        return lookup
    
    def map_to_icd_and_hcc(self, diagnosis, cms_lookup):
        """
        Map diagnosis to ICD-10 code and HCC category using preprocessed CMS data
        """
        if not cms_lookup or not isinstance(diagnosis, str) or not diagnosis:
            return None, None
            
        # Check if already in cache
        if diagnosis in self.icd_hcc_cache:
            return self.icd_hcc_cache[diagnosis]
            
        # Normalize the diagnosis
        normalized = self.normalize_diagnosis(diagnosis)
        
        # Try exact match first
        if normalized in cms_lookup:
            result = cms_lookup[normalized]
            self.icd_hcc_cache[diagnosis] = result
            return result
            
        # Not found in exact match, use basic keyword matching
        # This is much faster than fuzzy matching but less accurate
        best_match = None
        max_words = 0
        
        words = set(normalized.split())
        if not words:
            return None, None
            
        for desc in cms_lookup.keys():
            desc_words = set(desc.split())
            common_words = words.intersection(desc_words)
            
            if len(common_words) > max_words:
                max_words = len(common_words)
                best_match = desc
        
        # If we found a reasonable match, return it
        if best_match and max_words >= min(2, len(words)):
            result = cms_lookup[best_match]
            self.icd_hcc_cache[diagnosis] = result
            return result
            
        # No good match found
        self.icd_hcc_cache[diagnosis] = (None, None)
        return None, None
    
    def map_to_clinical_cluster(self, diagnosis):
        """
        Map a diagnosis to a clinical cluster based on keyword matching
        """
        if not isinstance(diagnosis, str) or not diagnosis:
            return None
            
        # Check if already in cache
        if diagnosis in self.cluster_cache:
            return self.cluster_cache[diagnosis]
            
        normalized = self.normalize_diagnosis(diagnosis)
        
        # Simple approach: check if any cluster keywords are in the normalized diagnosis
        for cluster_name, variations in self.clinical_clusters.items():
            for variation in variations:
                # If the variation is a substantial part of the diagnosis
                if variation in normalized or normalized in variation:
                    self.cluster_cache[diagnosis] = cluster_name
                    return cluster_name
        
        # No cluster found
        self.cluster_cache[diagnosis] = None
        return None
    
    def map_to_body_system(self, diagnosis, cms_lookup=None):
        """
        Map diagnosis to a body system using ICD code or keyword matching
        """
        if not isinstance(diagnosis, str) or not diagnosis:
            return None
            
        # Check if already in cache
        if diagnosis in self.body_system_cache:
            return self.body_system_cache[diagnosis]
            
        # Try using CMS data first if available
        if cms_lookup is not None:
            icd_code, _ = self.map_to_icd_and_hcc(diagnosis, cms_lookup)
            if icd_code and isinstance(icd_code, str) and len(icd_code) > 0:
                system_code = icd_code[0].upper()
                body_system = self.body_system_mappings.get(system_code)
                if body_system:
                    self.body_system_cache[diagnosis] = body_system
                    return body_system
        
        # Fallback to keyword matching
        normalized = self.normalize_diagnosis(diagnosis)
        words = normalized.split()
        
        for word in words:
            for keyword, system in self.keyword_body_system.items():
                if keyword in word:
                    self.body_system_cache[diagnosis] = system
                    return system
        
        # Try matching to a clinical cluster
        cluster = self.map_to_clinical_cluster(normalized)
        if cluster:
            # Map the cluster to a body system
            cluster_to_system = {
                'chest_pain': 'Circulatory System',
                'altered_mental_status': 'Nervous System',
                'headache': 'Nervous System',
                'respiratory_distress': 'Respiratory System',
                'gi_bleed': 'Digestive System',
                'pneumonia': 'Respiratory System',
                'heart_failure': 'Circulatory System',
                'kidney_disease': 'Genitourinary System',
                'uti': 'Genitourinary System',
                'stroke': 'Circulatory System',
                'abdominal_pain': 'Digestive System',
                'sepsis': 'Infectious and Parasitic Diseases',
                'copd': 'Respiratory System',
                'bowel_obstruction': 'Digestive System'
            }
            system = cluster_to_system.get(cluster)
            if system:
                self.body_system_cache[diagnosis] = system
                return system
        
        # Default classification
        self.body_system_cache[diagnosis] = 'Unclassified'
        return 'Unclassified'
    
    def evaluate_predictions(self, df, cms_df=None, sample_size=None):
        """
        Efficiently evaluate diagnostic predictions with metrics
        
        Parameters:
        -----------
        df : DataFrame
            DataFrame with 'primaryeddiagnosisname' and 'Predicted_Diagnosis' columns
        cms_df : DataFrame, optional
            DataFrame with CMS crosswalk data
        sample_size : int, optional
            Number of rows to sample for faster processing (None = use all)
        
        Returns:
        --------
        metrics : dict
            Dictionary of evaluation metrics
        """
        start_time = time.time()
        print("Starting evaluation...")
        
        # Sample data if requested
        if sample_size is not None and sample_size < len(df):
            eval_df = df.sample(sample_size, random_state=42).copy()
            print(f"Using sample of {sample_size} records from {len(df)} total")
        else:
            eval_df = df.copy()
        
        # Initialize metrics dictionary
        metrics = {}
        
        # Preprocess CMS data for faster lookups if available
        cms_lookup = None
        if cms_df is not None:
            print("Preprocessing CMS data...")
            cms_lookup = self.prepare_cms_mapping(cms_df)
            print(f"Created lookup dictionary with {len(cms_lookup) if cms_lookup else 0} entries")
        
        # Add normalized versions of diagnoses (with progress bar)
        print("Normalizing diagnoses...")
        eval_df['normalized_true'] = eval_df['primaryeddiagnosisname'].apply(self.normalize_diagnosis)
        eval_df['normalized_pred'] = eval_df['Predicted_Diagnosis'].apply(self.normalize_diagnosis)
        
        # Calculate exact match metrics
        print("Calculating match metrics...")
        eval_df['raw_exact_match'] = eval_df['primaryeddiagnosisname'] == eval_df['Predicted_Diagnosis']
        metrics['Exact Match Accuracy'] = f"{eval_df['raw_exact_match'].mean():.2%}"
        
        eval_df['normalized_exact_match'] = eval_df['normalized_true'] == eval_df['normalized_pred']
        metrics['Normalized Match Accuracy'] = f"{eval_df['normalized_exact_match'].mean():.2%}"
        
        # Initialize additional match columns
        eval_df['cluster_match'] = False
        eval_df['body_system_match'] = False
        eval_df['icd_match'] = False
        eval_df['hcc_match'] = False
        
        # Process in batches for better progress tracking
        batch_size = 100
        num_batches = (len(eval_df) + batch_size - 1) // batch_size
        
        # Pre-allocate arrays for results
        true_clusters = np.empty(len(eval_df), dtype=object)
        pred_clusters = np.empty(len(eval_df), dtype=object)
        true_body_systems = np.empty(len(eval_df), dtype=object)
        pred_body_systems = np.empty(len(eval_df), dtype=object)
        
        # Additionally, for CMS-based metrics
        if cms_lookup:
            true_icds = np.empty(len(eval_df), dtype=object)
            pred_icds = np.empty(len(eval_df), dtype=object)
            true_hccs = np.empty(len(eval_df), dtype=object)
            pred_hccs = np.empty(len(eval_df), dtype=object)
        
        print("Mapping diagnoses to clinical categories...")
        for i in tqdm(range(num_batches)):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, len(eval_df))
            
            # Get batch of diagnoses
            batch_true = eval_df['primaryeddiagnosisname'].iloc[start_idx:end_idx].values
            batch_pred = eval_df['Predicted_Diagnosis'].iloc[start_idx:end_idx].values
            
            # Map to clinical clusters
            for j, (true_diag, pred_diag) in enumerate(zip(batch_true, batch_pred)):
                idx = start_idx + j
                
                # Clinical clusters
                true_cluster = self.map_to_clinical_cluster(true_diag)
                pred_cluster = self.map_to_clinical_cluster(pred_diag)
                true_clusters[idx] = true_cluster
                pred_clusters[idx] = pred_cluster
                
                # Body systems
                true_system = self.map_to_body_system(true_diag, cms_lookup)
                pred_system = self.map_to_body_system(pred_diag, cms_lookup)
                true_body_systems[idx] = true_system
                pred_body_systems[idx] = pred_system
                
                # Set match flags
                if true_cluster is not None and pred_cluster is not None:
                    eval_df.loc[start_idx + j, 'cluster_match'] = (true_cluster == pred_cluster)
                
                if true_system is not None and pred_system is not None:
                    eval_df.loc[start_idx + j, 'body_system_match'] = (true_system == pred_system)
                
                # ICD and HCC mapping (only if CMS data is available)
                if cms_lookup:
                    true_icd, true_hcc = self.map_to_icd_and_hcc(true_diag, cms_lookup)
                    pred_icd, pred_hcc = self.map_to_icd_and_hcc(pred_diag, cms_lookup)
                    
                    true_icds[idx] = true_icd
                    pred_icds[idx] = pred_icd
                    true_hccs[idx] = true_hcc
                    pred_hccs[idx] = pred_hcc
                    
                    # Set match flags
                    if true_icd is not None and pred_icd is not None:
                        eval_df.loc[start_idx + j, 'icd_match'] = (true_icd == pred_icd)
                    
                    if true_hcc is not None and pred_hcc is not None:
                        eval_df.loc[start_idx + j, 'hcc_match'] = (true_hcc == pred_hcc)
        
        # Store results in DataFrame for analysis
        eval_df['true_cluster'] = true_clusters
        eval_df['pred_cluster'] = pred_clusters
        eval_df['true_body_system'] = true_body_systems
        eval_df['pred_body_system'] = pred_body_systems
        
        if cms_lookup:
            eval_df['true_icd'] = true_icds
            eval_df['pred_icd'] = pred_icds
            eval_df['true_hcc'] = true_hccs
            eval_df['pred_hcc'] = pred_hccs
        
        # Calculate clinical cluster metrics
        clustered_rows = eval_df[pd.notna(eval_df['true_cluster']) & pd.notna(eval_df['pred_cluster'])]
        if len(clustered_rows) > 0:
            metrics['Clinical Cluster Match Accuracy'] = f"{clustered_rows['cluster_match'].mean():.2%}"
            metrics['Cluster Coverage'] = f"{len(clustered_rows) / len(eval_df):.2%}"
        else:
            metrics['Clinical Cluster Match Accuracy'] = "0.00%"
            metrics['Cluster Coverage'] = "0.00%"
        
        # Calculate body system metrics
        body_system_rows = eval_df[pd.notna(eval_df['true_body_system']) & pd.notna(eval_df['pred_body_system'])]
        if len(body_system_rows) > 0:
            metrics['Body System Match Accuracy'] = f"{body_system_rows['body_system_match'].mean():.2%}"
            metrics['Body System Coverage'] = f"{len(body_system_rows) / len(eval_df):.2%}"
        else:
            metrics['Body System Match Accuracy'] = "0.00%"
            metrics['Body System Coverage'] = "0.00%"
        
        # Calculate ICD and HCC metrics if CMS data was used
        if cms_lookup:
            icd_rows = eval_df[pd.notna(eval_df['true_icd']) & pd.notna(eval_df['pred_icd'])]
            if len(icd_rows) > 0:
                metrics['ICD-10 Code Match Accuracy'] = f"{icd_rows['icd_match'].mean():.2%}"
                metrics['ICD-10 Coverage'] = f"{len(icd_rows) / len(eval_df):.2%}"
            else:
                metrics['ICD-10 Code Match Accuracy'] = "0.00%"
                metrics['ICD-10 Coverage'] = "0.00%"
            
            hcc_rows = eval_df[pd.notna(eval_df['true_hcc']) & pd.notna(eval_df['pred_hcc'])]
            if len(hcc_rows) > 0:
                metrics['HCC Category Match Accuracy'] = f"{hcc_rows['hcc_match'].mean():.2%}"
                metrics['HCC Coverage'] = f"{len(hcc_rows) / len(eval_df):.2%}"
            else:
                metrics['HCC Category Match Accuracy'] = "0.00%"
                metrics['HCC Coverage'] = "0.00%"
        
        # Calculate overall clinical relevance
        print("Calculating clinical relevance...")
        eval_df['clinically_relevant'] = (
            eval_df['normalized_exact_match'] | 
            eval_df['cluster_match'] | 
            eval_df['body_system_match']
        )
        
        if cms_lookup:
            eval_df['clinically_relevant'] = (
                eval_df['clinically_relevant'] | 
                eval_df['icd_match'] | 
                eval_df['hcc_match']
            )
            
        metrics['Overall Clinical Relevance'] = f"{eval_df['clinically_relevant'].mean():.2%}"
        
        # Calculate string similarity for a sample (can be slow for large datasets)
        if len(eval_df) > 1000:
            similarity_sample = eval_df.sample(1000, random_state=42)
            similarity_sample['string_similarity'] = similarity_sample.apply(
                lambda row: fuzz.token_sort_ratio(row['normalized_true'], row['normalized_pred']), 
                axis=1
            )
            metrics['Mean String Similarity'] = f"{similarity_sample['string_similarity'].mean():.1f}"
            metrics['High Similarity (>= 80)'] = f"{(similarity_sample['string_similarity'] >= 80).mean():.2%}"
        else:
            eval_df['string_similarity'] = eval_df.apply(
                lambda row: fuzz.token_sort_ratio(row['normalized_true'], row['normalized_pred']), 
                axis=1
            )
            metrics['Mean String Similarity'] = f"{eval_df['string_similarity'].mean():.1f}"
            metrics['High Similarity (>= 80)'] = f"{(eval_df['string_similarity'] >= 80).mean():.2%}"
        
        end_time = time.time()
        metrics['Evaluation Time'] = f"{end_time - start_time:.2f} seconds"
        
        return metrics, eval_df

def fast_evaluate_diagnoses(df, cms_df=None, sample_size=None):
    """
    Quickly evaluate diagnostic predictions with metrics including ICD and HCC accuracy
    
    Parameters:
    -----------
    df : DataFrame or str
        DataFrame with 'primaryeddiagnosisname' and 'Predicted_Diagnosis' columns,
        or path to a CSV file containing these columns
    cms_df : DataFrame or None, optional
        DataFrame with CMS crosswalk data containing 'dgns_cd', 'description', and 'hcc' columns
    sample_size : int or None, optional
        Number of rows to sample for faster processing (None = use all data)
    
    Returns:
    --------
    metrics : dict
        Dictionary of evaluation metrics
    eval_df : DataFrame
        DataFrame with all evaluation columns
    """
    # Load data if it's a file path
    if isinstance(df, str):
        df = pd.read_csv(df)
    
    # Check for required columns
    required_cols = ['primaryeddiagnosisname', 'Predicted_Diagnosis']
    for col in required_cols:
        if col not in df.columns:
            raise ValueError(f"Missing required column: {col}")
    
    # Check CMS dataframe if provided
    if cms_df is not None:
        print("\nCMS DataFrame Check:")
        print(f"Shape: {cms_df.shape}")
        print(f"Columns: {cms_df.columns.tolist()}")
        
        # Check required columns
        cms_required_cols = ['dgns_cd', 'description']
        cms_missing_cols = [col for col in cms_required_cols if col not in cms_df.columns]
        if cms_missing_cols:
            print(f"WARNING: Missing required columns in CMS dataframe: {cms_missing_cols}")
            cms_df = None
        
        # Check for HCC columns
        hcc_columns = [col for col in cms_df.columns if 'hcc' in col.lower()]
        if not hcc_columns:
            print("WARNING: No HCC columns found in CMS dataframe")
    
    # Create evaluator and run evaluation
    evaluator = OptimizedDiagnosticEvaluator()
    metrics, eval_df = evaluator.evaluate_predictions(df, cms_df, sample_size)
    
    # Print the metrics
    print("\n=== Diagnostic Prediction Evaluation Metrics ===")
    max_len = max(len(key) for key in metrics.keys())
    for key, value in metrics.items():
        print(f"{key.ljust(max_len)}: {value}")
    
    return metrics, eval_df

In [10]:
metrics, eval_df = fast_evaluate_diagnoses(df,cms_df)


CMS DataFrame Check:
Shape: (10768, 9)
Columns: ['dgns_cd', 'description', 'hccesrdv21', 'hcc22', 'hccrxv05', 'hccesrd', 'hcc', 'hccrx', 'fyear']
Starting evaluation...
Preprocessing CMS data...
Created lookup dictionary with 10163 entries
Normalizing diagnoses...
Calculating match metrics...
Mapping diagnoses to clinical categories...


  0%|          | 0/40 [00:00<?, ?it/s]

Calculating clinical relevance...

=== Diagnostic Prediction Evaluation Metrics ===
Exact Match Accuracy           : 0.03%
Normalized Match Accuracy      : 0.58%
Clinical Cluster Match Accuracy: 47.53%
Cluster Coverage               : 9.14%
Body System Match Accuracy     : 11.50%
Body System Coverage           : 100.00%
ICD-10 Code Match Accuracy     : 2.66%
ICD-10 Coverage                : 34.96%
HCC Category Match Accuracy    : 73.44%
HCC Coverage                   : 34.96%
Overall Clinical Relevance     : 35.24%
Mean String Similarity         : 30.4
High Similarity (>= 80)        : 1.00%
Evaluation Time                : 88.40 seconds


In [None]:
### HCC and ICD-10 Code are standarized, and were reported on. T 