In [None]:
# Standard library imports
import os
import warnings
from collections import defaultdict
from functools import cmp_to_key, reduce
from math import ceil

# Third-party library imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import colors as mcolors
from matplotlib.axes._axes import _log as matplotlib_axes_logger
from matplotlib_venn import venn3

from adjustText import adjust_text

# SciPy and related imports
from scipy.interpolate import interp1d
from scipy.stats import ttest_rel, false_discovery_control

# Progress bar handling
from tqdm.notebook import tqdm


In [None]:
# Set constants
warnings.filterwarnings("ignore")
matplotlib_axes_logger.setLevel('ERROR')
plt.rcParams['figure.facecolor'] = 'white'
np.seterr(divide='ignore', invalid='ignore')
def_color  = [x['color'] for x in plt.rcParams['axes.prop_cycle']]
#init_printing()
plt.rc('font', size=14) 
plt.rc('xtick', labelsize=10) 
plt.rc('ytick', labelsize=10) 
plt.rc('axes', titlesize=12) 
plt.rcParams["figure.dpi"] = 100
plt.rcParams["svg.fonttype"] = "none"

# Can we make a reference cohort for the disease that is also matched to age and BMI??

# General functions

In [None]:
def remove_top_right_frame(axes):
    for ax in axes:
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)

def get_test_dataframe(path,file):
    tdf       = pd.read_csv(os.path.join(path, file))
    tdf.index = np.array([np.array(s.replace('[','').replace(')','').split(',')).astype(int).mean() for s in tdf['week']])
    tdf       = tdf.drop(columns=['week'])
    tdf.rename(columns={"age_50": "age", "age50": "age"},inplace=True)
    if '1.w' in path:
        if len(tdf)<140:
            missing_index = np.sort(list(set(np.arange(-59.5,80.5,1)) - set(tdf.index)))
            tdf = tdf.T
            [tdf.insert(int(np.argwhere(np.arange(-59.5,80.5,1) == missing_index[i])[0][0]),missing_index[i],np.nan) for i in range(len(missing_index))]
            tdf = tdf.T
    return tdf

def filter_tests_from_data(data, field="value_at_quantile", min_n=50, min_avg_n=80, max_avg_noise=0.5, max_noise=0.7, min_weeks=-60, max_weeks=80, res=1):
    tests = set()
    num_points = (max_weeks - min_weeks) / res
    for k in data:
        df = data[k][field]
        noise = (df['error'] / df['value'].var(ddof=1) ** 0.5 )
        is_noisy = (noise.mean() > max_avg_noise or noise.max() > max_noise) or noise.isna().any()
        enough_n_vals = df["n"].min() > min_n and df["n"].mean() > min_avg_n
        n_points = len(df.index)
        if (n_points == num_points and enough_n_vals and not is_noisy):
            tests.add(k)
    return tests

# Loading metadata
* 'metadata' - Tests names in various forms, divison to groups and units
* 'labnorm'  - Referenence values for tests

In [None]:
# groups        = ['Liver','Renal','Musculoskeletal','Metabolism','RBCs','Coagulation','Immune','Endocrine']
# subgroups_main= ['Liver','Renal I','Renal II','Musculoskeletal','Metabolism','RBCs I','RBCs II','Coagulation','Immune I','Immune II','Endocrine']
# subgroups_si  = ['Liver','Renal I','Renal II','Musculoskeletal','Metabolism','RBCs I','RBCs II','RBCs III','Coagulation','Immune I','Immune II','Immune III','Endocrine I','Endocrine II','Endocrine III']
metadata      = pd.read_csv(os.path.join('..', 'Clalit Data', 'Metadata.csv'),index_col=0)
labnorm       = pd.read_csv(os.path.join('..', 'Clalit Data', 'LabNorm.csv'))
labnorm_tests = list(metadata.loc[~metadata['LabNorm name'].isna()].index)


# path to lab tests files
path          = os.path.join('..', 'Clalit Data', 'pregnancy.1w')
files         = [file for file in filter(lambda x: os.path.isfile(os.path.join(path, x)), os.listdir(path))]

# all files in path
all_tests     = [file.split('.')[0] for file in files]

# We removed duplicate tests
#all_tests     = np.unique(list(set(all_tests)- set(['POTASSIUM_BLOOD','SODIUM_BLOOD','RDW_CV'])))

# We removed tests without reference values
tests         = sorted(list(set(labnorm_tests).intersection(all_tests)))

In [None]:
def get_age_matched_reference(test,dfs=None, age_series=None):
    if dfs is not None:
        age_series = dfs[test]['age']
    age  = int(np.nanmedian(age_series))
    tdf  = labnorm.loc[(labnorm.age==age)&(labnorm.lab==metadata.loc[test]['LabNorm name'])].fillna(method='ffill').fillna(method='bfill').copy()
    tdf['value'].loc[tdf.value<0] = np.min(tdf['value'].loc[tdf.value>0])
    prob = np.diff(tdf['quantile'])/np.sum(np.diff(tdf['quantile']))
    mu_  = (tdf.value.iloc[1:]*prob).sum()
    sd_  = np.sqrt((prob*((tdf.value.iloc[1:] - mu_)**2)).sum())
    q_to_val = interp1d(tdf['quantile'],tdf['value'])
    return {'quantile':q_to_val,'mean':mu_,'sd':sd_}

def get_value_stats(test,dfs):
    tdf = dfs[test]
    n   = tdf['val_n'] if "val_n" in tdf.columns else tdf['n']
    v   = tdf['val_mean'] if "val_mean" in tdf.columns else tdf['mean']
    sd  = tdf['val_sd'] if "val_sd" in tdf.columns else tdf['sd']
    e   = sd/n
    return pd.DataFrame(np.array([n,v,e]).T,index=tdf.index,columns=['n','value','error'])

def get_quantile_stats(test,dfs, include_age=False):
    tdf = dfs[test]
    n   = tdf['val_n'] if "val_n" in tdf.columns else tdf['n']
    v   = tdf['qval_mean'] if "qval_mean" in tdf.columns else tdf['qmean']
    qsd = tdf['qval_sd'] if "qval_sd" in tdf.columns else tdf['qsd']
    e   = qsd/np.sqrt(n)
    if include_age:
        age = tdf["age"] if "age" in tdf.columns else tdf["age50"]
        df  = pd.DataFrame(np.array([n,v,e, age]).T,index=tdf.index,columns=['n','value','error', 'age'])
    else:
        df  = pd.DataFrame(np.array([n,v,e]).T,index=tdf.index,columns=['n','value','error'])
    return df

def get_value_at_quantile_stats(test,dfs):
    tdf = get_quantile_stats(test,dfs)
    if test in labnorm_tests:
        ref = get_age_matched_reference(test,dfs)
        v   = ref['quantile'](tdf['value'])
        e_p = ref['quantile']((tdf['value']+tdf['error']).clip(0, 1))
        e_m = ref['quantile'](tdf['value']-tdf['error'])
        e   = (e_p -e_m)/2
        return pd.DataFrame(np.array([tdf['n'],v,e]).T,index=tdf.index,columns=['n','value','error'])
    else:
        return pd.DataFrame(index=tdf.index,columns=['n','value','error']) 


def get_test_data(test,dfs, age_in_quantile=False):
    return {'mean'             :get_value_stats(test,dfs),
            'quantile'         :get_quantile_stats(test,dfs,age_in_quantile),
            'value_at_quantile':get_value_at_quantile_stats(test,dfs)}

