In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import matplotlib.gridspec as gridspec
from statsmodels.distributions.empirical_distribution import ECDF
from matplotlib.colors import LinearSegmentedColormap

# Set consistent visualization theme
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("viridis")
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['font.size'] = 12

def convert_to_date(date_str):
    """Convert string date to datetime object"""
    try:
        return pd.to_datetime(date_str, format='%m/%d/%Y')
    except:
        return pd.NaT

def analyze_medication_adherence(medication_code, data):
    """
    Analyze medication adherence patterns with enhanced visualizations
    
    Parameters:
    medication_code (str): The medication code to analyze
    data (DataFrame): The preprocessed medication data
    
    Returns:
    tuple: Processed data with adherence metrics and figure objects
    """
    # Filter data for the specific medication
    med_data = data[data['ATC'] == medication_code].copy()
    
    # Sort and calculate previous prescription date
    med_data = med_data.sort_values(['pnr', 'eksd'])
    med_data['prev_eksd'] = med_data.groupby('pnr')['eksd'].shift(1)
    
    # Remove rows with no previous prescription
    med_data = med_data.dropna(subset=['prev_eksd'])
    
    # Calculate interval between prescriptions
    med_data['event_interval'] = (med_data['eksd'] - med_data['prev_eksd']).dt.days
    
    # Sample one prescription per patient for ECDF analysis
    patient_samples = []
    for pnr, group in med_data.groupby('pnr'):
        if len(group) > 0:
            patient_samples.append(group.sample(n=1))
    
    sample_data = pd.concat(patient_samples, ignore_index=True)
    sample_data = sample_data[['pnr', 'eksd', 'prev_eksd', 'event_interval']]
    
    # Generate ECDF data
    event_intervals = sample_data['event_interval'].values
    ecdf = ECDF(event_intervals)
    x = np.sort(event_intervals)
    y = ecdf(x)
    
    # Create DataFrame for ECDF
    ecdf_df = pd.DataFrame({'interval': x, 'cumulative_prob': y})
    
    # Create enhanced ECDF plot
    ecdf_fig = create_enhanced_ecdf_plot(ecdf_df, medication_code)
    
    # Create log interval density plot
    density_fig = create_log_interval_density_plot(sample_data, medication_code)
    
    # Create silhouette analysis plot
    silhouette_fig, optimal_clusters = create_silhouette_analysis(sample_data, medication_code)
    
    # Create prescription duration boxplot
    boxplot_fig = create_prescription_boxplot(med_data, medication_code)
    
    # Perform clustering on intervals
    clustered_data = perform_kmeans_clustering(ecdf_df, optimal_clusters)
    
    # Merge clustering results with original data
    final_data = med_data.merge(
        clustered_data[['pnr', 'cluster', 'median_interval']], 
        on='pnr', 
        how='left'
    )
    
    # Fill missing values
    most_common_cluster = clustered_data['cluster'].value_counts().index[0]
    median_interval = clustered_data[clustered_data['cluster'] == most_common_cluster]['median_interval'].iloc[0]
    
    final_data['cluster'] = final_data['cluster'].fillna(most_common_cluster)
    final_data['median_interval'] = final_data['median_interval'].fillna(median_interval)
    
    return final_data, (ecdf_fig, density_fig, silhouette_fig, boxplot_fig)

def create_enhanced_ecdf_plot(ecdf_df, medication_code):
    """Create an enhanced ECDF visualization"""
    fig = plt.figure(figsize=(12, 6))
    
    # Create a colormap for gradient
    colors = sns.color_palette("viridis", 256)
    cmap = LinearSegmentedColormap.from_list('custom_cmap', colors)
    
    # Full ECDF
    ax1 = plt.subplot(1, 2, 1)
    points = ax1.scatter(
        ecdf_df['interval'], 
        ecdf_df['cumulative_prob'], 
        c=ecdf_df['cumulative_prob'], 
        cmap=cmap, 
        alpha=0.7,
        s=30
    )
    ax1.plot(ecdf_df['interval'], ecdf_df['cumulative_prob'], color='#333333', alpha=0.5)
    ax1.set_title(f"Complete ECDF: {medication_code}", fontweight='bold')
    ax1.set_xlabel("Prescription Interval (days)")
    ax1.set_ylabel("Cumulative Probability")
    plt.colorbar(points, ax=ax1, label='Probability')
    
    # 80% ECDF
    ecdf_80 = ecdf_df[ecdf_df['cumulative_prob'] <= 0.8].copy()
    ax2 = plt.subplot(1, 2, 2)
    points2 = ax2.scatter(
        ecdf_80['interval'], 
        ecdf_80['cumulative_prob'], 
        c=ecdf_80['cumulative_prob'], 
        cmap=cmap, 
        alpha=0.7,
        s=30
    )
    ax2.plot(ecdf_80['interval'], ecdf_80['cumulative_prob'], color='#333333', alpha=0.5)
    ax2.axhline(y=0.8, color='red', linestyle='--', alpha=0.7)
    ax2.set_title(f"80% ECDF: {medication_code}", fontweight='bold')
    ax2.set_xlabel("Prescription Interval (days)")
    ax2.set_ylabel("Cumulative Probability")
    plt.colorbar(points2, ax=ax2, label='Probability')
    
    plt.tight_layout()
    return fig

