# Load Libraries

In [None]:
# Data manipulation
import numpy as np
import pandas as pd

# Data science
import math
import scipy.stats as stats
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from statsmodels.stats.multitest import multipletests as mt

# Plots
import matplotlib as mpl
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import ticker, gridspec
from matplotlib import font_manager as fm


# Working with dates
from datetime import date,datetime
import dateutil

# Looping  progress
from tqdm.notebook import tqdm

# Reg expressions
import re

# Pretty table printing
import tabulate

import os
import subprocess

# Misc libraries
from IPython.display import display, HTML
#from IPython.core.display import display, HTML

# Set seaborn figure size, font size, and style
sns.set(rc={'figure.figsize':(11.7,8.27)})
sns.set(font_scale=1.5)
sns.set_style("white")

# Set Pandas options so we can see our entire dataframe
pd.options.display.max_rows = 10000
pd.options.display.max_columns = 10000
pd.options.display.max_colwidth = None

# Print our versions of this packages, this allows us to make sure
# we have the working versions we need. 
print(f"Pandas version: {pd.__version__}")


# Specify the directory where custom fonts are stored
font_dir = '/users/lapt3u/.fonts'

# Add fonts from the specified directory to Matplotlib's font manager
font_files = fm.findSystemFonts(fontpaths=[font_dir])
for font_file in font_files:
    fm.fontManager.addfont(font_file)

# Set the default font family to Arial (if Arial is available)
plt.rcParams['font.family'] = 'Arial'

print(plt.rcParams['font.family'])

# Prep Environment

In [None]:
HOME_DIR = "/data/pathogen_ncd"
os.chdir(f'{HOME_DIR}/results')

# Data Prep

In [None]:
all_res = pd.read_excel('final_res_with_merged_data_04_12_2023.xlsx')

res = all_res.loc[all_res['rep_stat'] == 'replicated', :]
cnr = all_res.loc[all_res['rep_stat'] == 'could_not', :]

In [None]:
lw = all_res.loc[: , ['icd', 'disease', 'org', 'ukb_per_dis_bh_fdr_corr_nom_p',
                      'ukb_OR',
                      'std_lev', 'rep_stat']]

# Limit to only replicated results
lw = lw.loc[lw['rep_stat'] == 'replicated', :]

In [None]:
lcnr = cnr.loc[: , ['icd', 'disease', 'org', 'ukb_per_dis_bh_fdr_corr_nom_p',
                    'ukb_OR',
                  'std_lev', 'rep_stat']]

## Convert std lev to int for plotting

In [None]:
cond_list = [lw['std_lev'] == 'Tier 1', 
             lw['std_lev']  == 'Tier 2', 
             lw['std_lev'] == 'exp_neg',
             lw['std_lev']  == 'unk']
choices = [0, 1, 2, 3]
lw.loc[:, 'std_lev_int'] = np.select(cond_list, choices)

## Generate Neg Log Meta P

In [None]:
# Generate neg_log_ukb
lw.loc[:, 'neg_log_ukb'] = -np.log10(lw.loc[:, 'ukb_per_dis_bh_fdr_corr_nom_p'])

## Get Abbreviated Disease Names

In [None]:
dis_dat = pd.read_excel('../dicts/dis_abbrev_dict.xlsx')
lw = pd.merge(lw, dis_dat[['icd', 'dis_abbrev']], 
         left_on = 'icd', 
         right_on = 'icd',
         how = 'left')

lw.loc[:, 'dis_abbrev_lab'] = "[" + lw.loc[:, 'icd'] + "] " + lw.loc[:, 'dis_abbrev']

lcnr = pd.merge(lcnr, dis_dat[['icd', 'dis_abbrev']], 
         left_on = 'icd', 
         right_on = 'icd',
         how = 'left')



lcnr.loc[:, 'dis_abbrev_lab'] = "[" + lcnr.loc[:, 'icd'] + "] " + lcnr.loc[:, 'dis_abbrev']

## Get nice pathogen names

In [None]:
vir_dat = pd.read_excel('../dicts/antigen_dict.xlsx')
vir_dat = vir_dat[['Tag', 'Organism']].drop_duplicates()

lw = pd.merge(lw, vir_dat, 
         left_on = 'org', 
         right_on = 'Tag',
         how = 'left')

lcnr = pd.merge(lcnr, vir_dat, 
         left_on = 'org', 
         right_on = 'Tag',
         how = 'left')

# Side-by-side Heatmap

### UKB -Log10(FDR) Max: 
<div style="margin-left: 40px;">
2.5
</div>

### UKB OR Max:
<div style="margin-left: 40px;">     
5
</div>

## Plotting Functions

In [None]:
# Convert the value to the size of the marker
def value_to_size(val):
    if math.isnan(val):
        return np.nan
    
    elif val < 1:
        return np.interp(val, [size_min, 1], [MAX_SIZE, MIN_SIZE])
    else:
        return np.interp(val, [1, size_max], [MIN_SIZE, MAX_SIZE])

In [None]:
# Function to generate the histogram on the side of each heatmap counting the
# replicated results for each disease row.
def gen_side_hist_new(input_ax, dis_hist, lab_font, y_val_ls, INPUT_HEIGHT = 1, DEBUG = False):
    
    # Actually plot the disease histogram along y-axis
    input_ax.barh(dis_hist['dis'], width = dis_hist['count'], height = INPUT_HEIGHT,
                  linewidth = 0)
    input_ax.invert_xaxis()

    # MANUALLY SETTING MAX - MIGHT NEED TO CHANGE
    input_ax.set_xlim([10, 0])

    # Move the x-axis tick labels to the top of the plot.
    input_ax.tick_params(axis="x", bottom = False, top = False, 
                             labelbottom = False, labeltop = False,
                             labelsize = 14)

    # MANUALLY SETTING MAX - MIGHT NEED TO CHANGE
    input_ax.xaxis.set_major_locator(ticker.FixedLocator(list(range(0, 11, 2))))
    input_ax.xaxis.set_major_formatter(ticker.FixedFormatter([0, 2, 4, 6, 8, 10]))


    # Move the x-axis tick labels to the top of the plot.
    input_ax.tick_params(axis="y", left = False, right = False, 
                               labelleft = False, labelright = False)


    # Remove almost all spines
    input_ax.spines['top'].set_visible(True)
    input_ax.spines['bottom'].set_visible(True)
    input_ax.spines['left'].set_visible(False)
    input_ax.spines['right'].set_visible(False)


    input_ax.grid(visible = True, which = 'major', axis = 'x',
                      color = '#242424', linestyle = '-', linewidth = 1,
                      alpha = 0.05)
    
    
    
    for curr_x_val in list(range(0, 11, 2)):

        # Note these y values are from the bottom to top for some reason
        for curr_y_val in y_val_ls:

            x_pos = curr_x_val + 0.35
            y_pos = curr_y_val 

            if DEBUG:
                print(f"x: {curr_x_val} -> {x_pos} | y: {curr_y_val} -> {y_pos} | str: {str(curr_x_val)}")

            input_ax.text(x = x_pos, y = y_pos, s = str(curr_x_val),
                              va = 'top', ha = 'center', 
                              zorder = 0.5, transform = input_ax.transData,
                              fontdict = lab_font, 
                              rotation = 90, rotation_mode = 'anchor', transform_rotates_text = True)

    return input_ax

