# Step 4: Visualization Dashboard

This notebook creates comprehensive visualizations of the AF cohort including:
- Cohort flow diagram
- Demographics charts
- AF episode characteristics
- Medication usage patterns
- Outcomes comparisons
- Temporal trends

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import sys
sys.path.append('..')
from config import *

# Set style
sns.set_style('whitegrid')
sns.set_palette('Set2')
plt.rcParams['figure.figsize'] = (14, 8)
plt.rcParams['font.size'] = 11

%matplotlib inline

## Load Data

In [None]:
# Load cohorts
af_cohort = pd.read_csv(f'../{COHORT_OUTPUT_PATH}/af_cohort_complete.csv',
                        parse_dates=['af_start', 'af_end', 'icu_intime', 'icu_outtime'])
patient_level = pd.read_csv(f'../{COHORT_OUTPUT_PATH}/af_cohort_patient_level.csv',
                            parse_dates=['af_start', 'af_end', 'icu_intime', 'icu_outtime'])

print(f"Loaded {len(af_cohort):,} episodes from {len(patient_level):,} patients")

## 1. Cohort Flow Diagram

In [None]:
# Create cohort flow using Plotly
fig = go.Figure()

# Calculate numbers
n_episodes = len(af_cohort)
n_stays = af_cohort['stay_id'].nunique()
n_admissions = af_cohort['hadm_id'].nunique()
n_patients = af_cohort['subject_id'].nunique()

# Create funnel
fig = go.Figure(go.Funnel(
    y = ["AF Episodes Detected", "Unique ICU Stays", "Unique Admissions", "Unique Patients"],
    x = [n_episodes, n_stays, n_admissions, n_patients],
    textinfo = "value+percent initial",
    marker = {"color": ["royalblue", "lightblue", "lightgreen", "lightcoral"]}
))

fig.update_layout(
    title="Cohort Flow: From AF Episodes to Unique Patients",
    height=500
)

fig.write_html(f"../{FIGURES_OUTPUT_PATH}/cohort_flow.html")
fig.show()
print(f"Saved to {FIGURES_OUTPUT_PATH}/cohort_flow.html")

## 2. Demographics Visualization

In [None]:
# Create demographics subplot
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Age distribution
axes[0, 0].hist(patient_level['age'].dropna(), bins=30, color='skyblue', edgecolor='black')
axes[0, 0].axvline(patient_level['age'].median(), color='red', linestyle='--', linewidth=2, label=f'Median: {patient_level["age"].median():.0f}')
axes[0, 0].set_xlabel('Age (years)')
axes[0, 0].set_ylabel('Number of Patients')
axes[0, 0].set_title('Age Distribution')
axes[0, 0].legend()

# Gender
gender_counts = patient_level['gender'].value_counts()
axes[0, 1].pie(gender_counts.values, labels=gender_counts.index, autopct='%1.1f%%', startangle=90,
               colors=['lightcoral', 'lightskyblue'])
axes[0, 1].set_title('Gender Distribution')

# Race (top 5)
race_counts = patient_level['race'].value_counts().head(5)
axes[1, 0].barh(range(len(race_counts)), race_counts.values, color='mediumseagreen')
axes[1, 0].set_yticks(range(len(race_counts)))
axes[1, 0].set_yticklabels([label[:30] for label in race_counts.index])
axes[1, 0].set_xlabel('Number of Patients')
axes[1, 0].set_title('Top 5 Race Categories')
axes[1, 0].invert_yaxis()

# ICU type
icu_counts = patient_level['first_careunit'].value_counts().head(6)
axes[1, 1].bar(range(len(icu_counts)), icu_counts.values, color='coral')
axes[1, 1].set_xticks(range(len(icu_counts)))
axes[1, 1].set_xticklabels([label[:15] for label in icu_counts.index], rotation=45, ha='right')
axes[1, 1].set_ylabel('Number of Patients')
axes[1, 1].set_title('Top ICU Types')

plt.tight_layout()
plt.savefig(f'../{FIGURES_OUTPUT_PATH}/demographics.png', dpi=300, bbox_inches='tight')
plt.show()
print(f"Saved to {FIGURES_OUTPUT_PATH}/demographics.png")

