# Association Rules Mining for Medical Data

This notebook implements association rules mining using the Apriori algorithm to discover medical associations from the MIMIC-3 dataset. We'll mine patterns like "(high BMI ∧ hypertension) ⇒ increased risk of diabetes" and evaluate them using Support, Confidence, and Lift metrics.

## Objectives
1. Load and preprocess MIMIC-3 diagnosis data
2. Prepare data for association rules mining
3. Implement Apriori algorithm
4. Discover medical associations and comorbidities
5. Evaluate rules using Support, Confidence, and Lift metrics
6. Save discovered rules for clinical interpretation


In [6]:
# Import necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from mlxtend.frequent_patterns import apriori, association_rules
from mlxtend.preprocessing import TransactionEncoder
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('default')
sns.set_palette('viridis')

print("Libraries imported successfully!")


Libraries imported successfully!


## 1. Load and Explore MIMIC-3 Data


In [7]:
# Load MIMIC-3 data
try:
    # Try to load preprocessed data first
    diagnoses = pd.read_csv('../src/data/processed/mimic3_diagnoses.csv')
    print(f"Loaded preprocessed diagnoses data: {diagnoses.shape}")
except FileNotFoundError:
    # Load raw MIMIC-3 data if preprocessed not available
    print("Preprocessed data not found. Loading raw MIMIC-3 data...")
    diagnoses = pd.read_csv('../MIMIC-3/DIAGNOSES_ICD.csv')
    diagnoses_icd = pd.read_csv('../MIMIC-3/D_ICD_DIAGNOSES.csv')
    
    # Merge with ICD descriptions
    diagnoses = pd.merge(diagnoses, diagnoses_icd, on='icd9_code', how='left')
    print(f"Loaded raw diagnoses data: {diagnoses.shape}")

# Display basic information
print(f"Dataset shape: {diagnoses.shape}")
print(f"Columns: {diagnoses.columns.tolist()}")
diagnoses.head()


Loaded preprocessed diagnoses data: (1761, 5)
Dataset shape: (1761, 5)
Columns: ['row_id', 'subject_id', 'hadm_id', 'seq_num', 'icd9_code']


Unnamed: 0,row_id,subject_id,hadm_id,seq_num,icd9_code
0,112344,10006,142345,1,99591
1,112345,10006,142345,2,99662
2,112346,10006,142345,3,5672
3,112347,10006,142345,4,40391
4,112348,10006,142345,5,42731


In [8]:
# Check data quality
print("Data quality check:")
print(f"Missing values: {diagnoses.isnull().sum().sum()}")
print(f"Unique patients: {diagnoses['subject_id'].nunique()}")
print(f"Unique diagnoses: {diagnoses['icd9_code'].nunique()}")

# Check most common diagnoses
print("\nTop 10 most common diagnoses:")
diagnosis_counts = diagnoses['icd9_code'].value_counts().head(10)
print(diagnosis_counts)

# Check if we have diagnosis descriptions
if 'short_title' in diagnoses.columns:
    print("\nSample diagnosis descriptions:")
    sample_diagnoses = diagnoses[['icd9_code', 'short_title']].drop_duplicates().head(10)
    print(sample_diagnoses)
else:
    print("No diagnosis descriptions available")


Data quality check:
Missing values: 0
Unique patients: 100
Unique diagnoses: 581

Top 10 most common diagnoses:
icd9_code
4019     53
42731    48
5849     45
4280     39
51881    31
25000    31
2724     29
5990     27
486      26
2859     25
Name: count, dtype: int64
No diagnosis descriptions available
