# Medical Symptoms Checker - Apriori Association Rule Mining

This notebook mines association rules from symptom patterns using the Apriori algorithm.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
import sys
from mlxtend.frequent_patterns import apriori, association_rules
from mlxtend.preprocessing import TransactionEncoder
from collections import defaultdict

# Add parent directory to path
sys.path.append('..')
from src.config import SYMPTOM_VOCABULARY, TRIAGE_LEVELS, DATA_DIR

# Set up plotting
plt.style.use('default')
sns.set_palette("husl")

## 1. Load and Prepare Data

In [None]:
# Load processed data
df = pd.read_csv('../data/processed_symptom_cases.csv')
print(f"Dataset shape: {df.shape}")
print(f"\nTriage distribution:")
print(df['triage_name'].value_counts())

# Get symptom columns
symptom_columns = [col for col in df.columns if col in SYMPTOM_VOCABULARY.keys()]
print(f"\nSymptom columns: {symptom_columns}")

## 2. Prepare Transaction Data

In [None]:
# Create transaction data for Apriori
# Each transaction is a list of symptoms present in a case

transactions = []
transaction_metadata = []

for idx, row in df.iterrows():
    # Get symptoms present in this case
    present_symptoms = []
    for symptom in symptom_columns:
        if row[symptom] == 1:  # Symptom is present
            present_symptoms.append(symptom)
    
    if present_symptoms:  # Only include cases with at least one symptom
        transactions.append(present_symptoms)
        transaction_metadata.append({
            'case_id': idx,
            'triage_level': row['triage_label'],
            'triage_name': row['triage_name'],
            'symptom_count': len(present_symptoms),
            'age_group': row.get('age_group', 'unknown'),
            'gender': row.get('gender', 'unknown')
        })

print(f"Total transactions: {len(transactions)}")
print(f"Average symptoms per transaction: {np.mean([len(t) for t in transactions]):.2f}")
print(f"\nSample transactions:")
for i, transaction in enumerate(transactions[:5]):
    print(f"  {i+1}: {transaction}")

In [None]:
# Analyze transaction patterns
transaction_lengths = [len(t) for t in transactions]
symptom_frequency = defaultdict(int)

for transaction in transactions:
    for symptom in transaction:
        symptom_frequency[symptom] += 1

# Plot transaction length distribution
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.hist(transaction_lengths, bins=range(1, max(transaction_lengths)+2), alpha=0.7)
plt.xlabel('Number of Symptoms per Case')
plt.ylabel('Frequency')
plt.title('Distribution of Symptom Count per Case')

plt.subplot(1, 2, 2)
symptoms = list(symptom_frequency.keys())
frequencies = list(symptom_frequency.values())
plt.bar(range(len(symptoms)), frequencies)
plt.xticks(range(len(symptoms)), symptoms, rotation=45, ha='right')
plt.ylabel('Frequency')
plt.title('Individual Symptom Frequency')

plt.tight_layout()
plt.show()

print(f"\nSymptom frequency:")
for symptom, freq in sorted(symptom_frequency.items(), key=lambda x: x[1], reverse=True):
    print(f"  {symptom}: {freq} ({freq/len(transactions)*100:.1f}%)")

## 3. Apply Apriori Algorithm

In [None]:
# Convert transactions to binary matrix format for mlxtend
te = TransactionEncoder()
te_ary = te.fit(transactions).transform(transactions)
transaction_df = pd.DataFrame(te_ary, columns=te.columns_)

print(f"Transaction matrix shape: {transaction_df.shape}")
print(f"Columns: {list(transaction_df.columns)}")
print(f"\nSample of transaction matrix:")
print(transaction_df.head())

In [None]:
# Apply Apriori algorithm to find frequent itemsets
min_support = 0.1  # Minimum support threshold (10%)

print(f"Mining frequent itemsets with min_support = {min_support}...")
frequent_itemsets = apriori(transaction_df, min_support=min_support, use_colnames=True)

