# 02 - Exploratory Data Analysis: Diagnoses

## Objective
Analyze the diagnoses (ICD codes) assigned to patients during hospital admissions and understand their relationships with patient outcomes.

## Dataset
- **Source:** MIMIC-IV Clinical Database Demo v2.2
- **Tables:** 
  - `diagnoses_icd.csv` - ICD diagnoses per admission
  - `d_icd_diagnoses.csv` - ICD code dictionary (descriptions)
  - `patients.csv`, `admissions.csv` - For joins

---

## 1. Setup & Imports

In [None]:
# Data manipulation
import pandas as pd
import numpy as np

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Display settings
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)
pd.set_option('display.max_colwidth', 100)
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Warnings
import warnings
warnings.filterwarnings('ignore')

print(" Libraries imported successfully")

## 2. Load Data

In [None]:
# Define data paths
DATA_PATH = '../../../../mimic-iv-demo/mimic-iv-clinical-database-demo-2.2/hosp/'

# Load diagnoses data
diagnoses = pd.read_csv(DATA_PATH + 'diagnoses_icd.csv')
print(f" Diagnoses data loaded: {diagnoses.shape[0]} rows, {diagnoses.shape[1]} columns")

# Load ICD dictionary
icd_dict = pd.read_csv(DATA_PATH + 'd_icd_diagnoses.csv')
print(f" ICD dictionary loaded: {icd_dict.shape[0]} rows, {icd_dict.shape[1]} columns")

# Load patients and admissions for joins
patients = pd.read_csv(DATA_PATH + 'patients.csv')
admissions = pd.read_csv(DATA_PATH + 'admissions.csv')
print(f" Patients data loaded: {patients.shape[0]} rows")
print(f" Admissions data loaded: {admissions.shape[0]} rows")

## 3. Initial Data Exploration

### 3.1 Diagnoses Table Structure

In [None]:
# Display first rows
print("First rows of diagnoses table:")
diagnoses.head(10)

In [None]:
# Data info
diagnoses.info()

In [None]:
# Check for missing values
print("Missing values in diagnoses:")
print(diagnoses.isnull().sum())

In [None]:
# ICD version distribution
print("ICD Version Distribution:")
print(diagnoses['icd_version'].value_counts())
print(f"\nPercentage:")
print(diagnoses['icd_version'].value_counts(normalize=True) * 100)

### 3.2 ICD Dictionary Structure

In [None]:
# Display first rows of ICD dictionary
print("ICD Dictionary Sample:")
icd_dict.head(10)

In [None]:
# Info
icd_dict.info()

## 4. Join Diagnoses with Descriptions

In [None]:
# Merge diagnoses with ICD dictionary to get readable descriptions
diagnoses_full = diagnoses.merge(
    icd_dict[['icd_code', 'icd_version', 'long_title']], 
    on=['icd_code', 'icd_version'], 
    how='left'
)

print(f" Merged diagnoses with descriptions: {diagnoses_full.shape[0]} rows, {diagnoses_full.shape[1]} columns")
print(f"\nRows without description: {diagnoses_full['long_title'].isnull().sum()}")

# Display sample
diagnoses_full.head(10)

## 5. Diagnosis Frequency Analysis

### 5.1 Most Common Diagnoses

In [None]:
# Count diagnoses frequency
diagnosis_counts = diagnoses_full.groupby(['icd_code', 'long_title']).size().reset_index(name='count')
diagnosis_counts = diagnosis_counts.sort_values('count', ascending=False)

print("Top 20 Most Common Diagnoses:")
print(diagnosis_counts.head(20))

In [None]:
# Visualize top 15 diagnoses
fig, ax = plt.subplots(figsize=(14, 8))
top_15 = diagnosis_counts.head(15)

# Shorten titles for better display
top_15['short_title'] = top_15['long_title'].str[:60] + '...'

ax.barh(range(len(top_15)), top_15['count'], color='steelblue', edgecolor='black')
ax.set_yticks(range(len(top_15)))
ax.set_yticklabels(top_15['short_title'])
ax.set_xlabel('Number of Occurrences', fontsize=12)
ax.set_title('Top 15 Most Common Diagnoses', fontsize=14, fontweight='bold')
ax.invert_yaxis()

# Add count labels
for i, v in enumerate(top_15['count']):
    ax.text(v + 1, i, str(v), va='center', fontsize=10)

plt.tight_layout()
plt.show()

### 5.2 Diagnoses per Admission

In [None]:
# Count diagnoses per admission
diagnoses_per_admission = diagnoses.groupby('hadm_id').size().reset_index(name='num_diagnoses')