In [None]:
# Function to draw the odds ratio legend or key, where the size of the triangle
# is proportional to the distance from 1.
def draw_horizon_or_legend(input_ax, risk_mark, prot_mark, 
                           RISK_Y_VALS, PROT_Y_VALS,
                           RISK_LAB_Y = 0.6,
                           PROT_LAB_Y = 0.4,
                           MARKER_X_OFFSET = 0.25,
                           X_START = 0.1, ARROW_LEN_X = 0.4, 
                           ARROW_LEN_Y = 0.05,
                           Y_CENTER = 0.5,
                           Y_OFFSET = 0.3,
                           Y_LIM = (0, 1),
                           X_LIM = (-0.1, 0.6),
                           DEBUG = False):
    from matplotlib.patches import  FancyArrowPatch as Arrow
    from matplotlib.patches import ConnectionStyle
    CON_COLOR =  '#4A6990'
    CASE_COLOR = '#bb3d2e'

    
    OR_FONT = {
                'family': 'Arial',
                'color':  'black',
                'weight': 'normal',
                'size': 13,  
    }

    OR_TITLE_FONT = {
                'family': 'Arial',
                'color':  'black',
                'weight': 'normal',
                'size': 13,  
    }

    ARROW_TEXT_FONT = {
                'family': 'Arial',
                'color':  'black',
                'weight': 'normal',
                'size': 13,  
    }

    
    #####################################
    #                                   # 
    #    Generate Odds Ratio Legend     #
    #                                   #
    #####################################
    # Define which ORs will be represented in legend
    or_ls = [0.30, 0.65, 0.99, 1.01, 3.0, 5]

    prot_or_ls = [x for x in or_ls if x < 1]
    prot_or_ls.reverse()

    risk_or_ls = [x for x in or_ls if x > 1]
    
    # Now map each of our OR values in our evenly spaced list
    # to the size values used in the main plot
    # [98.7, 2696.9, 5,269.0, 7867.2, 10,000]
    size_ls = [value_to_size(y) for y in or_ls]

    prot_size_ls = [value_to_size(y) for y in prot_or_ls]
    risk_size_ls = [value_to_size(y) for y in risk_or_ls]

    # Setup y values for our hex legend, we can make them closer
    # than just 1 value.
    size_y = [y* 0.5 for y in range(0, len(or_ls), 1)]
    size_y = sorted(size_y, reverse = True)

    risk_size_y_ls = RISK_Y_VALS
    prot_size_y_ls = PROT_Y_VALS
    
    # Set x values to 0.5 so the hexagons are centered on the x-axis
    size_x = [0.5] * len(or_ls)

    size_x_ls = [x * MARKER_X_OFFSET for x in range(0, len(prot_or_ls), 1)]
    

    # Actually create the hex legend
    # x: manually decided numbers on how close the hexagons should be to each
    # y: center the hexgons in the legend
    # marker: hexgon (obviously)
    # s: Give the hexagons the size values that are mapped to the ones in
    #    the main plot, so the sizes match
    # Give it a nice color and alpha, no meaning
    size_leg_alpha = 0.75
    
    if DEBUG:
        print(f'[PROT] x_ls: {size_x_ls}, y_ls: {prot_size_y_ls}, s_ls: {prot_size_ls}, h_ls: {prot_h_ls}')
        print(f'[RISK] x_ls: {size_x_ls}, y_ls: {risk_size_y_ls}, s_ls: {risk_size_ls}, h_ls: {risk_h_ls}')

    input_ax.scatter(
                        x = size_x_ls,
                        y = prot_size_y_ls,
                        s = prot_size_ls,
                        facecolors = 'none', 
                        edgecolors = 'black',
                        linewidths = 2,
                        alpha = size_leg_alpha,
                        marker = prot_mark
                    )

    input_ax.scatter(
                        x = size_x_ls,
                        y = risk_size_y_ls,
                        s = risk_size_ls,
                        facecolors = 'none', 
                        edgecolors = 'black',
                        linewidths = 2,
                        alpha = size_leg_alpha,
                        marker = risk_mark
                    )

    
    # Set the y-axis tick labels to nothing so they don't show up
    input_ax.set_xticklabels([])
    input_ax.set_yticklabels([])
    
    input_ax.set_xlim(X_LIM[0], X_LIM[1])
    input_ax.set_ylim(Y_LIM[0], Y_LIM[1])

    # Remove all the outside lines of the plot (the frame)
    for curr_spine in ['top', 'bottom', 'left', 'right']:
        input_ax.spines[curr_spine].set_visible(False)


    for curr_ind, curr_y_val in enumerate(risk_size_y_ls):
        curr_x = size_x_ls[curr_ind]
        curr_y = RISK_LAB_Y
        curr_val = risk_or_ls[curr_ind]

        input_ax.text(x = curr_x, y = curr_y, s = curr_val,
                         ha = 'center', va = 'center', fontdict = OR_FONT)  

    for curr_ind, curr_y_val in enumerate(prot_size_y_ls):
        curr_x = size_x_ls[curr_ind]
        curr_y = PROT_LAB_Y
        curr_val = prot_or_ls[curr_ind]


        input_ax.text(x = curr_x, y = curr_y, s = curr_val,
                         ha = 'center', va = 'center', fontdict = OR_FONT) 


    input_ax.text(x = -0.25, y = 0.5, s = 'UKB Odds Ratio',
                         ha = 'center', va = 'center', fontdict = OR_TITLE_FONT)


    risk_ar_st_x = X_START
    risk_ar_end_x = risk_ar_st_x + ARROW_LEN_X
    risk_ar_st_y = Y_CENTER + Y_OFFSET
    risk_ar_end_y = risk_ar_st_y + ARROW_LEN_Y

    mid_x = X_START + (ARROW_LEN_X / 2)

    TEXT_OFFSET_Y = 0.1 
    risk_mid_y = risk_ar_st_y + (ARROW_LEN_Y / 2) + TEXT_OFFSET_Y


    prot_ar_st_x = X_START
    prot_ar_end_x = prot_ar_st_x + ARROW_LEN_X
    prot_ar_st_y = Y_CENTER - Y_OFFSET
    prot_ar_end_y = prot_ar_st_y - ARROW_LEN_Y
    prot_mid_y = prot_ar_st_y - (ARROW_LEN_Y / 2) - TEXT_OFFSET_Y



    risk_ar = Arrow(posA = (risk_ar_st_x, risk_ar_st_y), posB = (risk_ar_end_x, risk_ar_end_y), 
                    arrowstyle = "-|>", connectionstyle = ConnectionStyle("Arc3, rad=0.05"),
                    color = CASE_COLOR, linewidth = 2, mutation_scale = 25, 
                    transform = input_ax.transAxes)

    prot_ar = Arrow(posA = (prot_ar_st_x, prot_ar_st_y), posB = (prot_ar_end_x, prot_ar_end_y), 
                    arrowstyle = "-|>", connectionstyle = ConnectionStyle("Arc3, rad=-0.05"),
                    color = CON_COLOR, linewidth = 2, mutation_scale = 25, 
                    transform = input_ax.transAxes)    

    input_ax.add_patch(risk_ar)
    input_ax.add_patch(prot_ar)


    input_ax.text(x = mid_x, 
                    y = risk_mid_y, 
                    s = 'Increasing Risk',
                    ha = 'center', va = 'center',
                    transform = input_ax.transAxes, 
                    fontdict = ARROW_TEXT_FONT)


    input_ax.text(x = mid_x, 
                        y = prot_mid_y, 
                        s = 'Decreasing Risk',
                        ha = 'center', va = 'center',
                        transform = input_ax.transAxes, 
                        fontdict = ARROW_TEXT_FONT)

    
    return input_ax