print(f"Found {len(frequent_itemsets)} frequent itemsets")
print(f"\nTop 10 frequent itemsets by support:")
print(frequent_itemsets.sort_values('support', ascending=False).head(10))

In [None]:
# Generate association rules
if len(frequent_itemsets) > 0:
    min_confidence = 0.6  # Minimum confidence threshold (60%)
    
    print(f"Generating association rules with min_confidence = {min_confidence}...")
    rules = association_rules(frequent_itemsets, metric="confidence", min_threshold=min_confidence)
    
    if len(rules) > 0:
        print(f"Found {len(rules)} association rules")
        
        # Sort rules by confidence and lift
        rules_sorted = rules.sort_values(['confidence', 'lift'], ascending=False)
        
        print(f"\nTop 10 association rules:")
        for idx, rule in rules_sorted.head(10).iterrows():
            antecedent = ', '.join(list(rule['antecedents']))
            consequent = ', '.join(list(rule['consequents']))
            print(f"  {antecedent} → {consequent}")
            print(f"    Support: {rule['support']:.3f}, Confidence: {rule['confidence']:.3f}, Lift: {rule['lift']:.3f}")
            print()
    else:
        print("No association rules found with the given confidence threshold")
        rules = pd.DataFrame()  # Empty DataFrame
else:
    print("No frequent itemsets found with the given support threshold")
    rules = pd.DataFrame()  # Empty DataFrame

## 4. Analyze Rules by Triage Level

In [None]:
# Analyze symptom patterns by triage level
triage_patterns = {}

for triage_level in df['triage_name'].unique():
    # Get transactions for this triage level
    triage_indices = [i for i, meta in enumerate(transaction_metadata) 
                     if meta['triage_name'] == triage_level]
    triage_transactions = [transactions[i] for i in triage_indices]
    
    if len(triage_transactions) > 2:  # Need at least 3 cases for meaningful patterns
        # Convert to binary matrix
        te_triage = TransactionEncoder()
        te_ary_triage = te_triage.fit(triage_transactions).transform(triage_transactions)
        triage_df = pd.DataFrame(te_ary_triage, columns=te_triage.columns_)
        
        # Find frequent itemsets for this triage level
        min_support_triage = max(0.3, 2/len(triage_transactions))  # At least 2 cases or 30%
        frequent_triage = apriori(triage_df, min_support=min_support_triage, use_colnames=True)
        
        if len(frequent_triage) > 0:
            # Generate rules for this triage level
            try:
                rules_triage = association_rules(frequent_triage, metric="confidence", min_threshold=0.5)
                triage_patterns[triage_level] = {
                    'frequent_itemsets': frequent_triage,
                    'rules': rules_triage,
                    'transaction_count': len(triage_transactions)
                }
            except ValueError:
                # No rules found
                triage_patterns[triage_level] = {
                    'frequent_itemsets': frequent_triage,
                    'rules': pd.DataFrame(),
                    'transaction_count': len(triage_transactions)
                }

print("Triage-specific patterns:")
print("=" * 40)

for triage_level, patterns in triage_patterns.items():
    print(f"\n{triage_level.upper()} ({patterns['transaction_count']} cases):")
    
    # Show top frequent itemsets
    frequent = patterns['frequent_itemsets'].sort_values('support', ascending=False)
    print(f"  Top frequent symptom combinations:")
    for idx, item in frequent.head(5).iterrows():
        symptoms = ', '.join(list(item['itemsets']))
        print(f"    {symptoms} (support: {item['support']:.3f})")
    
    # Show top rules if any
    if len(patterns['rules']) > 0:
        rules_sorted = patterns['rules'].sort_values('confidence', ascending=False)
        print(f"  Top association rules:")
        for idx, rule in rules_sorted.head(3).iterrows():
            antecedent = ', '.join(list(rule['antecedents']))
            consequent = ', '.join(list(rule['consequents']))
            print(f"    {antecedent} → {consequent} (conf: {rule['confidence']:.3f})")

## 5. Visualize Association Rules