def get_stats_from_dataset(test,data,stat,period=[-60,80],merge_timepoints = False):
    df       = data[test][stat].copy()
    if df.isna().all().all():
        return
    df['sd'] = df['error']*np.sqrt(df['n'])
 
    split_k = 1
    if merge_timepoints:
        split_k  = np.min([int(np.ceil(np.nanmean((df['error']/(df['value'].std() * 0.15))**2))),7])
        
        if split_k > 1:
            df     = df.reset_index()
            splits = np.array_split(df, int(len(df)/split_k))
            x      = np.array([tdf['index'].mean()                                   for tdf in splits])
            n      = np.array([tdf['n'].sum()                                        for tdf in splits])
            v      = np.array([np.sum(tdf['value']*tdf['n'])/tdf['n'].sum()          for tdf in splits])
            sd     = np.array([np.sqrt(np.sum(tdf['sd']**2*tdf['n'])/tdf['n'].sum()) for tdf in splits])
            e      = sd/np.sqrt(n)
            df     = pd.DataFrame(np.array([n,v,e]).T,columns=['n','value','error'],index=x)
   
    df = df.loc[period[0]:period[1]]
    x  = df.index
    n  = df.n
    v  = df['value']
    e  = df['error']
    return x,v,e,n      

def plot_qunatile_change(ax, test, data, c='k', decorate=True, is_series=False, use_quantiles=False, text_color=None, alpha=None, fontsize=8):
    # qstd   = pd.DataFrame([data[test]['value_at_quantile']['value'].std() for test in tests],index = tests)

    if is_series:
        x, v = data.index, data.values
        ax.plot(x, v, c=c, lw=1, alpha=alpha)
    else:
        if use_quantiles:
            r = get_stats_from_dataset(test,data,'quantile')
            if r is not None:
                x, v, e, n = r
                ax.errorbar(x,v,e ,c=c,lw=1)
        else:
            r = get_stats_from_dataset(test,data,'value_at_quantile')
            if r is not None:
                x,v,e,n = r
    #       ylim_   = np.max(v) + (np.max(v)-np.min(v))*0.2
    #       ylim_1  = np.min(v) - (np.max(v)-np.min(v))*0.2
            ax.errorbar(x,v,e ,c=c,lw=1)
    #     ax.set_ylim([ylim_1,ylim_])
    ax.set_xticks(np.arange(-50,100,50)) ; ax.set_xlabel('Weeks\npostpartum',fontsize=6); ax.set_xlim([-60,80])

    ax.set_facecolor('w')
    remove_top_right_frame([ax])
    ax.tick_params(axis='both', which='major', labelsize=6)
    if decorate:
        ax.axvspan(-38,0,color='dimgray',alpha=0.1,zorder=-20)
#         ax.axvline(-38,color='k',alpha=1,lw=0.5,zorder=-20)
#         ax.axvline(0,color='k',alpha=1,lw=0.5,zorder=-20)
        ax.annotate(metadata.loc[test]['Nice name'],(0.05,1.05),xycoords='axes fraction',fontsize=fontsize,color='k' if text_color is None else text_color)

def data_from_path(path, age_in_quantile=False):
    files    = list(filter(lambda x: os.path.isfile(os.path.join(path, x)), [file for file in os.listdir(path)]))
    all_tests= [file.split('.')[0] for file in files]
    dfs      = dict(zip(all_tests,[get_test_dataframe(path,file) for file in tqdm(files)]))
    data     = {test:get_test_data(test,dfs, age_in_quantile) for test in tqdm(all_tests)}
    return dfs,data

In [None]:
dfs_hnm1,data_hnm1 = data_from_path('../Clalit Data/pregnancy.1w/')
dfs_hnm4,data_hnm4 = data_from_path('../Clalit Data/pregnancy.4w/')
dfs_ecl ,data_ecl  = data_from_path('../Clalit Data/pregnancy.pre-eclampsia.4w/', True)
dfs_gd  ,data_gd   = data_from_path('../Clalit Data/pregnancy.gdm.4w/', True)
dfs_pph ,data_pph  = data_from_path('../Clalit Data/pregnancy.postpartum_hemorrhage.4w/', True)

## Plotting volcano plots: p-value and cohen's d (effect size)

### Find noisty tests

In [None]:
noisy_tests = {"ecl":set(), "pph": set(), "gd": set()}

for test in tests:
    for complication, complication_df in zip(noisy_tests.keys(), (data_ecl[test]["value_at_quantile"],data_pph[test]["value_at_quantile"],data_gd[test]["value_at_quantile"])):
        noise = (complication_df['error'] / complication_df['value'].var(ddof=1) ** 0.5 )
        is_stable = not (noise.mean() > 1.0 or noise.max() > 1.3)
        enough_points = len(noise) == 35  # Number of data points from -60 weeks to 80 weeks at 4 weeks resolution
        enough_n_vals = complication_df["n"].min() > 10 and complication_df["n"].mean() > 90
        if not (is_stable and enough_points and enough_n_vals):
            noisy_tests[complication].add(metadata.loc[test, "Short name"])
        if not is_stable and enough_n_vals and enough_points:
            print(complication, test)
A, B, C = noisy_tests.values()
set_labels = noisy_tests.keys()
venn_labels = {'100': A - B - C, '010': B - A - C, '110': A & B - C,
                   '001': C - A - B, '101': A & C - B, '011': B & C - A, '111': A & B & C}
diagram = venn3(subsets=(len(A - B - C), len(B - A - C), len(A & B - C),
               len(C - A - B), len(A & C - B), len(B & C - A), len(A & B & C)),
      set_labels=set_labels,
      alpha=0.5,
      set_colors=('skyblue', 'lightgreen', 'lightcoral')
               )

for k in venn_labels:
    label = diagram.get_label_by_id(k)
    if label is not None:
        label.set(text="\n".join(venn_labels[k]), size=10)
plt.show()

for k, v in noisy_tests.items():
    print(f"{k}: {', '.join(v)}")

tests_to_remove = reduce(lambda x, y: x | y, noisy_tests.values())
tests_to_remove = list(metadata[metadata["Short name"].isin(tests_to_remove)].index)
#tests_to_remove += ["LUC", "PT_perc", "TBIL", "BILIRUBIN_TOTAL", "MPV", ""]
tests_to_exclude = ["CH"]#metadata[metadata["Short name"].isin(["ACR", "Cl", "Mg", "SG", "UMA", "uCr", "CK", "LDH", "P", "VitD3", "DBIL", "GLB", "IBIL", "AMYL", "BMI", "HDL ratio", "APTT-sec", "PDW", "MPXI", "PCT", "Sed Rate", "CH", "FA", "Hct/Hgb", "FER", "HDW", "Iron", "RDW SD", "E2", "FT3", "FT4", "LH", "TSH"])].index
tests_to_remove += tests_to_exclude
if "HEMOGLOBIN_A1C_CALCULATED" in tests_to_remove:
    tests_to_remove.remove("HEMOGLOBIN_A1C_CALCULATED")

tests_for_volcanos = list(set(tests) - set(tests_to_remove))

@cmp_to_key
def compare_tests_alphabetically(test1, test2):
    group1 = metadata.loc[test1]["Group"]
    group2 = metadata.loc[test2]["Group"]
    if group1 == group2:
        is_more = metadata.loc[test1]['Nice name'] > metadata.loc[test2]['Nice name']
        if not is_more:
            is_equal = metadata.loc[test1]['Nice name'] > metadata.loc[test2]['Nice name']
            return 0 if is_equal else -1
        return 1
    return -1 if group1 < group2 else 1

