In [None]:
%load_ext autoreloadfrom pycaret.classification import load_model, predict_model
%autoreload 2
import pandas as pd
import datetime
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
def plot_site_entropy_distribution_median(df, output_file='site_entropy_distribution_median.png'):

    """
    Creates an entropy and probility distribution plot

    Args:
        df: dataframe with probabilistic predictions
        output_file: file with pre-defined format

    """
    
    
    # Preparing data
    score_cols = ['prediction_score_CT', 'prediction_score_PCM', 'prediction_score_PDLC']
    for col in score_cols:
        df[col] = df[col].str.replace(',', '.').astype(float)
    
    # Calculate entropy per sample
    df['entropy'] = df[score_cols].apply(lambda x: entropy(x, base=2), axis=1)
    
    # Calcular  ENTROPIA mediana por sitio
    site_medians = df.groupby('Site')[score_cols + ['entropy']].median()
    
    # Configure the chart style
    plt.style.use('seaborn-darkgrid')
    fig, ax = plt.subplots(figsize=(12, 6))
    
    # Create stacked bar chart
    bottom = np.zeros(len(site_medians))
    colors = ['#4B0082', '#228B22', '#B8860B']  # Colores para CT, PCM, PDLC
    
    # Direct use of site names as x-positions
    x = np.arange(len(site_medians.index))
    
    for i, col in enumerate(score_cols):
        ax.bar(x, site_medians[col], bottom=bottom, 
               label=col.replace('prediction_score_', ''),
               color=colors[i], alpha=0.7)
        bottom += site_medians[col]
    
    # Add entropy line
    ax2 = ax.twinx()
    ax2.plot(x, site_medians['entropy'], 
             color='red', linewidth=2, label='Entropy', 
             marker='o')
    
    # Configure axes and labels
    ax.set_xlabel('Sites', fontsize=12, fontweight='bold')
    ax.set_ylabel('Probability Distribution (Median)', fontsize=12, fontweight='bold')
    ax2.set_ylabel('Entropy (bits)', fontsize=12, fontweight='bold', color='red')
    
    # Rotate x-axis labels
    ax.set_xticks(x)
    ax.set_xticklabels(site_medians.index, rotation=45, ha='right')
    
    # Adjust captions
    lines1, labels1 = ax.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax.legend(lines1 + lines2, labels1 + labels2, 
             loc='upper right', bbox_to_anchor=(1.15, 1))
    
    # Adjust layout to prevent labels from being cut off
    plt.subplots_adjust(bottom=0.2)
    
    # Save plot
    plt.savefig(output_file, bbox_inches='tight')
    plt.close()
    
    # Imprimir estadísticas completas por sitio
    print("\nEstadísticas por sitio:")
    stats = pd.DataFrame({
        'median_CT': site_medians['prediction_score_CT'],
        'median_PCM': site_medians['prediction_score_PCM'],
        'median_PDLC': site_medians['prediction_score_PDLC'],
        'median_entropy': site_medians['entropy'],
        'n_samples': df.groupby('Site').size(),
        'mean_entropy': df.groupby('Site')['entropy'].mean(),
        'std_entropy': df.groupby('Site')['entropy'].std()
    }).round(3)
    
    print(stats)
    
    return stats