In [None]:
import numpy as np
import pandas as pd
from scipy.stats import ttest_ind
import seaborn as sns
import matplotlib.pyplot as plt


dataframe_1 = pd.read_csv('.csv')
dataframe_2 = pd.read_csv('.csv')


# features list
features = ['age',...]

# Perform statistical analysis
rows = []
for feat in features:
    a = controls_survived[feat].dropna()
    b = norepi_expired[feat].dropna()
    t_stat, p = ttest_ind(a, b, equal_var=False)
    m = len(features)
    p_bonf = min(p * m, 1.0)
    na, nb = len(a), len(b)
    sa, sb = a.std(ddof=1), b.std(ddof=1)
    s_pooled = np.sqrt(((na-1)*sa**2 + (nb-1)*sb**2) / (na+nb-2))
    d = (a.mean() - b.mean()) / s_pooled
    rows.append((feat, p_bonf, abs(d)))

res = pd.DataFrame(rows, columns=['feature','p_bonf','d_abs'])
res['significant'] = res['p_bonf'] < 0.05
top = res[res['significant']].sort_values('d_abs', ascending=False)
print("\nSignificant features sorted by effect size:")
print(top)

# Prepare data for plotting
features_to_plot = top['feature'].tolist()

# Combine the datasets
norepi_expired['Group'] = 'Expired\nFluids\nNorepi'
controls_survived['Group'] = 'Survived\nFluids'
combined_df = pd.concat([norepi_expired, controls_survived], axis=0)

# Create figure
sns.set_style("ticks")
n_features = len(features_to_plot)
fig, axes = plt.subplots(1, n_features, figsize=(5*n_features, 5))

# If there's only one feature, make axes a list
if n_features == 1:
    axes = [axes]

# Color palettes
box_palette = {'Expired\nFluids\nNorepi':'#f687ce','Survived\nFluids':'#bdd996'}
jitter_palette = {'Expired\nFluids\nNorepi': '#ef51ba', 'Survived\nFluids': '#5a9010'}

# Create violin plots for each significant feature
for ax, feat in zip(axes, features_to_plot):
    # Violin plot
    sns.violinplot(data=combined_df,
                   x='Group', y=feat,
                   ax=ax,
                   hue='Group',
                   palette=box_palette,
                   legend=False,
                   inner='box',
                   scale='width',
                   linewidth=1.5,
                   saturation=1,
                   cut=2)
    
    # Customize violin plot
    for violin in ax.collections:
        violin.set_edgecolor('black')
    
    # Add strip plot
    sns.stripplot(data=combined_df,
                  x='Group', y=feat,
                  ax=ax,
                  hue='Group',
                  palette=jitter_palette,
                  legend=False,
                  size=4,
                  jitter=0.05,
                  dodge=False,
                  alpha=0.2)

    # Style the plot
    ax.grid(False)
    
    # Bold axis labels
    ax.yaxis.set_tick_params(labelsize=20)
    ax.xaxis.set_tick_params(labelsize=20)
    for label in ax.get_yticklabels() + ax.get_xticklabels():
        label.set_fontweight('bold')
    
    # Black spines
    for spine in ax.spines.values():
        spine.set_color('black')
        spine.set_linewidth(1.5)
        spine.set_visible(True)

    # Add statistics
    pval = top.loc[top['feature']==feat, 'p_bonf'].iloc[0]
    dval = top.loc[top['feature']==feat, 'd_abs'].iloc[0]
    
    if pval < 0.005:
        p_text = "p < 0.005"
    elif pval < 0.05:
        p_text = "p < 0.05"
    else:
        p_text = f"p = {pval:.2f}"

    # Add statistical annotation
    ax.text(0.5, 0.85,
            f"{p_text}\nd = {dval:.2f}",
            transform=ax.transAxes,
            ha='center', va='center',
            fontsize=20, fontweight='bold',
            bbox=dict(facecolor='white', alpha=0.6, edgecolor='none'))

    # Set labels
    ax.set_title('', fontweight='bold', fontsize=20)
    ax.set_xlabel('')
    ax.set_ylabel(feat, fontweight='bold', fontsize=20)

    # Style the box plot inside violin
    for artist in ax.artists:
        artist.set_edgecolor('black')
        artist.set_linewidth(1.5)
    
    # Style the lines
    for line in ax.lines:
        line.set_color('black')
        line.set_linewidth(1.5)

plt.tight_layout()
plt.savefig('/home/pkris25/g_journal_april_2024/figure_1/HRV_expired_survied_violinplots.png', dpi=300, bbox_inches='tight')
plt.show()

# Print summary of significant features
print(f"\nNumber of significant features: {len(features_to_plot)}")
for i, (_, row) in enumerate(top.iterrows(), 1):
    print(f"\n{i}. {row['feature']}:")
    print(f"   p-value (Bonferroni): {row['p_bonf']:.2e}")
    print(f"   Effect size (d): {row['d_abs']:.3f}")