In [None]:
# Function to draw the actual heatmaps themselves.
def plot_part_heatmap(input_ax, input_dat, 
                      rep_alpha, palette, 
                      risk_mark, prot_mark,
                      HEAT_X_TICK_LABEL_ROTATION = 30,
                      HEAT_Y_TICK_LABEL_FONT_SIZE = 14,
                      HEAT_X_TICK_LABEL_FONT_SIZE = 10,
                      HEAT_X_TICK_LABEL_VA = 'baseline',
                      HEAT_X_TICK_LABEL_HA = 'left',
                      HEAT_X_TICK_LABEL_ROTATION_MODE = 'default',
                      TRANSFORM_ROTATES_TEXT_BOOL = False,
                      annotate_top = False):
    


    # We are assuming this is ordered properly ¯\(ツ)/¯
    dis_list = input_dat.loc[:, 'dis_abbrev_lab'].unique().tolist()
    dis_list.reverse()


    local_y_to_num = {}
    counter = 0
    for curr_dis in dis_list:
        local_y_to_num[curr_dis] = counter
        counter = counter + 1
    
    
    # Plotting
    protect = input_dat.copy(deep = True)

    # Set any -log meta p and UKB OR for non-protective pairs to NAN
    protect.loc[protect['ukb_OR'] > 1, 'neg_log_ukb'] = np.nan
    protect.loc[protect['ukb_OR'] > 1, 'ukb_OR'] = np.nan

    protect = protect.loc[:, ['dis_abbrev_lab', 'Organism', 'neg_log_ukb', 'ukb_OR']]
    protect['org_num'] = [corr_x_to_num[x] for x in protect.loc[:, 'Organism']]
    protect['dis_num'] = [local_y_to_num[x] for x in protect.loc[:, 'dis_abbrev_lab']]
    protect['size'] = protect.loc[:, 'ukb_OR'].apply(value_to_size)

    protect = protect.loc[((protect['neg_log_ukb'].notna()) & (protect['neg_log_ukb'].notna())), :]

    risk = input_dat.copy(deep = True)

    # Set any -log meta p and UKB OR for protective pairs to NAN
    risk.loc[risk['ukb_OR'] < 1, 'neg_log_ukb'] = np.nan
    risk.loc[risk['ukb_OR'] < 1, 'ukb_OR'] = np.nan

    risk = risk.loc[:, ['dis_abbrev_lab', 'Organism', 'neg_log_ukb', 'ukb_OR']]
    risk['org_num'] = [corr_x_to_num[x] for x in risk.loc[:, 'Organism']]
    risk['dis_num'] = [local_y_to_num[x] for x in risk.loc[:, 'dis_abbrev_lab']]
    risk['size'] = risk.loc[:, 'ukb_OR'].apply(value_to_size)

    risk = risk.loc[((risk['neg_log_ukb'].notna()) & (risk['neg_log_ukb'].notna())), :]


    #####################################
    #                                   # 
    #      Plotting main heatmap        #
    #                                   #
    #####################################
    # Protective
    input_ax.scatter(
            x = protect['org_num'],
            y = protect['dis_num'],
            s = protect['size'],
            c = protect['neg_log_ukb'],
            cmap = palette,
            alpha = rep_alpha,
            zorder = 2.1,
            marker = prot_mark
        )
    # Risk
    input_ax.scatter(
            x = risk['org_num'],
            y = risk['dis_num'],
            s = risk['size'],
            cmap = palette,
            c = risk['neg_log_ukb'],
            alpha = rep_alpha,
            zorder = 2.1,
            marker = risk_mark
        )

    
    
    #####################################
    #                                   # 
    #    Manipulating Heatmap input_axes      #
    #                                   #
    #####################################

    # Turn off tick marks on y-input_axis
    input_ax.yaxis.set_major_locator(plt.NullLocator())

    # Put a tick mark at every index value (each disease), 0, 1, 2, etc.
    input_ax.set_yticks([v for k,v in local_y_to_num.items()])

    # Now label each tick mark with the appropriate disease names.
    # ['[B02] Herpes Zoster',
    #  '[B19] Unspec Viral Hepatits',
    #  '[B24] Unspec Hiv ',
    # ...]


    input_ax.set_yticklabels([k for k in local_y_to_num], 
                       va = 'center', 
                       fontsize = HEAT_Y_TICK_LABEL_FONT_SIZE),
                       #fontfamily = 'monospace')

    # Put a tick on x-input_axis for each index value (each organism), 0, 1, 2, etc.
    input_ax.set_xticks([v for k,v in corr_x_to_num.items()])

    # Label each x tick mark with the appropriate organism
    input_ax.set_xticklabels([k for k in corr_x_to_num], 
                       fontsize = HEAT_X_TICK_LABEL_FONT_SIZE,
                       rotation = HEAT_X_TICK_LABEL_ROTATION, 
                       va = HEAT_X_TICK_LABEL_VA,
                       ha = HEAT_X_TICK_LABEL_HA,
                       rotation_mode = HEAT_X_TICK_LABEL_ROTATION_MODE, 
                        transform_rotates_text = TRANSFORM_ROTATES_TEXT_BOOL)

    # Add in grid lines on the y-input_axis. Major will get us a line for each org
    #input_ax.grid(b = True, which = 'major', input_axis = 'x')
    input_ax.grid(visible = True, which = 'major', axis = 'x')

    # Add grid lines for the x-input_axis, we need both major and minor for each
    # disease to have a line.
    #input_ax.grid(b = True, which = 'both', input_axis = 'y')
    input_ax.grid(visible = True, which = 'both', axis = 'y')

    # Setup the y-input_axis limits to give a little extra space on each side
    # these are index values, 0, 1, 2, etc. so we have 1 full index 
    # extra on each side.
    input_ax.set_ylim([-1, max([v for v in local_y_to_num.values()]) + 1])

    # Similar to y-input_axis, we add an extra full index on each side of
    # the x-input_axis to have a bit of space
    input_ax.set_xlim([-1, max([v for v in corr_x_to_num.values()]) + 1])

    # Move the y-input_axis tick labels to the top of the plot.
    input_ax.tick_params(axis="y", left = False, right = True, 
                   labelleft = False, labelright = True)

    
    if annotate_top:
        # Move the x-input_axis tick labels to the top of the plot.
        input_ax.tick_params(axis="x", bottom = False, top = True, 
                       labelbottom = False, labeltop = True)
    # No annotation
    else:
        input_ax.tick_params(axis="x", bottom = False, top = False, 
                       labelbottom = False, labeltop = False)        
    
    
    #####################################
    #                                   # 
    # Adding Standard Level Indicators  #
    #                                   #
    #####################################

    combo = pd.concat([protect, risk])
    
    combo = combo.merge(input_dat.loc[:, ['dis_abbrev_lab', 'Organism', 'std_lev_int']],
                    how = 'left', on = ['dis_abbrev_lab', 'Organism'])
    
    # Convert each dis/org pair from their int representation
    # to actual color codes (RGB)
    std_color = combo['std_lev_int'].values

    local_std_cols = []
    for curr_x in std_color:
        if curr_x == 0:
            local_std_cols.append(t1_col)
        elif curr_x == 1:
            local_std_cols.append(t2_col)
        elif curr_x == 2:
            local_std_cols.append(neg_col)
        else:
            local_std_cols.append(unk_col)
    
    
    # Create our standard level dots
    # x is organisms
    # y is diseases
    # shape is just a circle (default)
    # Set all to the same size
    # Color based on standard level
    # zorder should put this dot on top of the hexagons.
    input_ax.scatter(
        x = combo.loc[:, 'org_num'],
        y = combo.loc[:, 'dis_num'],
        s = STD_MARKER_SIZE, 
        c = local_std_cols,  
        zorder = 2.5
    )

    return input_ax