## 3. AF Episode Characteristics

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# AF duration distribution (log scale)
af_hours_clean = af_cohort['af_hours'][af_cohort['af_hours'] > 0]
axes[0, 0].hist(np.log10(af_hours_clean), bins=40, color='steelblue', edgecolor='black')
axes[0, 0].set_xlabel('AF Duration (log10 hours)')
axes[0, 0].set_ylabel('Number of Episodes')
axes[0, 0].set_title('AF Episode Duration Distribution')
axes[0, 0].axvline(np.log10(af_hours_clean.median()), color='red', linestyle='--', 
                   label=f'Median: {af_hours_clean.median():.1f}h')
axes[0, 0].legend()

# Duration categories
duration_cats = pd.cut(af_cohort['af_hours'], 
                       bins=[0, 1, 6, 24, 48, np.inf],
                       labels=['<1h', '1-6h', '6-24h', '24-48h', '>48h'])
duration_counts = duration_cats.value_counts().sort_index()
axes[0, 1].bar(range(len(duration_counts)), duration_counts.values, color='teal')
axes[0, 1].set_xticks(range(len(duration_counts)))
axes[0, 1].set_xticklabels(duration_counts.index)
axes[0, 1].set_ylabel('Number of Episodes')
axes[0, 1].set_title('AF Duration Categories')
axes[0, 1].set_xlabel('Duration')

# Episodes per stay
episodes_per_stay = af_cohort.groupby('stay_id')['episode_number'].max()
axes[1, 0].hist(episodes_per_stay, bins=range(1, min(episodes_per_stay.max()+2, 20)), 
                color='purple', edgecolor='black', alpha=0.7)
axes[1, 0].set_xlabel('Number of AF Episodes per ICU Stay')
axes[1, 0].set_ylabel('Number of ICU Stays')
axes[1, 0].set_title('Episodes per ICU Stay')

# Time from ICU admit to AF
af_cohort['hours_from_admit'] = (af_cohort['af_start'] - af_cohort['icu_intime']).dt.total_seconds() / 3600
first_eps = af_cohort[af_cohort['episode_number'] == 1]
hours_clean = first_eps['hours_from_admit'][(first_eps['hours_from_admit'] >= 0) & 
                                              (first_eps['hours_from_admit'] <= 168)]  # Cap at 1 week
axes[1, 1].hist(hours_clean, bins=30, color='orange', edgecolor='black')
axes[1, 1].set_xlabel('Hours from ICU Admission to First AF')
axes[1, 1].set_ylabel('Number of Patients')
axes[1, 1].set_title('Timing of First AF Episode')
axes[1, 1].axvline(hours_clean.median(), color='red', linestyle='--',
                  label=f'Median: {hours_clean.median():.1f}h')
axes[1, 1].legend()

plt.tight_layout()
plt.savefig(f'../{FIGURES_OUTPUT_PATH}/af_characteristics.png', dpi=300, bbox_inches='tight')
plt.show()
print(f"Saved to {FIGURES_OUTPUT_PATH}/af_characteristics.png")

## 4. Medication Usage

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Medication combinations (episode-level)
both = (af_cohort['received_antiarrhythmic']) & (af_cohort['received_rate_control'])
aa_only = (af_cohort['received_antiarrhythmic']) & (~af_cohort['received_rate_control'])
rate_only = (~af_cohort['received_antiarrhythmic']) & (af_cohort['received_rate_control'])
neither = (~af_cohort['received_antiarrhythmic']) & (~af_cohort['received_rate_control'])

med_combos = pd.Series({
    'Both': both.sum(),
    'AA only': aa_only.sum(),
    'Rate only': rate_only.sum(),
    'Neither': neither.sum()
})

colors_combo = ['#ff9999', '#66b3ff', '#99ff99', '#ffcc99']
axes[0].pie(med_combos.values, labels=med_combos.index, autopct='%1.1f%%', 
            startangle=90, colors=colors_combo)
axes[0].set_title('Medication Combinations\n(Episode-Level)')

# Patient-level medication exposure
patient_meds = pd.DataFrame({
    'Antiarrhythmic': [patient_level['received_antiarrhythmic'].sum(), 
                       len(patient_level) - patient_level['received_antiarrhythmic'].sum()],
    'Rate Control': [patient_level['received_rate_control'].sum(), 
                     len(patient_level) - patient_level['received_rate_control'].sum()]
}, index=['Received', 'Not Received'])