#tests_ordered_short_name = sorted(tests, key=compare_tests_alphabetically)
tests_ordered_short_name = sorted(tests, key=lambda test_name: metadata.loc[test_name]['Nice name'])
print(len(tests_for_volcanos))


# group_color = {'Coagulation': def_color[4],
#                'Endocrine': def_color[2],
#                'Immune': def_color[0],
#                'Liver': def_color[5],
#                'Metabolism': def_color[1],
#                'Musculoskeletal': def_color[7],
#                'RBCs': def_color[3],
#                'Renal': def_color[6]}




### Volcano grid and Venn

In [None]:
fig_cohensd, axes_cohensd = plt.subplots(ncols=3, nrows=3, figsize=(23,12.5))
fig_cohensd_venn, axes_cohensd_venn = plt.subplots(ncols=3, nrows=1, figsize=(23, 5))

titles = ["Pre-conception", "Gestation", "Postpartum"]

In [None]:
def find_min_lim_ecl(effect_sizes_all_stages, p_vals_all_stages, all_tests):
    all_tests = pd.Series(all_tests)
    m_res = metadata.reset_index()
    tests_to_consider = all_tests[all_tests.isin(m_res.loc[m_res["Short name"].isin(["BPs", "BPd"]), "Test"])]
    rows, cols = np.meshgrid(np.arange(effect_sizes_all_stages.shape[0]), tests_to_consider.index, indexing="ij")
    effect_sizes_all_stages_min = np.full(effect_sizes_all_stages.shape, np.inf)
    p_vals_all_stages_min = np.full(p_vals_all_stages.shape, np.inf)
    effect_sizes_all_stages_min[rows, cols] = effect_sizes_all_stages[rows, cols]
    p_vals_all_stages_min[rows, cols] = p_vals_all_stages[rows, cols]
    return effect_sizes_all_stages_min.min(), p_vals_all_stages_min.min() 

In [None]:
lower_bound = -np.inf
pre_ecl_cohensd_sets = []
all_effect_sizes = []
all_l_p_vals = []
panel_indices = ["A", "D", "G"]

for i, upper_bound in enumerate((-38, 0, np.inf)):
    annotations = []
    ax = axes_cohensd[i, 0]
    p_vals = []
    effect_sizes = []
    if i == 0:
        ax.set_title("Pre-eclampsia", fontsize=18, transform=ax.transAxes, y=1.15, verticalalignment="top", fontweight="bold")
    #ax.set_xlabel("Cohen's d", fontsize=14)
    elif i == 2:
        ax.set_xlabel("Effect Size (Cohen's d)", fontsize=14)
    for test in tests_for_volcanos:
        idx_ecl, v_ecl, _, _ = get_stats_from_dataset(test,data_ecl,'quantile')
        idx_bmi2, v_bmi2, _, _ = get_stats_from_dataset(test,data_hnm4,'quantile')
        indices = idx_ecl.intersection(idx_bmi2)
        indices = indices[(indices < upper_bound) & (indices >= lower_bound)] 
        t_stats = ttest_rel(v_ecl[indices], v_bmi2[indices])
        p_vals.append(t_stats.pvalue)
        # Cohen's d
        effect_size = t_stats.statistic / len(indices) ** 0.5
        effect_sizes.append(effect_size)
    l_p_vals = -np.log10(np.array(false_discovery_control(p_vals)))
    effect_sizes = np.array(effect_sizes)
    all_effect_sizes.append(effect_sizes)
    all_l_p_vals.append(l_p_vals)
    annotated_tests_inds = np.argwhere((np.abs(effect_sizes) >= 2.93) & (l_p_vals > 1.3)).squeeze().reshape(-1)
    negative_effect = np.argwhere(effect_sizes < 0)
    positive_effect = np.argwhere(effect_sizes > 0)
    negative_test_inds = np.intersect1d(annotated_tests_inds, negative_effect)
    positive_test_inds = np.intersect1d(annotated_tests_inds, positive_effect)
    
    colors = np.full(len(tests_for_volcanos), "grey")
    colors[annotated_tests_inds] = "r"
    ax.scatter(effect_sizes, l_p_vals, marker="o", s=14, c=colors)
    ax.set_ylim(0)
    xmin, xmax = ax.get_xlim()
    ax.set_xlim(-max(abs(xmin), abs(xmax)), max(abs(xmin), abs(xmax)))
    ax.spines.right.set_visible(False)
    ax.spines.top.set_visible(False)
    title = titles[i]
    lower_bound = upper_bound
    for ind in annotated_tests_inds:
        annotations.append(ax.text(s=metadata.loc[tests_for_volcanos[ind]]["Short name"], x=effect_sizes[ind], y=l_p_vals[ind], fontfamily='sans-serif',  fontsize=14))
    adjust_text(annotations, arrowprops=dict(linestyle="--", color='k', lw=0.5), force_text=(1., 1.), ax=ax)
    ecl_sig_tests = {k: "↓" for k in metadata.loc[np.array(tests_for_volcanos)[negative_test_inds]]["Short name"]}
    ecl_sig_tests.update({k: "↑" for k in metadata.loc[np.array(tests_for_volcanos)[positive_test_inds]]["Short name"]})
    ecl_sig_tests = set(f"{v}{k}" for k,v in ecl_sig_tests.items())
    print(f"Total significant tests: {len(ecl_sig_tests)}/{len(tests)}. Tests: {', '.join(sorted(list(ecl_sig_tests)))}")
    pre_ecl_cohensd_sets.append(ecl_sig_tests)
    ax.tick_params(labelsize=12)
    ax.set_ylabel("-log p-value", fontsize=12)
    ax.annotate(titles[i], xycoords='axes fraction', xy=(-0.43,0.5), va="center", ha="center", fontsize=16, color="k", fontweight='bold')
    ax.annotate(panel_indices[i], xycoords='axes fraction', xy=(-0.1,1.1), ha="left", fontsize=16, color="k", fontweight='bold')
    # fig.set_size_inches((6, 4))
    # fig.savefig(f"volcano_plot_Pre-eclampsia_{title}_cohens_d.svg", bbox_inches="tight", pad_inches=0.1)
    
#plt.suptitle("Volcano plot: Lab tests for PPH vs BMI 18.5-25, cohen's d")
    
# plt.show()

all_effect_sizes = np.array(all_effect_sizes)
all_l_p_vals = np.array(all_l_p_vals)
print(find_min_lim_ecl(all_effect_sizes, all_l_p_vals, tests_for_volcanos))


# Create a Venn diagram
venn_ax = axes_cohensd_venn[0]
A = pre_ecl_cohensd_sets[0]
B = pre_ecl_cohensd_sets[1]
C = pre_ecl_cohensd_sets[2]

venn_labels = {'100': A - B - C, '010': B - A - C, '110': A & B - C,
               '001': C - A - B, '101': A & C - B, '011': B & C - A, '111': A & B & C}
diagram = venn3(subsets=(len(A - B - C), len(B - A - C), len(A & B - C),
               len(C - A - B), len(A & C - B), len(B & C - A), len(A & B & C)),
      set_labels=titles,
      alpha=0.5,
      ax=venn_ax,
      set_colors=('skyblue', 'lightgreen', 'lightcoral'))

for k in venn_labels:
    label = diagram.get_label_by_id(k)
    if label is not None:
        label_text = list(venn_labels[k])
        label.set(text="\n".join(label_text), fontsize=14)
        