In [None]:
# Draws the pathogen grouping lines at the top of plot
#
# Parameters controling top polygon sizes
# NUB: Line that comes up from top x-axis
# ANGLE: Angle to draw line from nub up to the right, this should match
#        the x-axis tick label rotation
# UPPER_LINE_LEN: Length of line for top part of polygon
# LOWER_LINE_LEN: Lengh of line for lower portion of polygon
def draw_groups(input_ax, ANGLE = 45, NUB = 0.4, 
                UPPER_LINE_LEN = 4.5, LOWER_LINE_LEN = 3):
        
    # Fonts for top level label and lower level labels
    top_font = { 'size': 13, 'fontweight' : 'bold'}
    bot_font = { 'size': 11, 'fontweight' : 'normal'}

    
    # Colors for microbes (T. gondii is not a bacteria) and viruses
    vir_high =  sns.color_palette("Paired")[1]
    vir_low  =  sns.color_palette("Paired")[0]

    bact_high = sns.color_palette("Paired")[7]
    bact_low  = sns.color_palette("Paired")[6]


    ########################
    #       Microbes       #
    ########################
    our_style = dict(boxstyle = 'square', pad = 0.3, outline_col = bact_high,
                     label_fill = 'white', box_fill = bact_low, linewidth = 2,
                     lab_font = top_font)

    plot_and_lab_groups_rec(group_name = 'Microbes', grp_start = 0, grp_end = 2,
                             nub_len = NUB, input_style = our_style,
                             line_ang = ANGLE, line_len = UPPER_LINE_LEN,
                             left_nudge = 1, draw_line = True, in_ax = input_ax)


    ########################
    #       Viruses       #
    ########################
    our_style = dict(boxstyle = 'square', pad = 0.3, outline_col = vir_high,
                     label_fill = 'white', box_fill = vir_low, linewidth = 2,
                     lab_font = top_font)

    plot_and_lab_groups_rec(group_name = 'Viruses', grp_start = 3, grp_end = 14,
                             nub_len = NUB, input_style = our_style,
                             line_ang = ANGLE, line_len = UPPER_LINE_LEN,
                             right_nudge = 1, draw_line = True, in_ax = input_ax)



    herpes_col = sns.color_palette("Paired")[2]
    hpv_col = sns.color_palette("Paired")[8]

    ########################
    #       Papilloma      #
    ########################
    our_style = dict(boxstyle = 'square', pad = 0.3, outline_col = hpv_col,
                     label_fill = 'white', box_fill = "none", linewidth = 2,
                     lab_font = bot_font)

    plot_and_lab_groups_rec(group_name = 'Papilloma', grp_start = 7, grp_end = 8,
                             nub_len = NUB, input_style = our_style,
                             line_ang = ANGLE, line_len = LOWER_LINE_LEN,
                             draw_line = True, in_ax = input_ax)



    ########################
    #       Herpes         #
    ########################
    our_style = dict(boxstyle = 'square', pad = 0.3, outline_col = herpes_col,
                     label_fill = 'white', box_fill = "none", linewidth = 2,
                     lab_font = bot_font)

    plot_and_lab_groups_rec(group_name = 'Herpes', grp_start = 9, grp_end = 14,
                             nub_len = NUB, input_style = our_style,
                             line_ang = ANGLE, line_len = LOWER_LINE_LEN, right_nudge = 1,
                             draw_line = True, in_ax = input_ax)

    return input_ax

In [None]:
# Function to draw the legend or key for the standard dots at the center of the
# triangles
def draw_std_markers_key_w_exp_neg(input_ax, col_ls, lab_ls, legend_marker_size, 
                                   legend_x_axis_label_fs):

    # Arbitrary values to space our different standard level dots
    # in the legend
    std_y = [0, 0, 1, 1]

    # Put the dots in the legend in the center of y-axis
    std_x = [0.25, 0.75, 0.25, 0.75]

    # x: Arbitrary x values to space markers in legend
    # y: 0.5 for all values to center them on y-axis
    # s: Use the same marker size as used in the main plot
    # c: Set the colors to our different standard level colors
    input_ax.scatter(
            x = std_x,
            y = std_y,
            s = legend_marker_size, 
            c = col_ls
    )

    # Turn off tick marks
    input_ax.yaxis.set_major_locator(plt.NullLocator())

    for curr_ind in range(len(std_x)):
        curr_x = std_x[curr_ind]

        curr_y = std_y[curr_ind]

        curr_col = col_ls[curr_ind]
        curr_text = lab_ls[curr_ind]
        curr_y = curr_y - (STD_MARKER_KEY_LABEL_Y_OFFSET)
        input_ax.text(x = curr_x, y = curr_y, s = curr_text, 
                      ha = 'center', va = 'center', fontsize = 12)

    # Add a little extra padding on each side of legend
    input_ax.set_ylim([-0.5, 1.5])

    # Turn off y tick labels (setting them to empty)
    input_ax.set_xticklabels([])

    # Remove all the outside lines of the plot (the frame)
    input_ax.spines['top'].set_visible(False)
    input_ax.spines['bottom'].set_visible(False)
    input_ax.spines['left'].set_visible(False)
    input_ax.spines['right'].set_visible(False)
    
    # Label this legend
    input_ax.set_xlabel("Standard Level", fontsize = legend_x_axis_label_fs)
    
    return input_ax