def create_log_interval_density_plot(sample_data, medication_code):
    """Create log-transformed interval density plot"""
    # Filter out zero or negative intervals
    valid_intervals = sample_data[sample_data['event_interval'] > 0]['event_interval']
    log_intervals = np.log(valid_intervals)
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Create density plot
    sns.kdeplot(
        log_intervals, 
        fill=True, 
        color="darkblue", 
        alpha=0.7,
        ax=ax
    )
    
    # Add a rugplot to show the actual data points
    sns.rugplot(
        log_intervals, 
        color="darkred", 
        alpha=0.5,
        ax=ax
    )
    
    # Add vertical lines for quartiles
    quartiles = np.percentile(log_intervals, [25, 50, 75])
    colors = ['#ff7f0e', '#2ca02c', '#ff7f0e']
    labels = ['Q1', 'Median', 'Q3']
    
    for q, c, l in zip(quartiles, colors, labels):
        ax.axvline(q, color=c, linestyle='--', alpha=0.8, linewidth=1.5)
        ax.text(q, ax.get_ylim()[1]*0.95, f' {l}', 
                color=c, fontweight='bold', ha='left', va='top')
    
    # Annotate with statistics
    stats_text = (
        f"Mean: {log_intervals.mean():.2f}\n"
        f"Median: {log_intervals.median():.2f}\n"
        f"Std Dev: {log_intervals.std():.2f}"
    )
    
    ax.text(
        0.95, 0.95, stats_text,
        transform=ax.transAxes,
        fontsize=10,
        verticalalignment='top',
        horizontalalignment='right',
        bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
    )
    
    ax.set_title(f"Log(Prescription Interval) Density: {medication_code}", fontweight='bold')
    ax.set_xlabel("Log(Prescription Interval)")
    ax.set_ylabel("Density")
    ax.grid(True, alpha=0.3)
    
    return fig

def create_silhouette_analysis(sample_data, medication_code):
    """Create silhouette analysis for optimal cluster determination"""
    # Prepare log intervals data
    valid_data = sample_data[sample_data['event_interval'] > 0].copy()
    log_intervals = np.log(valid_data['event_interval']).values.reshape(-1, 1)
    
    # Silhouette Score analysis
    np.random.seed(42)
    range_n_clusters = range(2, 10)
    silhouette_scores = []
    
    for n_clusters in range_n_clusters:
        kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
        cluster_labels = kmeans.fit_predict(log_intervals)
        silhouette_avg = silhouette_score(log_intervals, cluster_labels)
        silhouette_scores.append(silhouette_avg)
    
    # Find optimal number of clusters
    optimal_clusters = range_n_clusters[np.argmax(silhouette_scores)]
    
    # Create visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Silhouette scores plot
    ax1.plot(range_n_clusters, silhouette_scores, 'o-', linewidth=2, markersize=8)
    ax1.axvline(x=optimal_clusters, color='red', linestyle='--', alpha=0.7)
    ax1.set_title(f'Silhouette Analysis: {medication_code}', fontweight='bold')
    ax1.set_xlabel('Number of Clusters')
    ax1.set_ylabel('Silhouette Score')
    ax1.grid(True, alpha=0.3)
    ax1.text(
        optimal_clusters, 
        silhouette_scores[optimal_clusters-2], 
        f' Optimal: {optimal_clusters}',
        color='red', 
        fontweight='bold'
    )
    
    # Plot optimal clustering
    kmeans = KMeans(n_clusters=optimal_clusters, random_state=42, n_init=10)
    cluster_labels = kmeans.fit_predict(log_intervals)
    
    # Create a color map
    colors = sns.color_palette("viridis", optimal_clusters)
    
    # Plot the clusters
    for i in range(optimal_clusters):
        cluster_points = log_intervals[cluster_labels == i]
        ax2.scatter(
            cluster_points, 
            np.random.normal(i, 0.1, size=cluster_points.shape[0]),
            c=[colors[i]],
            alpha=0.7,
            s=30,
            label=f'Cluster {i+1}'
        )
    
    ax2.set_title(f'Optimal Clustering (k={optimal_clusters}): {medication_code}', fontweight='bold')
    ax2.set_xlabel('Log(Prescription Interval)')
    ax2.set_yticks(range(optimal_clusters))
    ax2.set_yticklabels([f'Cluster {i+1}' for i in range(optimal_clusters)])
    ax2.legend(loc='upper right')
    
    plt.tight_layout()
    return fig, optimal_clusters