set_annotations = [obj for obj in venn_ax.get_children() if isinstance(obj, plt.Text) and obj.get_text() in titles]
for set_annotation in set_annotations:
    set_annotation.set_size(16)
    set_annotation.set_weight("bold")

venn_ax.annotate('J', xycoords='axes fraction', xy=(0,1.05), ha="left", fontsize=16, color="k", fontweight='bold')

#plt.title("PPH significant lab tests, cohen's d")
#plt.savefig("venn_Pre-eclampsia_cohens_d.svg")

In [None]:
def find_min_lim_pph(effect_sizes_all_stages, p_vals_all_stages, all_tests=tests_for_volcanos):
    all_tests = pd.Series(all_tests)
    m_res = metadata.reset_index()
    tests_to_consider = all_tests[all_tests.isin(m_res.loc[m_res["Short name"].isin(["PLT"]), "Test"])]
    # rbc_test_index = all_tests[(all_tests == m_res.loc[m_res["Short name"] == "RBC", "Test"].item())].index.item()
    # hgb_index = all_tests[(all_tests == m_res.loc[m_res["Short name"] == "Hgb", "Test"].item())].index.item()
    # hct_index = all_tests[(all_tests == m_res.loc[m_res["Short name"] == "Hct", "Test"].item())].index.item()
    # lymph_index = all_tests[(all_tests == m_res.loc[m_res["Short name"] == "LYMP %", "Test"].item())].index.item()
    
    rows, cols = np.meshgrid(np.array([0, 2]), tests_to_consider.index, indexing="ij")
    effect_sizes_all_stages_min = np.full(effect_sizes_all_stages.shape, np.inf)
    p_vals_all_stages_min = np.full(p_vals_all_stages.shape, np.inf)
    effect_sizes_all_stages_min[rows, cols] = np.abs(effect_sizes_all_stages[rows, cols])
    p_vals_all_stages_min[rows, cols] = p_vals_all_stages[rows, cols]
    # Remove HbA1C, sCr, Uric Acid from gestation
    # effect_sizes_all_stages_min[1, [rbc_test_index, hgb_index, hct_index, lymph_index]] = np.inf
    # p_vals_all_stages_min[1, [rbc_test_index, hgb_index, hct_index, lymph_index]] = np.inf
    return effect_sizes_all_stages_min.min(), p_vals_all_stages_min.min() 

In [None]:
lower_bound = -np.inf
pph_cohensd_sets = []
all_effect_sizes = []
all_l_p_vals = []
panel_titles = ["B", "E", "H"]


for i, upper_bound in enumerate((-38, 0, np.inf)):
    pph_valid_tests_t_test = set()
    annotations = []
    ax = axes_cohensd[i, 1]
    p_vals = []
    effect_sizes = []
    if i == 0:
        ax.set_title("Postpartum Hemorrhage", fontsize=18, transform=ax.transAxes, y=1.15, verticalalignment="top", fontweight="bold")
    elif i == 2:
        ax.set_xlabel("Effect Size (Cohen's d)", fontsize=14)
    for test in tests_for_volcanos:
        idx_pph, v_pph, _, _ = get_stats_from_dataset(test,data_pph,'quantile')
        idx_bmi2, v_bmi2, _, _ = get_stats_from_dataset(test,data_hnm4,'quantile')
        indices = idx_pph.intersection(idx_bmi2)
        indices = indices[(indices < upper_bound) & (indices >= lower_bound)] 
        t_stats = ttest_rel(v_pph[indices], v_bmi2[indices])
        p_vals.append(t_stats.pvalue)
        # Cohen's d
        effect_size = t_stats.statistic / len(indices) ** 0.5
        effect_sizes.append(effect_size)
    l_p_vals = -np.log10(np.array(false_discovery_control(p_vals)))
    effect_sizes = np.array(effect_sizes)
    all_effect_sizes.append(effect_sizes)
    all_l_p_vals.append(l_p_vals)
    annotated_tests_inds = np.argwhere((np.abs(effect_sizes) >= 1.3) & (l_p_vals > 1.31)).squeeze().reshape(-1)
    negative_effect = np.argwhere(effect_sizes < 0)
    positive_effect = np.argwhere(effect_sizes > 0)
    negative_test_inds = np.intersect1d(annotated_tests_inds, negative_effect)
    positive_test_inds = np.intersect1d(annotated_tests_inds, positive_effect)
    
    colors = np.full(len(tests_for_volcanos), "grey")
    colors[annotated_tests_inds] = "r"
    ax.scatter(effect_sizes, l_p_vals, marker="o", s=10, c=colors)
    ax.set_ylim(0)
    xmin, xmax = ax.get_xlim()
    ax.set_xlim(-max(abs(xmin), abs(xmax)), max(abs(xmin), abs(xmax)))
    ax.spines.right.set_visible(False)
    ax.spines.top.set_visible(False)
    title = titles[i]
    #ax.set_title(title, fontsize=16)
    lower_bound = upper_bound
    for ind in annotated_tests_inds:
        annotations.append(ax.text(s=metadata.loc[tests_for_volcanos[ind]]["Short name"], x=effect_sizes[ind], y=l_p_vals[ind], fontfamily='sans-serif',  fontsize=14))
        pph_valid_tests_t_test.add(metadata.loc[tests_for_volcanos[ind]]["Short name"])
    adjust_text(annotations, arrowprops=dict(linestyle="--", color='k', lw=0.5), force_text=(1., 1.), ax=ax)
    pph_sig_tests = {k: "↓" for k in metadata.loc[np.array(tests_for_volcanos)[negative_test_inds]]["Short name"]}
    pph_sig_tests.update({k: "↑" for k in metadata.loc[np.array(tests_for_volcanos)[positive_test_inds]]["Short name"]})
    pph_sig_tests = set(f"{v}{k}" for k,v in pph_sig_tests.items())
    print(f"Total significant tests: {len(pph_sig_tests)}/{len(tests)}. Tests: {', '.join(sorted(list(pph_sig_tests)))}")
    pph_cohensd_sets.append(pph_sig_tests)
    ax.annotate(panel_titles[i], xycoords='axes fraction', xy=(-0.1,1.1), ha="left", fontsize=16, color="k", fontweight='bold')
    ax.tick_params(labelsize=12)


print(find_min_lim_pph(np.array(all_effect_sizes), np.array(all_l_p_vals)))


# Create a Venn diagram
venn_ax = axes_cohensd_venn[1]
A = pph_cohensd_sets[0]
B = pph_cohensd_sets[1]
C = pph_cohensd_sets[2]

venn_labels = {'100': A - B - C, '010': B - A - C, '110': A & B - C,
               '001': C - A - B, '101': A & C - B, '011': B & C - A, '111': A & B & C}
diagram = venn3(subsets=(len(A - B - C), len(B - A - C), len(A & B - C),
               len(C - A - B), len(A & C - B), len(B & C - A), len(A & B & C)),
      set_labels=titles,
      alpha=0.5,
      ax=venn_ax,
      set_colors=('skyblue', 'lightgreen', 'lightcoral'))

for k in venn_labels:
    label = diagram.get_label_by_id(k)
    if label is not None:
        label_text = list(venn_labels[k])
        label.set(text="\n".join(label_text), fontsize=14)
        
set_annotations = [obj for obj in venn_ax.get_children() if isinstance(obj, plt.Text) and obj.get_text() in titles]
for set_annotation in set_annotations:
    set_annotation.set_size(16)
    set_annotation.set_weight("bold")

venn_ax.annotate('K', xycoords='axes fraction', xy=(0,1.05), ha="left", fontsize=16, color="k", fontweight='bold')