In [None]:
# Function to draw the pathogen grouping and labels. Called by draw_groups
# 
# group_name: Name of group to go in label
# grp_start: X value for first member of this group
# grp_end: X value for last member of this group
# nub_len: Lenght of nub to go straight up before angling for group line (like tick line)
# in_ax: Axes to draw label and line on
# line_ang: Angle to draw group end line at
# line_len: Length of group end line
# input_style: dictionary with details on how to style the polygon and label
# left_nudge: How much to move our sep line over to left (default is 0.5 so its halfway between orgs)
# right_nudge: Same as above but to right
# draw_line: Boolean if you want to actually draw lines instead of just put label
# debug: Boolean if you want function to print out messages
def plot_and_lab_groups_rec(group_name, grp_start, grp_end, nub_len, in_ax,
                            line_ang, line_len,
                            input_style, left_nudge = 0.5, right_nudge = 0.5, 
                            draw_line = True, debug = True):

    from matplotlib.patches import  Polygon
    from matplotlib.collections import PatchCollection
    # Grab the 'data' value of the very top of the y axis (basically y position of x axis line)
    trans = transforms.blended_transform_factory(in_ax.transData, in_ax.transAxes)
   

    _, y_hi = in_ax.transLimits.inverted().transform((0,1))


    # x values for lines
    bot_left_x = grp_start - left_nudge
    top_left_x = bot_left_x + math.cos(math.radians(line_ang)) * line_len

    bot_right_x = grp_end + right_nudge
    top_right_x = bot_right_x + math.cos(math.radians(line_ang)) * line_len


    # y values for lines
    y_bot = y_hi
    y_nub = y_hi + nub_len
    y_top = y_nub + math.sin(math.radians(line_ang)) * line_len


    if bot_left_x < in_ax.get_xlim()[0]:
        bot_left_x = in_ax.get_xlim()[0]

    if top_left_x < in_ax.get_xlim()[0]:
        top_left_x = in_ax.get_xlim()[0]

    if bot_right_x > in_ax.get_xlim()[1]:
        bot_right_x = in_ax.get_xlim()[1]

    if top_right_x > in_ax.get_xlim()[1]:
        top_right_x = in_ax.get_xlim()[1]

    lab_mid_x = (top_left_x + top_right_x) / 2
        
    # Now our x and y points
    x_points = [bot_left_x, bot_left_x, top_left_x, top_right_x, bot_right_x, bot_right_x]
    y_points = [y_bot, y_nub, y_top, y_top, y_nub, y_bot]

    pts = list(zip(x_points, y_points))
    patches = []
    if draw_line:


        
        line = Line2D(xdata = x_points, ydata = y_points,
                  linewidth = 2, 
                  linestyle = 'solid', 
                  color = input_style['outline_col'],
                  transform = in_ax.transData,
                  clip_on = False)

        in_ax.add_line(line)
        
        # Draw our divider line
        line = Polygon(xy = pts, closed = True)
        patches.append(line)
        col = PatchCollection(patches, fc = input_style['box_fill'], ec = 'none',
                              alpha = 0.05, clip_on = False)
        in_ax.add_collection(col)

        
        
    lab_style = dict(boxstyle = input_style['boxstyle'], 
                        pad = input_style['pad'],
                        fc = input_style['label_fill'],
                        ec = input_style['outline_col'],
                        linewidth = input_style['linewidth']
                    )
    
    # Add our group label
    in_ax.text(lab_mid_x, y_top, group_name, 
               bbox = lab_style, 
               ha = 'center', 
               va = 'center', 
               fontdict = input_style['lab_font'],
               transform = in_ax.transData)

## Define our Triangle markers

In [None]:
# We have to define our own equilateral triangle using matplotlib paths which
# we will use as markers on our heatmap
from matplotlib import path as mpath

# a is the length of 1 side of triangle
a = 1
h = ((a * np.cbrt(3)) / 2)
h_x = a / 2

left_x = 0 - h_x
right_x = 0 + h_x
center_x = 0

risk_left_y = 0 - (h / 2)
risk_right_y = 0 - (h / 2)
risk_center_y = 0 + h


prot_left_y = 0 + (h / 2)
prot_right_y = 0 + (h / 2)
prot_center_y = 0 - h

RISK_MARKER = mpath.Path(vertices = [(left_x, risk_left_y), (right_x, risk_right_y), (center_x, risk_center_y), (left_x, risk_left_y)], 
           codes = [mpath.Path.MOVETO, mpath.Path.LINETO, mpath.Path.LINETO, mpath.Path.CLOSEPOLY])

PROT_MARKER = mpath.Path(vertices = [(left_x, prot_left_y), (right_x, prot_right_y), (center_x, prot_center_y), (left_x, prot_left_y)],
           codes = [mpath.Path.MOVETO, mpath.Path.LINETO, mpath.Path.LINETO, mpath.Path.CLOSEPOLY])


# Define the coordinates of the equilateral triangle pointing upwards
RISK_MARKER = np.array([
    [0, 0.5],
    [-0.5, -0.5],
    [0.5, -0.5],
    [0, 0.5]
])

# Define the coordinates of the equilateral triangle pointing downwards
PROT_MARKER = np.array([
    [0, -0.5],
    [-0.5, 0.5],
    [0.5, 0.5],
    [0, -0.5]
])

## Making pathogen names simpler and ordering them

In [None]:
simp_org_name = {
                      'Human Polyomavirus BKV': 'BKV',
                      'Epstein-Barr Virus': 'EBV',
                      'Human Herpesvirus-7': 'HHV7',
                      'Herpes Simplex virus-1': 'HSV1',
                      'Herpes Simplex virus-2': 'HSV2',
                      'Human Herpesvirus-6': 'HHV6',
                      "Kaposi's Sarcoma-Associated Herpesvirus": 'KSHV',
                      'Human T-Lymphotropic Virus 1': 'HTLV-1',
                      'Human Immunodeficiency Virus': 'HIV',
                      'Varicella Zoster Virus': 'VZV',
                      'Merkel Cell Polyomavirus': 'MCV',
                      'Human Papillomavirus type-18': 'HPV18',
                      'Hepatitis C Virus': 'HCV',
                      'Human Polyomavirus JCV': 'JCV',
                      'Human Papillomavirus type-16': 'HPV16',
                      'Human Cytomegalovirus': 'CMV',
                      'Hepatitis B Virus': 'HBV',
    
            }

# Ugly way to manuallys specify org order
ordered_orgs = [

                # Bacteria
                ('$\\it{C. trachomatis}$', 0),
                ('$\\it{H. pylori}$', 1),
                
                # Alveolata
                ('$\\it{T. gondii}$', 2),
                
                # Riboviria - HBV more closely related to HIV and HTLV-1 than HCV atleast by phylo
                ('HBV', 3),
                ('HCV', 4),
                ('HIV', 5),
                ('HTLV1', 6),
                
                # Papovaviricetes
                ('BKV', 7),
                ('JCV', 8),
                ('MCV', 9),
                ('HPV16', 10),
                ('HPV18', 11),
                
                # Herpes
                ('CMV', 12),
                ('EBV', 13),
                ('HSV1', 14),
                ('HSV2', 15),
                ('HHV6', 16),
                ('HHV7', 17),
                ('VZV', 18),
                ('KSHV/HHV8', 19),
            
                ]

## Prep for plotting, setting plot params

In [None]:
# Min neg_log_ukb: 0.524
# Max neg_log_ukb: 5.874
# Num neg_log_ukb > 2.5: 8

# test
from matplotlib.lines import Line2D
import matplotlib.transforms as transforms
import matplotlib.patches as mpatches

# Current Method ####
#####################################
#                                   # 
#         Apply Max P Filter        #
#                                   #
#####################################
# Set the max p-value to 10
MAX_P_VAL = 2.5
lw.loc[lw['neg_log_ukb'] > MAX_P_VAL, 'neg_log_ukb'] = MAX_P_VAL

# Set the max OR to 5 (only 1 will get changed)
lw.loc[lw['ukb_OR'] > 5, 'ukb_OR'] = 5

HEIGHT = 30
WIDTH = 12
plt.rcParams["figure.figsize"] = (WIDTH, HEIGHT)


# How much to spread the smallest hexagon (smallest OR)
# and the largest hexagon (largest OR)
# SIZE_SCALE = 7000
SIZE_SCALE = 1400