def perform_kmeans_clustering(ecdf_df, n_clusters):
    """Perform KMeans clustering on prescription intervals"""
    # Prepare data
    intervals = ecdf_df['interval'].values.reshape(-1, 1)
    
    # Perform clustering
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    ecdf_df['cluster'] = kmeans.fit_predict(intervals)
    
    # Calculate cluster statistics
    cluster_stats = ecdf_df.groupby('cluster')['interval'].agg([
        ('min_interval', 'min'),
        ('max_interval', 'max'),
        ('median_interval', 'median'),
        ('count', 'count')
    ]).reset_index()
    
    # Merge cluster information back to original data
    clustered_data = ecdf_df.merge(
        cluster_stats[['cluster', 'median_interval']], 
        on='cluster', 
        how='left'
    )
    
    return clustered_data

def create_prescription_boxplot(med_data, medication_code):
    """Create enhanced boxplot of prescription durations by prescription number"""
    # Prepare data
    med_data = med_data.sort_values(['pnr', 'eksd'])
    med_data['prescription_number'] = med_data.groupby('pnr').cumcount() + 1
    
    # Filter to prescriptions with at least 50 patients to ensure reliable statistics
    prescription_counts = med_data['prescription_number'].value_counts()
    valid_prescriptions = prescription_counts[prescription_counts >= 50].index.tolist()
    
    plot_data = med_data[med_data['prescription_number'].isin(valid_prescriptions)].copy()
    plot_data['prescription_number'] = plot_data['prescription_number'].astype(str)
    
    # Calculate patient-level median durations
    patient_medians = med_data.groupby('pnr')['event_interval'].median().reset_index()
    overall_median = patient_medians['event_interval'].median()
    
    # Create figure
    fig, ax = plt.subplots(figsize=(12, 7))
    
    # Create boxplot with distinct colors
    palette = sns.color_palette("viridis", len(valid_prescriptions))
    boxplot = sns.boxplot(
        x='prescription_number', 
        y='event_interval', 
        data=plot_data,
        palette=palette,
        showfliers=False,  # Hide outliers for cleaner visualization
        ax=ax
    )
    
    # Add strip plot (jittered points) for data distribution
    sns.stripplot(
        x='prescription_number', 
        y='event_interval', 
        data=plot_data,
        color='black', 
        size=3, 
        alpha=0.2,
        jitter=True,
        ax=ax
    )
    
    # Add reference line for overall median
    ax.axhline(
        y=overall_median, 
        color='red', 
        linestyle='--', 
        alpha=0.7,
        linewidth=2
    )
    
    # Add text label for the median line
    ax.text(
        len(valid_prescriptions)-1, 
        overall_median*1.05, 
        f'Overall Median: {overall_median:.1f} days',
        color='red', 
        fontweight='bold',
        ha='right'
    )
    
    # Add counts above each box
    for i, prescription in enumerate(valid_prescriptions):
        count = len(plot_data[plot_data['prescription_number'] == str(prescription)])
        ax.text(
            i, 
            ax.get_ylim()[1]*0.95, 
            f'n={count}',
            ha='center',
            fontweight='bold',
            fontsize=9
        )
    
    # Enhance labels and title
    ax.set_title(f'Prescription Intervals by Prescription Number: {medication_code}', fontweight='bold')
    ax.set_xlabel('Prescription Number', fontweight='bold')
    ax.set_ylabel('Interval (days)', fontweight='bold')
    ax.grid(True, axis='y', alpha=0.3)
    
    return fig