In [None]:
def find_min_lim_gd(effect_sizes_all_stages, p_vals_all_stages, all_tests):
    all_tests = pd.Series(all_tests)
    m_res = metadata.reset_index()
    tests_to_consider = all_tests[all_tests.isin(m_res.loc[m_res["Short name"].isin(["Glucose", "HbA1c"]), "Test"])]
    hba1c_test_index = all_tests[(all_tests == m_res.loc[m_res["Short name"] == "HbA1c", "Test"].item())].index.item()
    rows, cols = np.meshgrid(np.arange(effect_sizes_all_stages.shape[0]), tests_to_consider.index, indexing="ij")
    effect_sizes_all_stages_min = np.full(effect_sizes_all_stages.shape, np.inf)
    p_vals_all_stages_min = np.full(p_vals_all_stages.shape, np.inf)
    effect_sizes_all_stages_min[rows, cols] = effect_sizes_all_stages[rows, cols]
    p_vals_all_stages_min[rows, cols] = p_vals_all_stages[rows, cols]
    # Remove HbA1C from post partum
    effect_sizes_all_stages_min[2, hba1c_test_index] = np.inf
    p_vals_all_stages_min[2, hba1c_test_index] = np.inf
    return np.abs(effect_sizes_all_stages_min).min(), p_vals_all_stages_min.min() 

In [None]:
lower_bound = -np.inf
gd_cohensd_sets = []
all_effect_sizes = []
all_l_p_vals = []
panel_titles = ["C", "F", "I"]

for i, upper_bound in enumerate((-38, 0, np.inf)):
    annotations = []
    ax = axes_cohensd[i, 2]
    p_vals = []
    effect_sizes = []
    if i == 0:
        ax.set_title("Gestational Diabetes", fontsize=18, transform=ax.transAxes, y=1.15, verticalalignment="top", fontweight="bold")
    elif i == 2:
        ax.set_xlabel("Effect Size (Cohen's d)", fontsize=14)
    for test in tests_for_volcanos:
        idx_gd, v_gd, _, _ = get_stats_from_dataset(test,data_gd,'quantile')
        idx_bmi2, v_bmi2, _, _ = get_stats_from_dataset(test,data_hnm4,'quantile')
        indices = idx_gd.intersection(idx_bmi2)
        indices = indices[(indices < upper_bound) & (indices >= lower_bound)] 
        n_vals = len(indices)
        sum_ranks = n_vals * (n_vals + 1) / 2
        t_stats = ttest_rel(v_gd[indices], v_bmi2[indices])
        p_vals.append(t_stats.pvalue)
        # Cohen's d
        effect_size = t_stats.statistic / len(indices) ** 0.5
        if test == "BPd":
            print(f"{(v_gd[indices]- v_bmi2[indices]).abs().mean() :.3f}")
        effect_sizes.append(effect_size)
    l_p_vals = -np.log10(np.array(false_discovery_control(p_vals)))
    effect_sizes = np.array(effect_sizes)
    negative_effect = np.argwhere(effect_sizes < 0)
    positive_effect = np.argwhere(effect_sizes > 0)
    annotated_tests_inds = np.argwhere((np.abs(effect_sizes) >= 2.93) & (l_p_vals >= 1.31)).squeeze().reshape(-1)
    negative_test_inds = np.intersect1d(annotated_tests_inds, negative_effect)
    positive_test_inds = np.intersect1d(annotated_tests_inds, positive_effect)
    colors = np.full(len(tests_for_volcanos), "grey")
    colors[annotated_tests_inds] = "r"
    ax.scatter(effect_sizes, l_p_vals, marker="o", s=10, c=colors)
    ax.set_ylim(0)
    xmin, xmax = ax.get_xlim()
    ax.set_xlim(-max(abs(xmin), abs(xmax)), max(abs(xmin), abs(xmax)))
    ax.spines.right.set_visible(False)
    ax.spines.top.set_visible(False)
    # ax.set_title(titles[i], fontsize=16)
    lower_bound = upper_bound
    for ind in annotated_tests_inds:
        annotations.append(ax.text(s=metadata.loc[tests_for_volcanos[ind]]["Short name"], x=effect_sizes[ind], y=l_p_vals[ind], fontfamily='sans-serif', fontsize=14))
    adjust_text(annotations, arrowprops=dict(linestyle="--", color='k', lw=0.5), force_text=(1., 1.), ax=ax)
    gd_sig_tests = {k: "↓" for k in metadata.loc[np.array(tests_for_volcanos)[negative_test_inds]]["Short name"]}
    gd_sig_tests.update({k: "↑" for k in metadata.loc[np.array(tests_for_volcanos)[positive_test_inds]]["Short name"]})
    gd_sig_tests = set(f"{v}{k}" for k,v in gd_sig_tests.items())
    print(f"Total significant tests: {len(gd_sig_tests)}/{len(tests_for_volcanos)}. Tests: {', '.join(sorted(gd_sig_tests))}")
    gd_cohensd_sets.append(gd_sig_tests)
    all_effect_sizes.append(effect_sizes)
    all_l_p_vals.append(l_p_vals)
    ax.tick_params(labelsize=12)
    ax.annotate(panel_titles[i], xycoords='axes fraction', xy=(-0.1,1.1), ha="left", fontsize=16, color="k", fontweight='bold')

#plt.suptitle("Volcano plot: Lab tests for GD vs BMI 25-30, cohen's d")

all_effect_sizes = np.array(all_effect_sizes)
all_l_p_vals = np.array(all_l_p_vals)
print(find_min_lim_gd(all_effect_sizes, all_l_p_vals, tests_for_volcanos))
      
# Create a Venn diagram
venn_ax = axes_cohensd_venn[2]
A = gd_cohensd_sets[0]
B = gd_cohensd_sets[1]
C = gd_cohensd_sets[2]
venn_labels = {'100': A - B - C, '010': B - A - C, '110': A & B - C,
               '001': C - A - B, '101': A & C - B, '011': B & C - A, '111': A & B & C}
diagram = venn3(subsets=(len(A - B - C), len(B - A - C), len(A & B - C),
               len(C - A - B), len(A & C - B), len(B & C - A), len(A & B & C)),
      set_labels=titles,
      ax=venn_ax,
      alpha=0.5,
      set_colors=('skyblue', 'lightgreen', 'lightcoral'))

for k in venn_labels:
    label = diagram.get_label_by_id(k)
    if label is not None:
        pass
        label.set(text="\n".join(venn_labels[k]), fontsize=14)
        
set_annotations = [obj for obj in venn_ax.get_children() if isinstance(obj, plt.Text) and obj.get_text() in titles]
for set_annotation in set_annotations:
    set_annotation.set_size(16)
    set_annotation.set_weight("bold")

venn_ax.annotate('L', xycoords='axes fraction', xy=(0,1.05), ha="left", fontsize=16, color="k", fontweight='bold')


In [None]:
#fig_cohensd.set_size_inches(18,12)
fig_cohensd.subplots_adjust(hspace=0.4, wspace=0.2)
fig_cohensd.savefig("volcanos_cohensd.svg", bbox_inches='tight')

In [None]:
fig_cohensd_venn.subplots_adjust(hspace=0.4, wspace=0.2)
fig_cohensd_venn.savefig("venn_cohensd.svg")

## Premature figures similar to figure 5: value at quantiles and quantiles, in healthy vs complication

In [None]:
group_to_test = defaultdict(list)
for test in tests_ordered_short_name:
    group = metadata.loc[test]["Group"]
    group_to_test[group].append(test)