MIN_SIZE = 200
MAX_SIZE = MIN_SIZE + SIZE_SCALE

size_min = min(lw.loc[:, 'ukb_OR'].values)
size_max = max(lw.loc[:, 'ukb_OR'].values)

#####################################
#                                   # 
#         Splitting ICDs            #
#                                   #
#####################################

std_icds = ['A60', 'B00', 'B02', 'B19', 'B24', 'B27']
left_icds = ['C90', 'C91', 'D50', 'D64', 'D75', 'D86', 
             'E04', 'E10', 'E78', 'F39', 'G20', 'G35', 'G43', 'G50', 'G91', 'H43', 'H47', 'H61', 'H66', 
             'I12', 'I26', 'I27', 'I31', 'I47', 'I65', 'I71', 'I72', 'I78', 'I83', 'I95', 
             'J03', 'J06', 'J20', 'J22', 'J30', 'J44', 'J45', 'J93', 'J94', 'J96', 'J98']

right_icds = ['K01', 'K05', 'K21', 'K22', 'K25', 'K26', 'K27', 'K29', 'K31', 'K44', 'K51', 
              'K52', 'K57', 'K58', 'K59', 'K65', 'K70', 'K74', 'K76', 'K83', 'K90', 'K92', 
              'L01', 'L03', 'L29', 'L43', 'L57', 'L70', 'L82', 'L84', 'L91', 'M05', 'M17', 
              'M20', 'M22', 'M25', 'M32', 'M47', 'M54', 'M67', 'M77', 'M79', 'M81', 'M86', 
              'M94', 'N13', 'N30', 'N32', 'N34', 'N39', 'N47', 'N48', 'N70', 'N81', 'N89', 
              'N93', 'O03']

#####################################
#                                   # 
#         Setting up Colors         #
#                                   #
#####################################
# Hex color (neg log p) palette
PALETTE = "flare"

min_color = min(lw.loc[:, 'neg_log_ukb'].values)
max_color = max(lw.loc[:, 'neg_log_ukb'].values)

# Create the palette
our_pal = sns.color_palette(PALETTE, as_cmap = True) 

# Standard colors
t1_col  = '#FDD017'
t2_col  = '#00FFFF'
unk_col = '#299d47'
neg_col = '#242424'
     
#####################################
#                                   # 
#         Collecting Data           #
#                                   #
#####################################

# Using the UKB OR for the size of the hexagon
size = lw['ukb_OR']

# Standard level marker  size
STD_MARKER_SIZE = 25
LEGEND_STD_MARKER_SIZE = 400
REP_ALPHA = 0.6
LEGEND_X_AXIS_LABEL_FONTSIZE = 12

SIDE_HIST_GRID_LAB_FONT = {'fontsize' : 10, 'fontweight' : 'bold', 
                           'alpha' : 0.75, 'color' : 'black'}

#####################################
#                                   # 
#         Prepping Data             #
#                                   #
#####################################
reorg_lw = prep_data(input_dat = lw, org_name_map = simp_org_name,
                     rep_res = res, cnr_res = cnr)


# X will be our nice organism names
x = reorg_lw['Organism']

# Y will be our abbreviated disease names
y = reorg_lw['dis_abbrev_lab']

# Numbering our diseases and orgs so that they can be plotted
# on the heatmap, which is indexed 0, 1, 2, etc. So we need to be 
# able to map an actual disease name back to an index value

# {'[B02] Herpes Zoster': 0,
# '[B19] Unspec Viral Hepatits': 1,

# Converting x and y axes to numbers and create lookup tables
# that we will convert our actual values to using.
y_names = [t for t in sorted(set([v for v in y]), reverse = True)]
y_to_num = {p[1]:p[0] for p in enumerate(y_names)}

x_names = x.unique().tolist() 
corr_x_names = []
for curr_name in x_names:
    
    if curr_name == 'Toxoplasma gondii':
        corr_x_names.append('$\\it{T. gondii}$')
    elif curr_name == 'Helicobacter pylori':
        corr_x_names.append('$\\it{H. pylori}$')
        
    elif curr_name == 'Chlamydia trachomatis': 
        corr_x_names.append('$\\it{C. trachomatis}$')
    else:
        corr_x_names.append(curr_name)
        
corr_x_names = sorted(corr_x_names, key=lambda x: [tup[1] for tup in ordered_orgs if tup[0] == x][0])
x_to_num = {p[1]:p[0] for p in enumerate(x_names)}
corr_x_to_num = {p[1]:p[0] for p in enumerate(corr_x_names)}
reorg_lw.loc[:, 'Organism'] = reorg_lw.loc[:, 'Organism'].replace({
                                        'Toxoplasma gondii' : '$\\it{T. gondii}$', 
                                        'Helicobacter pylori' : '$\\it{H. pylori}$',
                                        'Chlamydia trachomatis' : '$\\it{C. trachomatis}$'
                                    })

std_dat = reorg_lw.loc[reorg_lw['icd'].isin(std_icds), :]
left_dat = reorg_lw.loc[reorg_lw['icd'].isin(left_icds), :]
right_dat = reorg_lw.loc[reorg_lw['icd'].isin(right_icds), :]

std_dat = std_dat.sort_values('icd')
left_dat = left_dat.sort_values('icd')
right_dat = right_dat.sort_values('icd')

# Convert each dis/org pair from their int representation
# to actual color codes (RGB)
std_color = reorg_lw['std_lev_int'].values

std_cols = []
for curr_x in std_color:
    if curr_x == 0:
        std_cols.append(t1_col)
    elif curr_x == 1:
        std_cols.append(t2_col)
    elif curr_x == 2:
        std_cols.append(neg_col)
    else:
        std_cols.append(unk_col)


        
# We are assuming this is ordered properly ¯\(ツ)/¯
input_dat = lw.copy(deep = True)
dis_list = input_dat.loc[:, 'dis_abbrev_lab'].unique().tolist()
dis_list.reverse()


local_y_to_num = {}
counter = 0
for curr_dis in dis_list:
    local_y_to_num[curr_dis] = counter
    counter = counter + 1


### General settings

In [None]:
# General settings

DEBUG = False


risk_y_val = 1.1
prot_y_val = 0.05

risk_y_vals = [0.85, 0.945, 1.0]
prot_y_vals = [0.2, 0.105, 0.05]
risk_lab_y = 0.65
prot_lab_y = 0.375

OR_LAB_Y_OFFSET = 0.35

HEIGHT = 24

size_y_lim = (-0.5, 1.5)
size_x_lim = (-0.1, 0.6)

MARKER_X_OFFSET = 0.25

X_START = 0.1
ARROW_LEN_X = 0.8

ARROW_LEN_Y = 0.1
Y_CENTER = 0.5
Y_OFFSET = 0.35


LEGEND_STD_MARKER_SIZE = 400


NUB = 0.4
ANGLE = 50
UPPER_LINE_LEN = 4.5
LOWER_LINE_LEN = 3


TRANSFORM_ROTATES_TEXT_BOOL = False
       
HEAT_Y_TICK_LABEL_FONT_SIZE = 14
HEAT_X_TICK_LABEL_FONT_SIZE = 10
HEAT_X_TICK_LABEL_ROTATION = 30
HEAT_X_TICK_LABEL_VA = 'baseline'
HEAT_X_TICK_LABEL_HA = 'left'
HEAT_X_TICK_LABEL_ROTATION_MODE = 'default'