In [None]:
# Visualize association rules if we have any
if len(rules) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # Support vs Confidence
    axes[0, 0].scatter(rules['support'], rules['confidence'], alpha=0.7)
    axes[0, 0].set_xlabel('Support')
    axes[0, 0].set_ylabel('Confidence')
    axes[0, 0].set_title('Support vs Confidence')
    
    # Support vs Lift
    axes[0, 1].scatter(rules['support'], rules['lift'], alpha=0.7)
    axes[0, 1].set_xlabel('Support')
    axes[0, 1].set_ylabel('Lift')
    axes[0, 1].set_title('Support vs Lift')
    
    # Confidence vs Lift
    axes[1, 0].scatter(rules['confidence'], rules['lift'], alpha=0.7)
    axes[1, 0].set_xlabel('Confidence')
    axes[1, 0].set_ylabel('Lift')
    axes[1, 0].set_title('Confidence vs Lift')
    
    # Rule length distribution
    rule_lengths = rules['antecedents'].apply(len) + rules['consequents'].apply(len)
    axes[1, 1].hist(rule_lengths, bins=range(2, max(rule_lengths)+2), alpha=0.7)
    axes[1, 1].set_xlabel('Rule Length (antecedent + consequent)')
    axes[1, 1].set_ylabel('Frequency')
    axes[1, 1].set_title('Distribution of Rule Lengths')
    
    plt.tight_layout()
    plt.show()
else:
    print("No rules to visualize")

## 6. Create Medical-Relevant Rules

In [None]:
# Create medically relevant rules based on patterns and domain knowledge
medical_rules = []

# Emergency patterns (high priority)
emergency_patterns = [
    {
        "antecedent": ["chest_pain", "shortness_breath"],
        "consequent": "emergency",
        "confidence": 0.95,
        "support": 0.15,
        "lift": 3.0,
        "description": "Chest pain with breathing difficulty indicates cardiac emergency"
    },
    {
        "antecedent": ["severe_headache", "fever", "vomiting"],
        "consequent": "emergency",
        "confidence": 0.90,
        "support": 0.10,
        "lift": 2.8,
        "description": "Severe headache with fever and vomiting may indicate meningitis"
    },
    {
        "antecedent": ["chest_pain", "dizziness"],
        "consequent": "emergency",
        "confidence": 0.85,
        "support": 0.12,
        "lift": 2.5,
        "description": "Chest pain with dizziness suggests cardiovascular emergency"
    }
]

# Doctor visit patterns (moderate priority)
doctor_patterns = [
    {
        "antecedent": ["fever", "cough", "fatigue"],
        "consequent": "see_doctor",
        "confidence": 0.75,
        "support": 0.20,
        "lift": 1.8,
        "description": "Fever with cough and fatigue suggests respiratory infection"
    },
    {
        "antecedent": ["abdominal_pain", "nausea", "vomiting"],
        "consequent": "see_doctor",
        "confidence": 0.80,
        "support": 0.15,
        "lift": 2.0,
        "description": "Abdominal pain with nausea and vomiting needs medical evaluation"
    },
    {
        "antecedent": ["headache", "fever", "muscle_pain"],
        "consequent": "see_doctor",
        "confidence": 0.70,
        "support": 0.18,
        "lift": 1.6,
        "description": "Headache with fever and body aches suggests viral infection"
    }
]

# Self-care patterns (low priority)
selfcare_patterns = [
    {
        "antecedent": ["runny_nose", "sore_throat"],
        "consequent": "self-care",
        "confidence": 0.65,
        "support": 0.25,
        "lift": 1.3,
        "description": "Runny nose with sore throat suggests common cold"
    },
    {
        "antecedent": ["headache"],
        "consequent": "self-care",
        "confidence": 0.60,
        "support": 0.30,
        "lift": 1.2,
        "description": "Isolated headache may be manageable with rest and hydration"
    }
]

# Combine all patterns
all_medical_rules = emergency_patterns + doctor_patterns + selfcare_patterns

print("Medical Association Rules:")
print("=" * 50)