print("Diagnoses per Admission Statistics:")
print(diagnoses_per_admission['num_diagnoses'].describe())

print(f"\nMax diagnoses in single admission: {diagnoses_per_admission['num_diagnoses'].max()}")
print(f"Admissions with 10+ diagnoses: {(diagnoses_per_admission['num_diagnoses'] >= 10).sum()}")

In [None]:
# Visualize distribution of diagnoses per admission
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Histogram
axes[0].hist(diagnoses_per_admission['num_diagnoses'], bins=30, edgecolor='black', alpha=0.7, color='coral')
axes[0].set_xlabel('Number of Diagnoses')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Distribution of Diagnoses per Admission')
axes[0].axvline(diagnoses_per_admission['num_diagnoses'].mean(), color='red', linestyle='--', 
                label=f'Mean: {diagnoses_per_admission["num_diagnoses"].mean():.1f}')
axes[0].legend()

# Boxplot
axes[1].boxplot(diagnoses_per_admission['num_diagnoses'], vert=True)
axes[1].set_ylabel('Number of Diagnoses')
axes[1].set_title('Diagnoses per Admission (Boxplot)')
axes[1].set_xticklabels(['All Admissions'])

plt.tight_layout()
plt.show()

### 5.3 Diagnoses per Patient

In [None]:
# Count unique diagnoses per patient (across all admissions)
diagnoses_per_patient = diagnoses.groupby('subject_id')['icd_code'].nunique().reset_index(name='unique_diagnoses')

print("Unique Diagnoses per Patient Statistics:")
print(diagnoses_per_patient['unique_diagnoses'].describe())

print(f"\nPatients with 20+ unique diagnoses: {(diagnoses_per_patient['unique_diagnoses'] >= 20).sum()}")

## 6. Comorbidities Analysis

### 6.1 Common Comorbidity Pairs

In [None]:
# Find most common diagnosis pairs (comorbidities)
# Get diagnoses per admission
admission_diagnoses = diagnoses_full.groupby('hadm_id')['long_title'].apply(list).reset_index()

# Create pairs
from itertools import combinations

comorbidity_pairs = []
for idx, row in admission_diagnoses.iterrows():
    diag_list = row['long_title']
    if len(diag_list) >= 2:
        pairs = list(combinations(sorted(diag_list), 2))
        comorbidity_pairs.extend(pairs)

# Count pairs
from collections import Counter
pair_counts = Counter(comorbidity_pairs)
top_pairs = pair_counts.most_common(15)

print("Top 15 Most Common Comorbidity Pairs:")
for i, (pair, count) in enumerate(top_pairs, 1):
    diag1_short = pair[0][:50] + '...' if len(pair[0]) > 50 else pair[0]
    diag2_short = pair[1][:50] + '...' if len(pair[1]) > 50 else pair[1]
    print(f"{i}. [{count}x] {diag1_short} + {diag2_short}")

## 7. Diagnosis Analysis by Patient Outcomes

### 7.1 Join with Admissions and Patient Data

In [None]:
# Prepare admissions data
admissions['admittime'] = pd.to_datetime(admissions['admittime'])
admissions['dischtime'] = pd.to_datetime(admissions['dischtime'])
admissions['los_days'] = (admissions['dischtime'] - admissions['admittime']).dt.total_seconds() / (24 * 3600)

# Merge all tables - Fixed: removed 'subject_id' from admissions merge since it already exists in diagnoses_full
diag_complete = diagnoses_full.merge(
    admissions[['hadm_id', 'admission_type', 'hospital_expire_flag', 'los_days']],
    on='hadm_id',
    how='left'
).merge(
    patients[['subject_id', 'gender', 'anchor_age']],
    on='subject_id',
    how='left'
)

print(f" Complete dataset created: {diag_complete.shape[0]} rows, {diag_complete.shape[1]} columns")
diag_complete.head()

### 7.2 Diagnoses Associated with Mortality

In [None]:
# Calculate mortality rate by diagnosis
mortality_by_diagnosis = diag_complete.groupby(['icd_code', 'long_title']).agg({
    'hospital_expire_flag': ['sum', 'count', 'mean']
}).reset_index()

mortality_by_diagnosis.columns = ['icd_code', 'long_title', 'deaths', 'total_cases', 'mortality_rate']
mortality_by_diagnosis['mortality_rate'] = mortality_by_diagnosis['mortality_rate'] * 100

# Filter diagnoses with at least 5 cases for statistical relevance
mortality_by_diagnosis_filtered = mortality_by_diagnosis[mortality_by_diagnosis['total_cases'] >= 5]