STD_MARKER_KEY_LABEL_Y_OFFSET = 0.45

LEFT_FIG_TEXT = "Figure 3a | Overview of all replicated pathogen-disease pairs (Tier 1 controls, C00 - J99)"
RIGHT_FIG_TEXT = "Figure 3b | Overview of all replicated pathogen-disease pairs (K00 - O99)"

## Left plot

In [None]:
## plot
# Run all the above functions
n_std_dis   = 6
n_left_dis  = len(left_dat['icd'].unique().tolist())
n_right_dis = len(right_dat['icd'].unique().tolist())

left_plot_n = len(std_icds) + len(left_icds) + 1
right_plot_n = len(right_icds) 


WIDTH = 12

fig_w = WIDTH 
fig_h = HEIGHT


# Setup figure
fig = plt.figure(figsize = (fig_w, fig_h), facecolor = 'w')

FIG_GRID_H = right_plot_n
FIG_GRID_W = 40

# Generate the figure grid
fig_grid = plt.GridSpec(nrows = FIG_GRID_H, ncols = FIG_GRID_W, 
                        hspace = 1.4, wspace = 0,
                        figure = fig)
STD_SPACER = 0
MAIN_SPACER_H = 0
STD_SPACER_H = 1
STD_SPACER_W = 0

# Width of histogram
SIDE_HIST_W = 3
STD_LEGEND_W = 5
SIZE_LEGEND_W = 15


# Separation betweeen histogram and legend
HIST_LEGEND_SEP_W = 0

CBAR_H = 3


LEGEND_SPACER_H = 1
#LEGEND_SPACER_W = FIG_GRID_W - SIZE_LEGEND_W - STD_LEGEND_W
LEGEND_SPACER_W = 2
print(f'FIG_H: {fig_h}, FIG_W: {fig_w}\nSIZE_LEGEND_W: {SIZE_LEGEND_W}, STD_LEGEND_W: {STD_LEGEND_W}, LEGEND_SPACER_W: {LEGEND_SPACER_W}')


# Create the main axis for the heatmap
std_ax = fig.add_subplot(
    fig_grid[
                 : n_std_dis, 
                 (SIDE_HIST_W + HIST_LEGEND_SEP_W) : ])

main_ax = fig.add_subplot(
    fig_grid[
                (n_std_dis + MAIN_SPACER_H) : (n_std_dis + STD_SPACER + n_left_dis), 
                 (SIDE_HIST_W + HIST_LEGEND_SEP_W) : 
            ], 
            sharex = std_ax)


std_side_hist_ax = fig.add_subplot(
    fig_grid[
                :n_std_dis, 
                : (SIDE_HIST_W + HIST_LEGEND_SEP_W)
            ],
            sharey = std_ax)

main_side_hist_ax = fig.add_subplot(
    fig_grid[
                (n_std_dis + MAIN_SPACER_H) : (n_std_dis + STD_SPACER + n_left_dis), 
                 :  (SIDE_HIST_W + HIST_LEGEND_SEP_W)
            ],
    sharey = main_ax)


size_ax = fig.add_subplot(
    fig_grid[
                (n_std_dis + STD_SPACER_H + n_left_dis + STD_SPACER_H) :  (FIG_GRID_H),
                : SIZE_LEGEND_W
            ])

FIG_GRID_W - (SIZE_LEGEND_W + LEGEND_SPACER_W)

std_leg_ax = fig.add_subplot(
    fig_grid[
                (n_std_dis + STD_SPACER_H + n_left_dis) :   (FIG_GRID_H - CBAR_H - LEGEND_SPACER_H),
                (FIG_GRID_W - SIZE_LEGEND_W) : 
        #(SIZE_LEGEND_W +  LEGEND_SPACER_W) :  (SIZE_LEGEND_W +  LEGEND_SPACER_W + STD_LEGEND_W) 
            ])

col_ax = fig.add_subplot(
    fig_grid[
                 FIG_GRID_H - CBAR_H ,  
                 (SIZE_LEGEND_W + LEGEND_SPACER_W) :  
            ])


#####################################
#                                   # 
#     Plotting Std Matrix           #
#                                   #
#####################################
plot_part_heatmap(input_ax = std_ax, input_dat = std_dat, 
                  rep_alpha = REP_ALPHA, palette = our_pal, 
                  risk_mark = RISK_MARKER, prot_mark = PROT_MARKER, 
                  HEAT_X_TICK_LABEL_ROTATION = HEAT_X_TICK_LABEL_ROTATION,
                  HEAT_Y_TICK_LABEL_FONT_SIZE = HEAT_Y_TICK_LABEL_FONT_SIZE,
                  HEAT_X_TICK_LABEL_FONT_SIZE = HEAT_X_TICK_LABEL_FONT_SIZE,
                  HEAT_X_TICK_LABEL_VA = HEAT_X_TICK_LABEL_VA,
                  HEAT_X_TICK_LABEL_HA = HEAT_X_TICK_LABEL_HA,
                  HEAT_X_TICK_LABEL_ROTATION_MODE = HEAT_X_TICK_LABEL_ROTATION_MODE,
                  TRANSFORM_ROTATES_TEXT_BOOL = TRANSFORM_ROTATES_TEXT_BOOL,
                  annotate_top = True)

        

#####################################
#                                   # 
#     Plotting Main Matrix          #
#                                   #
#####################################
plot_part_heatmap(input_ax = main_ax, input_dat = left_dat, 
                  rep_alpha = REP_ALPHA, palette = our_pal, 
                  risk_mark = RISK_MARKER, prot_mark = PROT_MARKER,
                  annotate_top = False)


#####################################
#                                   # 
#       Std Histogram               #
#                                   #
#####################################
# Put together data for disease histogram (along y-axis)
dis_hist = std_dat.groupby('dis_abbrev_lab').size().reset_index()
dis_hist.columns = ['dis', 'count']
dis_hist = dis_hist.sort_values('dis', ascending = False)


gen_side_hist_new(input_ax = std_side_hist_ax, 
              dis_hist = dis_hist, 
              lab_font = SIDE_HIST_GRID_LAB_FONT,
              y_val_ls = [-1.35, 6.3], INPUT_HEIGHT = 0.8, DEBUG = False)



#####################################
#                                   # 
#       Main Histogram              #
#                                   #
#####################################
# Put together data for disease histogram (along y-axis)
dis_hist = left_dat.groupby('dis_abbrev_lab').size().reset_index()
dis_hist.columns = ['dis', 'count']
dis_hist = dis_hist.sort_values('dis', ascending = False)


gen_side_hist_new(input_ax = main_side_hist_ax, 
              dis_hist = dis_hist, 
              lab_font = SIDE_HIST_GRID_LAB_FONT,
              y_val_ls = [-1.35], INPUT_HEIGHT = 0.80, DEBUG = False)

draw_groups(input_ax = std_ax, NUB = NUB, ANGLE = ANGLE,
            UPPER_LINE_LEN = UPPER_LINE_LEN,
            LOWER_LINE_LEN = LOWER_LINE_LEN)