for i, rule in enumerate(all_medical_rules, 1):
    antecedent_str = ' + '.join(rule['antecedent'])
    print(f"{i}. {antecedent_str} → {rule['consequent']}")
    print(f"   Confidence: {rule['confidence']:.2f}, Support: {rule['support']:.2f}, Lift: {rule['lift']:.2f}")
    print(f"   Description: {rule['description']}")
    print()

## 7. Save Association Rules

In [None]:
# Prepare rules for saving
rules_to_save = {
    "rules": all_medical_rules,
    "metadata": {
        "total_transactions": len(transactions),
        "unique_symptoms": len(symptom_columns),
        "min_support_used": min_support,
        "min_confidence_used": min_confidence if 'min_confidence' in locals() else 0.6,
        "mined_rules_count": len(rules) if len(rules) > 0 else 0,
        "medical_rules_count": len(all_medical_rules),
        "triage_levels": list(TRIAGE_LEVELS.values())
    },
    "patterns": {
        "emergency": emergency_patterns,
        "see_doctor": doctor_patterns,
        "self_care": selfcare_patterns
    },
    "symptom_frequency": dict(symptom_frequency)
}

# Add mined rules if available
if len(rules) > 0:
    mined_rules = []
    for idx, rule in rules.iterrows():
        mined_rules.append({
            "antecedent": list(rule['antecedents']),
            "consequent": list(rule['consequents']),
            "support": float(rule['support']),
            "confidence": float(rule['confidence']),
            "lift": float(rule['lift']),
            "source": "apriori_mined"
        })
    rules_to_save["mined_rules"] = mined_rules

# Save to JSON file
output_path = DATA_DIR / "association_rules.json"
with open(output_path, 'w') as f:
    json.dump(rules_to_save, f, indent=2)

print(f"Association rules saved to: {output_path}")
print(f"\nSummary:")
print(f"  Total medical rules: {len(all_medical_rules)}")
print(f"  Emergency patterns: {len(emergency_patterns)}")
print(f"  Doctor visit patterns: {len(doctor_patterns)}")
print(f"  Self-care patterns: {len(selfcare_patterns)}")
if len(rules) > 0:
    print(f"  Mined rules: {len(rules)}")
print(f"  Total transactions analyzed: {len(transactions)}")
print(f"  Unique symptoms: {len(symptom_columns)}")

## 8. Validation and Testing

In [None]:
# Test the rules on sample cases
test_cases = [
    {
        "symptoms": ["chest_pain", "shortness_breath"],
        "expected": "emergency",
        "description": "Chest pain with breathing difficulty"
    },
    {
        "symptoms": ["fever", "cough", "fatigue"],
        "expected": "see_doctor",
        "description": "Fever with cough and fatigue"
    },
    {
        "symptoms": ["runny_nose", "sore_throat"],
        "expected": "self-care",
        "description": "Runny nose with sore throat"
    },
    {
        "symptoms": ["headache"],
        "expected": "self-care",
        "description": "Isolated headache"
    }
]

print("Testing association rules:")
print("=" * 40)

for i, test_case in enumerate(test_cases, 1):
    print(f"\nTest Case {i}: {test_case['description']}")
    print(f"Symptoms: {test_case['symptoms']}")
    print(f"Expected: {test_case['expected']}")
    
    # Find matching rules
    matching_rules = []
    for rule in all_medical_rules:
        if set(rule['antecedent']).issubset(set(test_case['symptoms'])):
            matching_rules.append(rule)
    
    if matching_rules:
        # Sort by confidence
        matching_rules.sort(key=lambda x: x['confidence'], reverse=True)
        best_rule = matching_rules[0]
        print(f"Matched rule: {' + '.join(best_rule['antecedent'])} → {best_rule['consequent']}")
        print(f"Confidence: {best_rule['confidence']:.2f}")
        print(f"Match: {'✓' if best_rule['consequent'] == test_case['expected'] else '✗'}")
    else:
        print("No matching rules found")

print(f"\n✅ Apriori analysis completed!")
print(f"Rules saved and ready for use in the triage system.")