# Sort by mortality rate
mortality_by_diagnosis_filtered = mortality_by_diagnosis_filtered.sort_values('mortality_rate', ascending=False)

print("Top 15 Diagnoses with Highest Mortality Rate (min 5 cases):")
print(mortality_by_diagnosis_filtered.head(15))

In [None]:
# Visualize top diagnoses by mortality
fig, ax = plt.subplots(figsize=(14, 8))
top_mortality = mortality_by_diagnosis_filtered.head(15)

# Shorten titles
top_mortality['short_title'] = top_mortality['long_title'].str[:60] + '...'

colors = plt.cm.Reds(top_mortality['mortality_rate'] / 100)
ax.barh(range(len(top_mortality)), top_mortality['mortality_rate'], color=colors, edgecolor='black')
ax.set_yticks(range(len(top_mortality)))
ax.set_yticklabels(top_mortality['short_title'])
ax.set_xlabel('Mortality Rate (%)', fontsize=12)
ax.set_title('Top 15 Diagnoses with Highest Mortality Rate', fontsize=14, fontweight='bold')
ax.invert_yaxis()

# Add rate labels
for i, v in enumerate(top_mortality['mortality_rate']):
    ax.text(v + 1, i, f'{v:.1f}%', va='center', fontsize=10)

plt.tight_layout()
plt.show()

### 7.3 Diagnoses Associated with Length of Stay

In [None]:
# Calculate average LOS by diagnosis
los_by_diagnosis = diag_complete.groupby(['icd_code', 'long_title']).agg({
    'los_days': ['mean', 'median', 'count']
}).reset_index()

los_by_diagnosis.columns = ['icd_code', 'long_title', 'mean_los', 'median_los', 'count']

# Filter diagnoses with at least 5 cases
los_by_diagnosis_filtered = los_by_diagnosis[los_by_diagnosis['count'] >= 5]
los_by_diagnosis_filtered = los_by_diagnosis_filtered.sort_values('mean_los', ascending=False)

print("Top 15 Diagnoses with Longest Average Length of Stay (min 5 cases):")
print(los_by_diagnosis_filtered.head(15))

In [None]:
# Visualize
fig, ax = plt.subplots(figsize=(14, 8))
top_los = los_by_diagnosis_filtered.head(15)

# Shorten titles
top_los['short_title'] = top_los['long_title'].str[:60] + '...'

ax.barh(range(len(top_los)), top_los['mean_los'], color='darkorange', edgecolor='black')
ax.set_yticks(range(len(top_los)))
ax.set_yticklabels(top_los['short_title'])
ax.set_xlabel('Average Length of Stay (days)', fontsize=12)
ax.set_title('Top 15 Diagnoses with Longest Average Hospital Stay', fontsize=14, fontweight='bold')
ax.invert_yaxis()

# Add LOS labels
for i, v in enumerate(top_los['mean_los']):
    ax.text(v + 0.5, i, f'{v:.1f}d', va='center', fontsize=10)

plt.tight_layout()
plt.show()

## 8. Diagnosis Categories Analysis

### 8.1 ICD Code Categories (First Character)

In [None]:
# Extract first character/digit to categorize diagnoses
# ICD-9: First digit represents broad category
# ICD-10: First letter represents broad category

diag_complete['icd_category'] = diag_complete['icd_code'].str[0]

# ICD-9 categories (numeric)
icd9_categories = {
    '0': 'Infectious Diseases',
    '1': 'Neoplasms',
    '2': 'Endocrine/Metabolic',
    '3': 'Blood Diseases',
    '4': 'Mental Disorders',
    '5': 'Nervous System',
    '6': 'Circulatory System',
    '7': 'Respiratory System',
    '8': 'Digestive System',
    '9': 'Genitourinary System',
    'V': 'Supplementary Classification',
    'E': 'External Causes'
}

# Count by category
category_counts = diag_complete['icd_category'].value_counts()
print("Diagnosis Distribution by ICD Category:")
print(category_counts)

In [None]:
# Visualize category distribution
fig, ax = plt.subplots(figsize=(12, 6))
category_counts.head(15).plot(kind='bar', ax=ax, color='teal', edgecolor='black')
ax.set_title('Diagnosis Distribution by ICD Category (First Character)', fontsize=14, fontweight='bold')
ax.set_xlabel('ICD Category')
ax.set_ylabel('Count')
ax.tick_params(axis='x', rotation=45)
plt.tight_layout()
plt.show()

## 9. Demographics and Diagnoses

### 9.1 Diagnoses by Gender