#####################################
#                                   # 
#    Generate Odds Ratio Legend     #
#                                   #
#####################################
draw_horizon_or_legend(input_ax = size_ax, risk_mark = RISK_MARKER, prot_mark = PROT_MARKER,
                       #RISK_Y_VAL = risk_y_val, PROT_Y_VAL = prot_y_val,
                       RISK_Y_VALS = risk_y_vals, 
                       PROT_Y_VALS = prot_y_vals,
                       RISK_LAB_Y = risk_lab_y,
                       PROT_LAB_Y = prot_lab_y,
                       X_START = X_START, 
                       ARROW_LEN_X = ARROW_LEN_X,
                       ARROW_LEN_Y = ARROW_LEN_Y,
                       Y_CENTER = Y_CENTER,
                       Y_OFFSET = Y_OFFSET,                       
                       X_LIM = size_x_lim,
                       Y_LIM = size_y_lim,
                       MARKER_X_OFFSET = MARKER_X_OFFSET,
                       DEBUG = False)

#####################################
#                                   # 
#    Standard Level Legend          #
#                                   #
#####################################
# Setup the colors for each standard level point
std_col_ls = [t1_col, t2_col, neg_col, unk_col]

# Setup our labels for each standard level point
std_lab_ls = ['Tier 1', 'Tier 2', 'Expected\nNegative', 'Unknown']

draw_std_markers_key_w_exp_neg(input_ax = std_leg_ax, col_ls = std_col_ls, lab_ls = std_lab_ls,
                     legend_marker_size = LEGEND_STD_MARKER_SIZE,
                     legend_x_axis_label_fs = LEGEND_X_AXIS_LABEL_FONTSIZE)


#####################################
#                                   # 
#   -Log10(p-val) Colorbar Legend   #
#                                   #
#####################################
# Create the colorbar map using our neg log meta p palette
cbar_map = plt.cm.get_cmap(our_pal)

# Normalize colorbar to our min and max -Log10(meta p) values
norm = plt.Normalize(0.5, reorg_lw.loc[:, 'neg_log_ukb'].max())

# Combine the colorbar and normalization
mapper = mpl.cm.ScalarMappable(norm = norm, cmap = cbar_map)


cbar = fig.colorbar(mapper, cax = col_ax, 
             orientation = 'horizontal',
             aspect = 1, ticks = [0.5, 1.0, 1.5, 2, 2.5],
                   ticklocation = 'bottom')

cbar_tick_labs = ['0.5', '1.0', '1.5', '2.0', '2.5+']
cbar.set_ticklabels(cbar_tick_labs)

cbar.ax.set_xlabel("-Log10(UKB FDR)", fontsize = LEGEND_X_AXIS_LABEL_FONTSIZE,
                    labelpad = 14)

col_ax = cbar.ax

# Move the legends over to right a bit
X_PUSH = 0.125
# Triangle size
curr_pos = size_ax.get_position()
x0 = curr_pos.x0
y0 = curr_pos.y0

w = curr_pos.width
h = curr_pos.height

new_x0 = x0 + X_PUSH
size_ax.set_position([new_x0, y0, w, h])

# Legend colors
curr_pos = std_leg_ax.get_position()
x0 = curr_pos.x0
y0 = curr_pos.y0

w = curr_pos.width
h = curr_pos.height

new_x0 = x0 + X_PUSH
std_leg_ax.set_position([new_x0, y0, w, h])

# Colorbar
X_PUSH = 0.2
curr_pos = col_ax.get_position()
x0 = curr_pos.x0
y0 = curr_pos.y0

w = curr_pos.width
h = curr_pos.height

new_x0 = x0 + X_PUSH
col_ax.set_position([new_x0, y0, w, h])

fig.text(x = 0.121, y = 0.080, s = LEFT_FIG_TEXT, ha = 'left')

In [None]:
out_dir = f'{HOME_DIR}/manuscript/figures

fn = f"{out_dir}/Figure_3a.svg"
fig.savefig(fn, format = 'svg', dpi = 600, bbox_inches="tight")

### Right plot

In [None]:
# plot
HEAT_Y_TICK_LABEL_FONT_SIZE = 14
HEAT_X_TICK_LABEL_FONT_SIZE = 10
HEAT_X_TICK_LABEL_ROTATION = 30
HEAT_X_TICK_LABEL_VA = 'baseline'
HEAT_X_TICK_LABEL_HA = 'left'
HEAT_X_TICK_LABEL_ROTATION_MODE = 'default'

NUB = 0.4
ANGLE = 50
UPPER_LINE_LEN = 4.5
LOWER_LINE_LEN = 3

n_std_dis   = 6
n_left_dis  = len(left_dat['icd'].unique().tolist())
n_right_dis = len(right_dat['icd'].unique().tolist())

WIDTH = 12
HEIGHT = 24


fig_w = WIDTH 
fig_h = HEIGHT


# Setup figure
fig = plt.figure(figsize = (fig_w, fig_h), facecolor = 'w')

FIG_GRID_H = n_right_dis
FIG_GRID_W = 40

# Generate the figure grid
fig_grid = plt.GridSpec(nrows = FIG_GRID_H, ncols = FIG_GRID_W, 
                        hspace = 0.8, wspace = 0,
                        figure = fig)

STD_SPACER = 0
# Width of histogram
SIDE_HIST_W = 3

# Separation betweeen histogram and legend
HIST_LEGEND_SEP_W = 0

TOT_SIDE_W = SIDE_HIST_W + HIST_LEGEND_SEP_W



main_ax = fig.add_subplot(
    fig_grid[
                (n_std_dis + STD_SPACER) : , 
                TOT_SIDE_W:
            ], 
            sharex = std_ax)



main_side_hist_ax = fig.add_subplot(
    fig_grid[
                (n_std_dis + STD_SPACER) : , 
                (HIST_LEGEND_SEP_W) : TOT_SIDE_W
            ],
    sharey = main_ax)



        

#####################################
#                                   # 
#     Plotting Main Matrix          #
#                                   #
#####################################
plot_part_heatmap(input_ax = main_ax, input_dat = right_dat, 
                  rep_alpha = REP_ALPHA, palette = our_pal, 
                  risk_mark = RISK_MARKER, prot_mark = PROT_MARKER,
                  annotate_top = True)


draw_groups(input_ax = main_ax, NUB = NUB, ANGLE = ANGLE,
            UPPER_LINE_LEN = UPPER_LINE_LEN,
            LOWER_LINE_LEN = LOWER_LINE_LEN)


#####################################
#                                   # 
#       Main Histogram              #
#                                   #
#####################################
# Put together data for disease histogram (along y-axis)
dis_hist = right_dat.groupby('dis_abbrev_lab').size().reset_index()
dis_hist.columns = ['dis', 'count']
dis_hist = dis_hist.sort_values('dis', ascending = False)


gen_side_hist_new(input_ax = main_side_hist_ax, 
              dis_hist = dis_hist, 
              lab_font = SIDE_HIST_GRID_LAB_FONT,
              y_val_ls = [-1.35, 57.3], INPUT_HEIGHT = 0.80, DEBUG = False)


fig.text(x = 0.121, y = 0.080, s = RIGHT_FIG_TEXT, ha = 'left')

In [None]:
out_dir = f'{HOME_DIR}/manuscript/figures

fn = f"{out_dir}/Figure_3b.svg"
fig.savefig(fn, format = 'svg', dpi = 600, bbox_inches="tight")