def analyze_medication_data(medication_data):
    """Main function to analyze medication data and generate all visualizations"""
    # Preprocess the data
    processed_data = preprocess_data(medication_data)
    
    # Get unique medication codes
    med_codes = processed_data['ATC'].unique()
    
    # Analyze each medication
    results = {}
    for code in med_codes:
        data, figures = analyze_medication_adherence(code, processed_data)
        results[code] = {
            'data': data,
            'figures': figures
        }
        
        # Save figures
        ecdf_fig, density_fig, silhouette_fig, boxplot_fig = figures
        
        ecdf_fig.savefig(f'{code}_ecdf_analysis.png', dpi=300, bbox_inches='tight')
        density_fig.savefig(f'{code}_log_interval_density.png', dpi=300, bbox_inches='tight')
        silhouette_fig.savefig(f'{code}_silhouette_analysis.png', dpi=300, bbox_inches='tight')
        boxplot_fig.savefig(f'{code}_prescription_boxplot.png', dpi=300, bbox_inches='tight')
        
        plt.close(ecdf_fig)
        plt.close(density_fig)
        plt.close(silhouette_fig)
        plt.close(boxplot_fig)
    
    return results

def preprocess_data(data):
    """Preprocess the medication data"""
    processed = data.copy()
    processed.columns = ["pnr", "eksd", "perday", "ATC", "dur_original"]
    processed['eksd'] = processed['eksd'].apply(convert_to_date)
    return processed

def create_combined_dashboard(medication_code, data):
    """Create a combined dashboard of all visualizations for a medication"""
    # Process data and get individual figures
    processed_data, (ecdf_fig, density_fig, silhouette_fig, boxplot_fig) = analyze_medication_adherence(
        medication_code, 
        preprocess_data(data)
    )
    
    # Create dashboard figure
    dashboard = plt.figure(figsize=(20, 16))
    gs = gridspec.GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1])
    
    # Add title to the dashboard
    dashboard.suptitle(f'Medication Adherence Analysis Dashboard: {medication_code}', 
                      fontsize=20, fontweight='bold', y=0.98)
    
    # Copy contents of each figure to the dashboard
    for i, fig in enumerate([ecdf_fig, density_fig, silhouette_fig, boxplot_fig]):
        row, col = divmod(i, 2)
        ax = dashboard.add_subplot(gs[row, col])
        
        # Copy content from source figure to dashboard
        for ax_src in fig.get_axes():
            # Copy lines
            for line in ax_src.get_lines():
                ax.plot(line.get_xdata(), line.get_ydata(), 
                        color=line.get_color(), 
                        linestyle=line.get_linestyle(),
                        linewidth=line.get_linewidth(),
                        marker=line.get_marker(),
                        alpha=line.get_alpha())
            
            # Copy scatter points if any
            for collection in ax_src.collections:
                if isinstance(collection, plt.matplotlib.collections.PathCollection):
                    ax.scatter(collection.get_offsets()[:, 0], 
                              collection.get_offsets()[:, 1],
                              color=collection.get_facecolor()[0],
                              alpha=collection.get_alpha())
            
            # Copy title and labels
            ax.set_title(ax_src.get_title())
            ax.set_xlabel(ax_src.get_xlabel())
            ax.set_ylabel(ax_src.get_ylabel())
            ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.93)
    
    return dashboard

# Example usage
if __name__ == "__main__":
    # Create sample data
    np.random.seed(42)
    sample_size = 2000
    
    # Generate patient IDs with realistic repetition patterns
    patient_ids = np.repeat(np.arange(1000, 1200), np.random.randint(1, 10, 200))
    patient_ids = patient_ids[:sample_size]
    
    # Generate dates with reasonable intervals
    base_date = pd.Timestamp('2020-01-01')
    dates = [base_date + pd.Timedelta(days=np.random.randint(0, 1000)) for _ in range(sample_size)]
    dates = [d.strftime('%m/%d/%Y') for d in dates]
    
    # Create dataframe
    med_events = pd.DataFrame({
        'pnr': patient_ids,
        'eksd': dates,
        'perday': np.random.uniform(1, 3, sample_size),
        'ATC': np.random.choice(['medA', 'medB', 'medC'], sample_size, p=[0.4, 0.4, 0.2]),
        'dur_original': np.random.randint(30, 180, sample_size)
    })
    
    # Sort by patient and date to create realistic prescription sequences
    med_events = med_events.sort_values(['pnr', 'eksd'])
    
    # Run the analysis
    results = analyze_medication_data(med_events)
    
    # Create combined dashboards
    for code in results.keys():
        dashboard = create_combined_dashboard(code, med_events)
        dashboard.savefig(f'{code}_dashboard.png', dpi=300, bbox_inches='tight')
        plt.close(dashboard)
    
    print("Analysis complete. Visualization files saved.")