In [None]:
# Top diagnoses by gender
top_diag_female = diag_complete[diag_complete['gender'] == 'F'].groupby('long_title').size().sort_values(ascending=False).head(10)
top_diag_male = diag_complete[diag_complete['gender'] == 'M'].groupby('long_title').size().sort_values(ascending=False).head(10)

print("Top 10 Diagnoses - FEMALE:")
print(top_diag_female)
print("\nTop 10 Diagnoses - MALE:")
print(top_diag_male)

### 9.2 Diagnoses by Age Group

In [None]:
# Create age groups
diag_complete['age_group'] = pd.cut(diag_complete['anchor_age'], 
                                     bins=[0, 18, 35, 50, 65, 80, 100], 
                                     labels=['0-18', '19-35', '36-50', '51-65', '66-80', '80+'])

# Count diagnoses by age group
age_group_counts = diag_complete['age_group'].value_counts().sort_index()
print("Diagnoses by Age Group:")
print(age_group_counts)

# Visualize
fig, ax = plt.subplots(figsize=(10, 6))
age_group_counts.plot(kind='bar', ax=ax, color='mediumpurple', edgecolor='black')
ax.set_title('Number of Diagnoses by Age Group', fontsize=14, fontweight='bold')
ax.set_xlabel('Age Group')
ax.set_ylabel('Number of Diagnoses')
ax.tick_params(axis='x', rotation=45)
plt.tight_layout()
plt.show()

## 10. Admission Type vs Diagnoses

In [None]:
# Average number of diagnoses by admission type
diag_by_admission_type = diag_complete.groupby(['hadm_id', 'admission_type']).size().reset_index(name='num_diagnoses')
avg_diag_by_type = diag_by_admission_type.groupby('admission_type')['num_diagnoses'].agg(['mean', 'median', 'count'])

print("Average Number of Diagnoses by Admission Type:")
print(avg_diag_by_type.sort_values('mean', ascending=False))

# Visualize
fig, ax = plt.subplots(figsize=(10, 6))
avg_diag_by_type['mean'].sort_values(ascending=False).plot(kind='bar', ax=ax, color='coral', edgecolor='black')
ax.set_title('Average Number of Diagnoses by Admission Type', fontsize=14, fontweight='bold')
ax.set_xlabel('Admission Type')
ax.set_ylabel('Average Number of Diagnoses')
ax.tick_params(axis='x', rotation=45)
plt.tight_layout()
plt.show()

## 11. Key Findings & Summary

In [None]:
print("="*80)
print("KEY FINDINGS - DIAGNOSES EXPLORATORY ANALYSIS")
print("="*80)

print("\n DATASET OVERVIEW:")
print(f"  • Total diagnosis records: {len(diagnoses)}")
print(f"  • Unique ICD codes: {diagnoses['icd_code'].nunique()}")
print(f"  • Unique admissions with diagnoses: {diagnoses['hadm_id'].nunique()}")
print(f"  • ICD-9 codes: {(diagnoses['icd_version'] == 9).sum()} ({(diagnoses['icd_version'] == 9).sum()/len(diagnoses)*100:.1f}%)")
print(f"  • ICD-10 codes: {(diagnoses['icd_version'] == 10).sum()} ({(diagnoses['icd_version'] == 10).sum()/len(diagnoses)*100:.1f}%)")

print("\n DIAGNOSIS PATTERNS:")
print(f"  • Average diagnoses per admission: {diagnoses_per_admission['num_diagnoses'].mean():.2f}")
print(f"  • Median diagnoses per admission: {diagnoses_per_admission['num_diagnoses'].median():.0f}")
print(f"  • Max diagnoses in single admission: {diagnoses_per_admission['num_diagnoses'].max()}")

print("\n MOST COMMON DIAGNOSES:")
top_3_diag = diagnosis_counts.head(3)
for i, row in top_3_diag.iterrows():
    print(f"  {i+1}. {row['long_title'][:70]}... ({row['count']} cases)")

print("\n HIGH MORTALITY DIAGNOSES:")
high_mortality = mortality_by_diagnosis_filtered.head(3)
for idx, row in high_mortality.iterrows():
    print(f"  • {row['long_title'][:70]}... ({row['mortality_rate']:.1f}% mortality)")

print("\n⏱️ LONGEST HOSPITAL STAYS:")
long_stays = los_by_diagnosis_filtered.head(3)
for idx, row in long_stays.iterrows():
    print(f"  • {row['long_title'][:70]}... ({row['mean_los']:.1f} days avg)")

print("\n" + "="*80)
print("="*80)