subfig_ncols = 6
subplot_inches = 2.
rows_per_group = np.array(tuple(map(lambda group: (len(group_to_test[group]) - 1) // subfig_ncols + 1  , group_to_test)))
base_height = 0.8
base_width = 0.5

subfig_heights = rows_per_group * subplot_inches + base_height
fig_height = subfig_heights.sum()

fig = plt.figure(figsize=(subfig_ncols * (subplot_inches + base_height) + base_width , fig_height))
#fig.suptitle("Pre-Eclampsia per healthy BMI groups")
subfigs = fig.subfigures(nrows=len(group_to_test.keys()), ncols=1, squeeze=True, height_ratios=subfig_heights / fig_height)
i = -1

for i, group in enumerate(group_to_test.keys()):
    subfig_tests = tuple(filter(lambda test_name: metadata.loc[test_name]["Group"] == group, tests_ordered_short_name))  # has to have length greater than 0
    subfig = subfigs[i]
    subfig.suptitle(f"{group}:", fontsize=12, x=0, ha="left")
    # If division is without remainder, gets a blank line. So subtract 1 to avoid such issue.
    subplot_nrows = (len(subfig_tests) - 1)//subfig_ncols + 1
    axes = subfig.subplots(nrows=subplot_nrows, ncols=subfig_ncols)
    for j, test in enumerate(subfig_tests):
        ax = axes.flatten()[j]
        plot_qunatile_change(ax, test, data_ecl, c='r', decorate=True)
        #plot_qunatile_change(ax,test,data_bmi1,c='grey', decorate=False, alpha=0.5)
        plot_qunatile_change(ax,test,data_hnm4,c='k', decorate=False, alpha=0.5)
        ax.grid(False)
        if j < len(subfig_tests) - subfig_ncols:
            ax.set_xticks([])
            ax.set_xlabel('')
        # if j % subfig_ncols != 0:
        #     ax.set_ylabel('')
    axes_to_remove = axes.flatten()[j + 1:]
    for ax in axes_to_remove:
        subfig.delaxes(ax)
fig.legend(['Pregnancy','Pre-Ecl','Healthy'],loc='upper left', bbox_to_anchor=(0, 0))
plt.savefig('ECL_vs_healthy_value_at_quantile.pdf',bbox_inches = 'tight',pad_inches=0.1); 
plt.show()

In [None]:
group_to_test = defaultdict(list)
for test in tests_ordered_short_name:
    group = metadata.loc[test]["Group"]
    group_to_test[group].append(test)

subfig_ncols = 6
subplot_inches = 2.
rows_per_group = np.array(tuple(map(lambda group: (len(group_to_test[group]) - 1) // subfig_ncols + 1  , group_to_test)))
base_height = 0.8
base_width = 0.5

subfig_heights = rows_per_group * subplot_inches + base_height
fig_height = subfig_heights.sum()

fig = plt.figure(figsize=(subfig_ncols * (subplot_inches + base_height) + base_width , fig_height))
#fig.suptitle("Pre-Eclampsia per healthy BMI groups")
subfigs = fig.subfigures(nrows=len(group_to_test.keys()), ncols=1, squeeze=True, height_ratios=subfig_heights / fig_height)
i = -1

for i, group in enumerate(group_to_test.keys()):
    subfig_tests = tuple(filter(lambda test_name: metadata.loc[test_name]["Group"] == group, tests_ordered_short_name))  # has to have length greater than 0
    subfig = subfigs[i]
    subfig.suptitle(f"{group}:", fontsize=12, x=0, ha="left")
    # If division is without remainder, gets a blank line. So subtract 1 to avoid such issue.
    subplot_nrows = (len(subfig_tests) - 1)//subfig_ncols + 1
    axes = subfig.subplots(nrows=subplot_nrows, ncols=subfig_ncols)
    for j, test in enumerate(subfig_tests):
        ax = axes.flatten()[j]
        plot_qunatile_change(ax, test, data_ecl, c='r', decorate=True, use_quantiles=True)
        #plot_qunatile_change(ax,test,data_bmi1,c='grey', decorate=False, alpha=0.5, use_quantiles=True)
        plot_qunatile_change(ax,test,data_hnm4,c='k', decorate=False, alpha=0.5, use_quantiles=True)
        ax.grid(False)
        if j < len(subfig_tests) - subfig_ncols:
            ax.set_xticks([])
            ax.set_xlabel('')
            
        # if j % subfig_ncols != 0:
        #     ax.set_ylabel('')
    axes_to_remove = axes.flatten()[j + 1:]
    for ax in axes_to_remove:
        subfig.delaxes(ax)
fig.legend(['Pregnancy','Pre-Ecl','Healthy'],loc='upper left', bbox_to_anchor=(0, 0))
plt.savefig('ECL_vs_healthy_quantiles.pdf',bbox_inches = 'tight',pad_inches=0.1) 
plt.show()

In [None]:
group_to_test = defaultdict(list)
for test in tests_ordered_short_name:
    group = metadata.loc[test]["Group"]
    group_to_test[group].append(test)

subfig_ncols = 6
subplot_inches = 2.
rows_per_group = np.array(tuple(map(lambda group: (len(group_to_test[group]) - 1) // subfig_ncols + 1  , group_to_test)))
base_height = 0.8
base_width = 0.5

subfig_heights = rows_per_group * subplot_inches + base_height
fig_height = subfig_heights.sum()

fig = plt.figure(figsize=(subfig_ncols * (subplot_inches + base_height) + base_width , fig_height))
#fig.suptitle("Photpartum Hemorrhage and healthy BMI groups")
subfigs = fig.subfigures(nrows=len(group_to_test.keys()), ncols=1, squeeze=True, height_ratios=subfig_heights / fig_height)
i = -1

for i, group in enumerate(group_to_test.keys()):
    subfig_tests = tuple(filter(lambda test_name: metadata.loc[test_name]["Group"] == group, tests_ordered_short_name))  # has to have length greater than 0
    subfig = subfigs[i]
    subfig.suptitle(f"{group}:", fontsize=12, x=0, ha="left")
    # If division is without remainder, gets a blank line. So subtract 1 to avoid such issue.
    subplot_nrows = (len(subfig_tests) - 1)//subfig_ncols + 1
    axes = subfig.subplots(nrows=subplot_nrows, ncols=subfig_ncols)
    for j, test in enumerate(subfig_tests):
        ax = axes.flatten()[j]
        plot_qunatile_change(ax, test, data_pph, c='r', decorate=True)
        #plot_qunatile_change(ax,test,data_bmi1,c='grey', decorate=False, alpha=0.5)
        plot_qunatile_change(ax,test,data_hnm4,c='k', decorate=False, alpha=0.5)
        ax.grid(False)
        if j < len(subfig_tests) - subfig_ncols:
            ax.set_xticks([])
            ax.set_xlabel('')
        # if j % subfig_ncols != 0:
        #     ax.set_ylabel('')
    axes_to_remove = axes.flatten()[j + 1:]
    for ax in axes_to_remove:
        subfig.delaxes(ax)
fig.legend(['Pregnancy','PPH','Healthy'],loc='upper left', bbox_to_anchor=(0, 0))
plt.savefig('PPH_vs_healthy_value_at_quantile.pdf',bbox_inches = 'tight',pad_inches=0.1); 
plt.show()

In [None]:
group_to_test = defaultdict(list)
for test in tests_ordered_short_name:
    group = metadata.loc[test]["Group"]
    group_to_test[group].append(test)

subfig_ncols = 6
subplot_inches = 2.
rows_per_group = np.array(tuple(map(lambda group: (len(group_to_test[group]) - 1) // subfig_ncols + 1  , group_to_test)))
base_height = 0.8
base_width = 0.5

subfig_heights = rows_per_group * subplot_inches + base_height
fig_height = subfig_heights.sum()

fig = plt.figure(figsize=(subfig_ncols * (subplot_inches + base_height) + base_width , fig_height))
#fig.suptitle("Postpartume Hemorrhage and healthy BMI groups, quantiles")
subfigs = fig.subfigures(nrows=len(group_to_test.keys()), ncols=1, squeeze=True, height_ratios=subfig_heights / fig_height)
i = -1

for i, group in enumerate(group_to_test.keys()):
    subfig_tests = tuple(filter(lambda test_name: metadata.loc[test_name]["Group"] == group, tests_ordered_short_name))  # has to have length greater than 0
    subfig = subfigs[i]
    subfig.suptitle(f"{group}:", fontsize=12, x=0, ha="left")
    # If division is without remainder, gets a blank line. So subtract 1 to avoid such issue.
    subplot_nrows = (len(subfig_tests) - 1)//subfig_ncols + 1
    axes = subfig.subplots(nrows=subplot_nrows, ncols=subfig_ncols)
    for j, test in enumerate(subfig_tests):
        ax = axes.flatten()[j]
        plot_qunatile_change(ax, test, data_pph, c='r', decorate=True, use_quantiles=True)
        #plot_qunatile_change(ax,test,data_bmi1,c='grey', decorate=False, alpha=0.5, use_quantiles=True)
        plot_qunatile_change(ax,test,data_hnm4,c='k', decorate=False, alpha=0.5, use_quantiles=True)
        ax.grid(False)
        if j < len(subfig_tests) - subfig_ncols:
            ax.set_xticks([])
            ax.set_xlabel('')
        # if j % subfig_ncols != 0:
        #     ax.set_ylabel('')
    axes_to_remove = axes.flatten()[j + 1:]
    for ax in axes_to_remove:
        subfig.delaxes(ax)
fig.legend(['Pregnancy','PPH','Healthy'],loc='upper left', bbox_to_anchor=(0, 0))
plt.savefig('PPH_vs_healthy_quantiles.pdf',bbox_inches = 'tight',pad_inches=0.1)
plt.show()

In [None]:
group_to_test = defaultdict(list)
for test in tests_ordered_short_name:
    group = metadata.loc[test]["Group"]
    group_to_test[group].append(test)

subfig_ncols = 6
subplot_inches = 2.
rows_per_group = np.array(tuple(map(lambda group: (len(group_to_test[group]) - 1) // subfig_ncols + 1  , group_to_test)))
base_height = 0.8
base_width = 0.5

subfig_heights = rows_per_group * subplot_inches + base_height
fig_height = subfig_heights.sum()

fig = plt.figure(figsize=(subfig_ncols * (subplot_inches + base_height) + base_width , fig_height))
#fig.suptitle("Gestational Diabetes and healthy BMI groups")
subfigs = fig.subfigures(nrows=len(group_to_test.keys()), ncols=1, squeeze=True, height_ratios=subfig_heights / fig_height)
i = -1

for i, group in enumerate(group_to_test.keys()):
    subfig_tests = tuple(filter(lambda test_name: metadata.loc[test_name]["Group"] == group, tests_ordered_short_name))  # has to have length greater than 0
    subfig = subfigs[i]
    subfig.suptitle(f"{group}:", fontsize=12, x=0, ha="left")
    # If division is without remainder, gets a blank line. So subtract 1 to avoid such issue.
    subplot_nrows = (len(subfig_tests) - 1)//subfig_ncols + 1
    axes = subfig.subplots(nrows=subplot_nrows, ncols=subfig_ncols)
    for j, test in enumerate(subfig_tests):
        ax = axes.flatten()[j]
        plot_qunatile_change(ax, test, data_gd, c='r', decorate=True)
        #plot_qunatile_change(ax,test,data_bmi1,c='grey', decorate=False, alpha=0.5)
        plot_qunatile_change(ax,test,data_hnm4,c='k', decorate=False, alpha=0.5)
        ax.grid(False)
        if j < len(subfig_tests) - subfig_ncols:
            ax.set_xticks([])
            ax.set_xlabel('')
        # if j % subfig_ncols != 0:
        #     ax.set_ylabel('')
    axes_to_remove = axes.flatten()[j + 1:]
    for ax in axes_to_remove:
        subfig.delaxes(ax)
fig.legend(['Pregnancy','GD','Healthy'],loc='upper left', bbox_to_anchor=(0, 0))
plt.savefig('GD_vs_healthy_value_at_quantile.pdf',bbox_inches = 'tight',pad_inches=0.1); 
plt.show()

In [None]:
group_to_test = defaultdict(list)
for test in tests_ordered_short_name:
    group = metadata.loc[test]["Group"]
    group_to_test[group].append(test)

subfig_ncols = 6
subplot_inches = 2.
rows_per_group = np.array(tuple(map(lambda group: (len(group_to_test[group]) - 1) // subfig_ncols + 1  , group_to_test)))
base_height = 0.8
base_width = 0.5

subfig_heights = rows_per_group * subplot_inches + base_height
fig_height = subfig_heights.sum()

fig = plt.figure(figsize=(subfig_ncols * (subplot_inches + base_height) + base_width , fig_height))
#fig.suptitle("Gestational Diabetes and healthy BMI groups (quantiles)")
subfigs = fig.subfigures(nrows=len(group_to_test.keys()), ncols=1, squeeze=True, height_ratios=subfig_heights / fig_height)
i = -1

for i, group in enumerate(group_to_test.keys()):
    subfig_tests = tuple(filter(lambda test_name: metadata.loc[test_name]["Group"] == group, tests_ordered_short_name))  # has to have length greater than 0
    subfig = subfigs[i]
    subfig.suptitle(f"{group}:", fontsize=12, x=0, ha="left")
    # If division is without remainder, gets a blank line. So subtract 1 to avoid such issue.
    subplot_nrows = (len(subfig_tests) - 1)//subfig_ncols + 1
    axes = subfig.subplots(nrows=subplot_nrows, ncols=subfig_ncols)
    for j, test in enumerate(subfig_tests):
        ax = axes.flatten()[j]
        plot_qunatile_change(ax, test, data_gd, c='r', decorate=True, use_quantiles=True)
        #plot_qunatile_change(ax,test,data_bmi1,c='grey', decorate=False, alpha=0.5, use_quantiles=True)
        plot_qunatile_change(ax,test,data_hnm4,c='k', decorate=False, alpha=0.5, use_quantiles=True)
        ax.grid(False)
        if j < len(subfig_tests) - subfig_ncols:
            ax.set_xticks([])
            ax.set_xlabel('')
        # if j % subfig_ncols != 0:
        #     ax.set_ylabel('')
    axes_to_remove = axes.flatten()[j + 1:]
    for ax in axes_to_remove:
        subfig.delaxes(ax)
fig.legend(['Pregnancy','GD','Healthy'],loc='upper left', bbox_to_anchor=(0, 0))
plt.savefig('GD_vs_healthy_quantiles.pdf',bbox_inches = 'tight',pad_inches=0.1); 
plt.show()

## Venn intersecting the 3 periods (preconception, gestation, postpartum). Above see the Venn diagram for inteseting complications per period

In [None]:
fig, axs = plt.subplots(ncols=3, figsize=(21,5))
titles = ["Pre-conception", "Gestation", "Postpartum"]
set_labels = ["Pre-ECL", "PPH", "GD"]
for i, ax in enumerate(axs):
    A = pre_ecl_cohensd_sets[i]
    B = pph_cohensd_sets[i]
    C = gd_cohensd_sets[i]
    
    venn_labels = {'100': A - B - C, '010': B - A - C, '110': A & B - C,
                   '001': C - A - B, '101': A & C - B, '011': B & C - A, '111': A & B & C}
    diagram = venn3(subsets=(len(A - B - C), len(B - A - C), len(A & B - C),
                   len(C - A - B), len(A & C - B), len(B & C - A), len(A & B & C)),
          set_labels=('Pre-ECL', 'PPH', 'GD'),
          alpha=0.5,
          set_colors=('navy', 'darkviolet', 'chocolate'),
           ax=ax)
    
    for k in venn_labels:
        label = diagram.get_label_by_id(k)
        if label is not None:
            label.set(text="\n".join(venn_labels[k]), size=11)
    
    set_annotations = [obj for obj in ax.get_children() if isinstance(obj, plt.Text) and obj.get_text() in set_labels]
    for set_annotation in set_annotations:
        set_annotation.set_size(14)
        
    ax.set_title(titles[i], fontsize=16)

#plt.suptitle("Significant tests between PPH,GD,pre-Eclampsia by period, cohen's d", fontsize=12)
#plt.savefig("venn_intersection_cohensd.pdf")
plt.show()

## Figure 6(6S) - compare all value-at-quantiles (quantiles) for all complications and healthy

In [None]:
def format_y_axis(axes):
    ymin, ymax = axes.get_ylim()
    min_y_tick = (ymax - ymin) * 0.2 + ymin
    max_y_tick = (ymax - ymin) * 0.8 + ymin
    if max_y_tick <= 0.1:
        max_y_tick_for_label = 100 * max_y_tick
        min_y_tick_for_label = 100 * min_y_tick
        axes.annotate(r"$\times 10^{-2}$", xycoords='axes fraction', xy=(0,1), ha="right", fontsize=12, color="k")
    elif max_y_tick <= 1:
        max_y_tick_for_label = 10 * max_y_tick
        min_y_tick_for_label = 10 * min_y_tick
        axes.annotate(r"$\times 10^{-1}$", xycoords='axes fraction', xy=(0,1), ha="right", fontsize=12, color="k")
    else:
        max_y_tick_for_label = max_y_tick
        min_y_tick_for_label = min_y_tick
    axes.set_yticks([min_y_tick, max_y_tick], [f"{min_y_tick_for_label:#.3g}", f"{max_y_tick_for_label:#.3g}"])
    axes.tick_params(axis='y', labelsize=14)

def f(use_cohensd, use_quantiles, is_pdf=False):
    if use_cohensd:
        effect_size_type = "cohens_d"
        ecl = pre_ecl_cohensd_sets
        gd = gd_cohensd_sets
        pph = pph_cohensd_sets
    else:
        effect_size_type = "mean_diff"
        ecl = pre_ecl_fc
        gd = gd_fc
        pph = pph_fc
    plt.rcParams.update({'font.size': 14})
    #Union all periods, remove arrow pointing enrichment or depleting
    pre_ecl_all_time_set = set(map(lambda x: (x[1:]), reduce(lambda x, y: x|y, ecl)))
    pph_all_time_set = set(map(lambda x: (x[1:]), reduce(lambda x, y: x|y, pph)))
    gd_all_time_set = set(map(lambda x: (x[1:]), reduce(lambda x, y: x|y, gd)))
    
    group_to_test = defaultdict(list)
    for test in sorted(tests_for_volcanos, key=lambda test_name: metadata.loc[test_name]['Nice name']):
        group = metadata.loc[test]["Group"]
        group_to_test[group].append(test)
    group_to_test_tmp = {}
    for group in ["Renal",  "Metabolism", "RBCs", "Coagulation", "Immune", "Liver", "Musculoskeletal", "Endocrine"]:
        group_to_test_tmp[group] = group_to_test[group]
    group_to_test = group_to_test_tmp
    del group_to_test_tmp
    ncols = 7
    subplot_inches = 3
    total_tests = sum((len(group_to_test[k]) for k in group_to_test))
    nrows = ceil(total_tests / ncols)
    fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize=(subplot_inches * ncols, subplot_inches * nrows))
    fig.set_facecolor("none")
    i = 0
    cur_group = ""
    
    for group, group_tests in group_to_test.items():
        for test in group_tests:
            ax = axs.flatten()[i]
            if cur_group != group:
                ax.annotate(f"{group} ➡", fontsize=17, xycoords='axes fraction', xy=(0.,1.29), ha="left", fontweight="bold", backgroundcolor=(*mcolors.to_rgb(mcolors.CSS4_COLORS["aliceblue"]), 1.0))
                cur_group = group
            plot_qunatile_change(ax, test, data_ecl, c='navy', decorate=True, fontsize=14, use_quantiles=use_quantiles)
            plot_qunatile_change(ax, test, data_pph, c='darkviolet', decorate=False, use_quantiles=use_quantiles)
            plot_qunatile_change(ax, test, data_gd, c='chocolate', decorate=False, use_quantiles=use_quantiles)
            plot_qunatile_change(ax,test,data_hnm4,c='limegreen', decorate=False, alpha=0.5, use_quantiles=use_quantiles)
            ax.spines.right.set_visible(False)
            ax.spines.top.set_visible(False)
            ax.grid(False)
            format_y_axis(ax)
            ax.set_facecolor("none")
            pre_ecl_text = "Pe" if metadata.loc[test, "Short name"] in pre_ecl_all_time_set else ""
            pph_text = "PH" if metadata.loc[test, "Short name"] in pph_all_time_set else ""
            gd_text = "GD" if metadata.loc[test, "Short name"] in gd_all_time_set else ""
            text = ax.annotate(pre_ecl_text, (1.0,1.05), xycoords='axes fraction',fontsize=12,color='navy')
            text = ax.annotate(pph_text, (0,1), xycoords=text, fontsize=12, color='darkviolet', verticalalignment="bottom")
            ax.annotate(gd_text, (0,1), xycoords=text, fontsize=12, color='chocolate', verticalalignment="bottom")
            if total_tests - i > ncols:
                ax.set_xlabel('')
                ax.set_xticks([])
            else:
                ax.tick_params(axis='x', labelsize=12)
                ax.set_xlabel(ax.get_xlabel(), fontsize=12)
            ax.set_box_aspect(1)
            
            i += 1
        
    axes_to_remove = axs.flatten()[i:]
    for ax in axes_to_remove:
        fig.delaxes(ax)
    fig.legend(['Pregnancy','Pre-eclampsia','Postpartum Hemorrhage', 'Gestational Diabetes', 'Healthy'],loc='lower right', bbox_to_anchor=(1, 0))
    #fig.subplots_adjust(wspace=0.1, hspace=0.1)
    fig.set_layout_engine("tight")
    data_type = "quantiles" if use_quantiles else "value_at_quantile"
    suffix = "pdf" if is_pdf else "svg"
    #plt.savefig(f'significant_{data_type}_{effect_size_type}_effect_size_tight.{suffix}',bbox_inches = 'tight')
    plt.show()

In [None]:
f(True, True)
f(True, False)
# f(True, False, True)
# f(False, True)
# f(False, False)
# f(False, False, True)