patient_meds.plot(kind='bar', ax=axes[1], color=['#ff9999', '#66b3ff'])
axes[1].set_title('Medication Exposure (Patient-Level)')
axes[1].set_ylabel('Number of Patients')
axes[1].set_xlabel('')
axes[1].legend(title='Medication Type')
axes[1].set_xticklabels(axes[1].get_xticklabels(), rotation=0)

plt.tight_layout()
plt.savefig(f'../{FIGURES_OUTPUT_PATH}/medication_usage.png', dpi=300, bbox_inches='tight')
plt.show()
print(f"Saved to {FIGURES_OUTPUT_PATH}/medication_usage.png")

## 5. Outcomes Comparison

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Mortality by medication exposure
mort_by_aa = patient_level.groupby('received_antiarrhythmic')['hospital_expire_flag'].mean() * 100
axes[0, 0].bar([0, 1], mort_by_aa.values, color=['lightcoral', 'lightblue'])
axes[0, 0].set_xticks([0, 1])
axes[0, 0].set_xticklabels(['No AA', 'Received AA'])
axes[0, 0].set_ylabel('Hospital Mortality (%)')
axes[0, 0].set_title('Hospital Mortality by Antiarrhythmic Use')
axes[0, 0].set_ylim([0, max(mort_by_aa.values) * 1.2])

# ICU LOS by medication
los_aa = patient_level[patient_level['received_antiarrhythmic'] == True]['icu_los_days'].dropna()
los_no_aa = patient_level[patient_level['received_antiarrhythmic'] == False]['icu_los_days'].dropna()
axes[0, 1].boxplot([los_no_aa, los_aa], labels=['No AA', 'Received AA'])
axes[0, 1].set_ylabel('ICU Length of Stay (days)')
axes[0, 1].set_title('ICU LOS by Antiarrhythmic Use')
axes[0, 1].set_ylim([0, patient_level['icu_los_days'].quantile(0.95)])  # Cap at 95th percentile

# Overall outcomes
outcomes = pd.Series({
    'Hospital\nMortality': patient_level['hospital_expire_flag'].mean() * 100,
    '30-day\nMortality': patient_level['mortality_30day'].mean() * 100
})
axes[1, 0].bar(range(len(outcomes)), outcomes.values, color=['#ff6b6b', '#ee5a6f'])
axes[1, 0].set_xticks(range(len(outcomes)))
axes[1, 0].set_xticklabels(outcomes.index)
axes[1, 0].set_ylabel('Mortality (%)')
axes[1, 0].set_title('Overall Mortality Rates')
axes[1, 0].set_ylim([0, max(outcomes.values) * 1.2])

# SOFA score distribution
sofa_available = patient_level['sofa_24hours'].dropna()
axes[1, 1].hist(sofa_available, bins=20, color='mediumpurple', edgecolor='black')
axes[1, 1].axvline(sofa_available.median(), color='red', linestyle='--',
                  label=f'Median: {sofa_available.median():.0f}')
axes[1, 1].set_xlabel('SOFA Score (first 24h)')
axes[1, 1].set_ylabel('Number of Patients')
axes[1, 1].set_title('SOFA Score Distribution')
axes[1, 1].legend()

plt.tight_layout()
plt.savefig(f'../{FIGURES_OUTPUT_PATH}/outcomes.png', dpi=300, bbox_inches='tight')
plt.show()
print(f"Saved to {FIGURES_OUTPUT_PATH}/outcomes.png")

## 6. Interactive Plotly Dashboard

In [None]:
# Create interactive scatter plot: AF duration vs LOS, colored by mortality
fig = px.scatter(patient_level, 
                 x='af_hours', 
                 y='icu_los_days',
                 color='hospital_expire_flag',
                 hover_data=['age', 'gender', 'sofa_24hours'],
                 labels={'af_hours': 'AF Duration (hours)',
                        'icu_los_days': 'ICU Length of Stay (days)',
                        'hospital_expire_flag': 'Hospital Death'},
                 title='AF Duration vs ICU Length of Stay',
                 color_discrete_map={True: 'red', False: 'blue'})

fig.update_layout(height=600)
fig.write_html(f"../{FIGURES_OUTPUT_PATH}/interactive_scatter.html")
fig.show()
print(f"Saved to {FIGURES_OUTPUT_PATH}/interactive_scatter.html")

## 7. Summary Dashboard

In [None]:
# Create a comprehensive summary figure
fig = plt.figure(figsize=(16, 10))
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

# Title
fig.suptitle('MIMIC-IV Atrial Fibrillation Cohort Summary Dashboard', fontsize=16, fontweight='bold')

# Key metrics text box
ax_text = fig.add_subplot(gs[0, :])
ax_text.axis('off')
metrics_text = f"""
Total Patients: {len(patient_level):,}
Total AF Episodes: {len(af_cohort):,}
Median Age: {patient_level['age'].median():.0f} years
Female: {(patient_level['gender']=='F').sum()}/{len(patient_level)} ({(patient_level['gender']=='F').mean()*100:.1f}%)

Median AF Duration: {af_cohort['af_hours'].median():.1f} hours
Antiarrhythmic Use: {patient_level['received_antiarrhythmic'].sum():,} ({patient_level['received_antiarrhythmic'].mean()*100:.1f}%)
Rate Control Use: {patient_level['received_rate_control'].sum():,} ({patient_level['received_rate_control'].mean()*100:.1f}%)

Hospital Mortality: {patient_level['hospital_expire_flag'].sum():,} ({patient_level['hospital_expire_flag'].mean()*100:.1f}%)
30-day Mortality: {patient_level['mortality_30day'].sum():,} ({patient_level['mortality_30day'].mean()*100:.1f}%)
Median ICU LOS: {patient_level['icu_los_days'].median():.1f} days
"""
ax_text.text(0.1, 0.5, metrics_text, fontsize=12, family='monospace',
            verticalalignment='center', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))

# Age histogram
ax1 = fig.add_subplot(gs[1, 0])
ax1.hist(patient_level['age'].dropna(), bins=20, color='skyblue', edgecolor='black')
ax1.set_xlabel('Age')
ax1.set_title('Age Distribution')

# AF duration
ax2 = fig.add_subplot(gs[1, 1])
duration_cats = pd.cut(af_cohort['af_hours'], bins=[0, 1, 6, 24, 48, np.inf],
                       labels=['<1h', '1-6h', '6-24h', '24-48h', '>48h']).value_counts().sort_index()
ax2.bar(range(len(duration_cats)), duration_cats.values, color='teal')
ax2.set_xticks(range(len(duration_cats)))
ax2.set_xticklabels(duration_cats.index, rotation=45)
ax2.set_title('AF Duration')

# Medications
ax3 = fig.add_subplot(gs[1, 2])
med_data = pd.DataFrame({
    'AA': [patient_level['received_antiarrhythmic'].sum()],
    'Rate': [patient_level['received_rate_control'].sum()]
})
med_data.T.plot(kind='barh', ax=ax3, legend=False, color=['#ff9999', '#66b3ff'])
ax3.set_xlabel('Number of Patients')
ax3.set_title('Medication Exposure')

# Outcomes by medication
ax4 = fig.add_subplot(gs[2, :])
mort_data = pd.DataFrame({
    'No AA': [patient_level[patient_level['received_antiarrhythmic']==False]['hospital_expire_flag'].mean()*100],
    'Received AA': [patient_level[patient_level['received_antiarrhythmic']==True]['hospital_expire_flag'].mean()*100]
})
mort_data.T.plot(kind='bar', ax=ax4, legend=False, color=['lightcoral', 'lightblue'])
ax4.set_ylabel('Hospital Mortality (%)')
ax4.set_title('Mortality by Antiarrhythmic Use')
ax4.set_xticklabels(ax4.get_xticklabels(), rotation=0)

plt.savefig(f'../{FIGURES_OUTPUT_PATH}/summary_dashboard.png', dpi=300, bbox_inches='tight')
plt.show()
print(f"Saved to {FIGURES_OUTPUT_PATH}/summary_dashboard.png")

## Summary

This notebook has created comprehensive visualizations of the AF cohort. All figures are saved to the `outputs/figures` directory in both static (PNG) and interactive (HTML) formats.

### Files Created:
- cohort_flow.html - Interactive cohort flow diagram
- demographics.png - Demographics overview
- af_characteristics.png - AF episode characteristics
- medication_usage.png - Medication usage patterns
- outcomes.png - Outcomes analysis
- interactive_scatter.html - Interactive scatter plot
- summary_dashboard.png - Comprehensive summary dashboard