# Load Libraries

In [None]:
# Data manipulation
import numpy as np
import pandas as pd

# Data science
import math
import scipy.stats as stats
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from statsmodels.stats.multitest import multipletests as mt

# Plots
import matplotlib as mpl
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

# Working with dates
from datetime import date,datetime
import dateutil

# Looping  progress
from tqdm.notebook import tqdm

# Reg expressions
import re

# Pretty table printing
import tabulate

import os
import subprocess

# Misc libraries
from IPython.display import display, HTML
#from IPython.core.display import display, HTML

# Set seaborn figure size, font size, and style
sns.set(rc={'figure.figsize':(11.7,8.27)})
sns.set(font_scale=1.5)
sns.set_style("white")

# Set Pandas options so we can see our entire dataframe
pd.options.display.max_rows = 10000
pd.options.display.max_columns = 10000
pd.options.display.max_colwidth = None

# Print our versions of this packages, this allows us to make sure
# we have the working versions we need. 
print(f"Pandas version: {pd.__version__}")

# Prep Environment

In [None]:
# Remove grey side bars
display(HTML("<style>.container { width:90% !important; }</style>"))

os.chdir('../../../../results/')

# Disease breakout

In [None]:
res = pd.read_excel('../manuscript/latest/supplemental_datasets/supplemental_dataset_2.xlsx',
                    sheet_name = 'Results')

In [None]:
# Revert our more human-friendly column names back to computer-friendly ones.
human_to_computer_column_dict = {'Disease' : 'disease', 
                                'ICD10' : 'icd', 
                                'Organism' : 'org', 
                                'Antibody' : 'anti', 
                                'Pair is Associated' : 'pair_is_associated', 
                                'Standard Level' : 'std_lev', 
                                'Replication Status' : 'rep_stat', 
                                'UKB adj p' : 'ukb_per_dis_bh_fdr_corr_nom_p', 
                                'TNX adj p' : 'tnx_per_dis_bh_fdr_corr_p', 
                                'UKB OR' : 'ukb_OR', 
                                'TNX OR' : 'tnx_OR', 
                                'UKB CI' : 'ukb_anti_CI', 
                                'TNX CI' : 'tnx_CI', 
                                'UKB nCase' : 'ukb_nCase',
                                'UKB nControl' : 'ukb_nControl',
                                'TNX nCase' : 'tnx_nCase', 
                                'TNX nControl' : 'tnx_nControl'
}

res = res.rename(columns = human_to_computer_column_dict)

# Revert human friendly organism names to computationally friendly ones
human_to_computer_org_dict = {
                                'BKV': 'bkv',
                                'C. trachomatis': 'c_trach',
                                'CMV': 'cmv',
                                'EBV': 'ebv',
                                'H. pylori': 'h_pylor',
                                'HBV': 'hbv',
                                'HCV': 'hcv',
                                'HHV-6': 'hhv_6',
                                'HHV-7': 'hhv_7',
                                'HIV': 'hiv',
                                'HPV-16': 'hpv_16',
                                'HPV-18': 'hpv_18',
                                'HSV-1': 'hsv_1',
                                'HSV-2': 'hsv_2',
                                'HTLV-1': 'htlv',
                                'JCV': 'jcv',
                                'KSHV': 'kshv',
                                'MCV': 'mcv',
                                'T. gondii': 't_gond',
                                'VZV': 'vzv'
                            }
res.loc[:, 'org'] = res.loc[:, 'org'].replace(human_to_computer_org_dict)

In [None]:
res.loc[:, 'icd_cat'] = res.loc[:, 'icd'].str[0]
res.loc[:, 'icd_site'] = res.loc[:, 'icd'].str[-2:]
res.loc[:, 'icd_site'] = res.loc[:, 'icd_site'].astype(int)
res.loc[:, 'icd_site_str'] = res.loc[:, 'icd'].str[-2:].str.rjust(width = 2, fillchar = '0')

## Massage Data

In [None]:
lw = res.loc[: , ['icd', 'disease', 'org', 
                  'ukb_per_dis_bh_fdr_corr_nom_p', 'tnx_per_dis_bh_fdr_corr_p', 
                  'ukb_OR', 'tnx_OR', 
                  'std_lev', 'rep_stat',
                  'icd_cat', 'icd_site', 'icd_site_str']]

lw = lw.rename(columns = {
    'ukb_per_dis_bh_fdr_corr_nom_p' : 'ukb_p',
    'tnx_per_dis_bh_fdr_corr_p' : 'tnx_p'
})

### Convert std lev to int for plotting

In [None]:
cond_list = [lw['std_lev'] == 'Tier 1', 
             lw['std_lev']  == 'Tier 2', 
             lw['std_lev']  == 'unk']
choices = [0, 1, 2]
lw.loc[:, 'std_lev_int'] = np.select(cond_list, choices)

### Generate Neg Log Meta P

In [None]:
UKB_THRESH =  0.3
TNX_THRESH =  0.01


# Generate neg_log_meta
lw.loc[:, 'neg_log_ukb'] = -np.log10(lw.loc[:, 'ukb_p'])
lw.loc[:, 'neg_log_tnx'] = -np.log10(lw.loc[:, 'tnx_p'])

### Get Abbreviated Disease Names

In [None]:
dis_dict = pd.read_excel('../dicts/dis_abbrev_dict.xlsx')
lw = pd.merge(lw, dis_dict[['icd', 'dis_abbrev']], 
         left_on = 'icd', 
         right_on = 'icd',
         how = 'left')

lw.loc[:, 'dis_abbrev_lab'] = "[" + lw.loc[:, 'icd'] + "] " + lw.loc[:, 'dis_abbrev']

### Get nice organism names

In [None]:
vir_dat = pd.read_excel('../dicts/antigen_dict.xlsx')
vir_dat = vir_dat[['Tag', 'Organism', 'Abbrev']].drop_duplicates()

lw = pd.merge(lw, vir_dat, 
         left_on = 'org', 
         right_on = 'Tag',
         how = 'left')

lw.loc[:, 'Abbrev'] = lw.loc[:, 'Abbrev'].str.replace('\xa0', '').str.strip().tolist()

corr_abbrev = [] 

for _, curr_row in tqdm(lw.iterrows(), total = len(lw)):
    
    curr_name = curr_row['Abbrev']
    
    if curr_name == 'T. gondii':
        corr_abbrev.append('$\it{T. gondii}$')
    elif curr_name == 'H.pylori':
        corr_abbrev.append('$\it{H. pylori}$')
        
    elif curr_name == 'C. trachomatis': 
        corr_abbrev.append('$\it{C. trachomatis}$')
    else:
        corr_abbrev.append(curr_name)
        
lw.loc[:, 'corr_abbrev'] = corr_abbrev
lw.loc[:, 'corr_abbrev'] = lw.loc[:, 'corr_abbrev'].str.replace('-', '')

### Mark significant tests

In [None]:
# Set all sig flags to false by default
lw['ukb_sig'] = False
lw['tnx_sig'] = False

# Find UKB and TNX sig results
lw.loc[lw['ukb_p'] < UKB_THRESH, 'ukb_sig'] = True
lw.loc[lw['tnx_p'] < TNX_THRESH, 'tnx_sig'] = True

# Set whether the risk flag to false by default
OR_THRESH = 1
lw['ukb_OR_risk'] = False
lw['tnx_OR_risk'] = False

# Set the risk flag based on if the OR is > 1 (true) or < 1(false)
lw.loc[lw['ukb_OR'] > OR_THRESH, 'ukb_OR_risk'] = True
lw.loc[lw['tnx_OR'] > OR_THRESH, 'tnx_OR_risk'] = True

# Always show the UKB results
lw['ukb_show'] = True

# Show the TNX if UKB is sig and there was TNX data
lw['tnx_show'] = False
lw.loc[((lw['ukb_sig'] == True) & (lw['rep_stat'] != 'could_not')), 'tnx_show'] = True

# Tier 1 Hepatitis Horizontal Funnel 

In [None]:
UKB_COLOR = '#5b9bd5'
UKB_COL = UKB_COLOR

TNX_COLOR = '#f4b183'
TNX_COL = TNX_COLOR

In [None]:
############################################
#                                          #
#     Changed HIV to *HIV for tier 1       #
#                                          #
############################################

# Ugly way to manuallys specify org order
ordered_orgs = [

                # Bacteria
                ('$\\it{C. trachomatis}$', 0),
                ('$\\it{H. pylori}$', 1),
                
                # Alveolata
                ('$\\it{T. gondii}$', 2),
                
                # Riboviria - HBV more closely related to HIV and HTLV-1 than HCV atleast by phylo
                ('HBV', 3),
                ('HCV', 4),
                ('HIV', 5),
                ('HTLV1', 6),
                
                # Papovaviricetes
                ('BKV', 7),
                ('JCV', 8),
                ('MCV', 9),
                ('HPV16', 10),
                ('HPV18', 11),
                
                # Herpes
                ('CMV', 12),
                ('EBV', 13),
                ('HSV1', 14),
                ('HSV2', 15),
                ('HHV6', 16),
                ('HHV7', 17),
                ('VZV', 18),
                ('KSHV/HHV8', 19),
            
                ]

ordered_org_ls = [x[0] for x in ordered_orgs]

In [None]:
# prep data ####
top_tier = "Tier 1"

out_dir = '../manuscript/figures/fig_2'

top_dis = 'B19'
top_title = 'Unspecified viral hepatitis [B19]'
fn = f"{out_dir}/fig_2_viral_hepatitis_B19_pub.pdf"

input_df = lw.copy(deep = True)

curr_dis_dat = input_df.loc[input_df['icd'] == top_dis, :].copy(deep = True)



melty = curr_dis_dat.melt(id_vars = ['icd', 'dis_abbrev_lab', 'Organism', 
                                 'Abbrev', 'corr_abbrev', 'std_lev'], 
        value_vars = ['ukb_p', 'tnx_p', 
                      'neg_log_ukb', 'neg_log_tnx', 
                      'ukb_OR', 'tnx_OR', 
                      'ukb_sig', 'tnx_sig', 
                      'ukb_OR_risk', 'tnx_OR_risk', 'ukb_show', 'tnx_show']
        )


###################################################
#                                                 #
#    Fix issue where HIV not tested in Tier 1's   #
#                                                 #
###################################################

curr_top_row = melty.iloc[0]
curr_icd = curr_top_row['icd']
curr_dis_abbrev_lab = curr_top_row['dis_abbrev_lab']


for curr_var in melty.loc[:, 'variable'].unique().tolist():
  
    curr_ls = [curr_icd, curr_dis_abbrev_lab, 'Human Immunodeficiency Virus', 'HIV', 'HIV', 'exp_neg']

    curr_ls.append(curr_var)
    
    if '_p' in curr_var:
        fin_val = 1
    elif 'neg_log_' in curr_var:
        fin_val = 0
        
    elif '_OR' in curr_var:
        fin_val = 1
        
    elif '_sig' in curr_var:
        fin_val = False
        
    elif '_OR_risk' in curr_var:
        fin_val = False
    
    elif curr_var == 'ukb_show':
        fin_val = True
        
    elif curr_var == 'tnx_show':
        fin_val = False
    
    curr_ls.append(fin_val)
    
    melty.loc[len(melty)] = curr_ls

    
    
top_dat = melty.copy()

if 'hiv' not in curr_dis_dat['org'].unique().tolist():
    top_ukb_n = sum(curr_dis_dat['ukb_show'] == True) + 1
else:
    top_ukb_n = sum(curr_dis_dat['ukb_show'] == True)

if 'hiv' not in curr_dis_dat['org'].unique().tolist():
    top_tnx_n = sum(curr_dis_dat['tnx_show'] == True) + 1
else:
    top_tnx_n = sum(curr_dis_dat['tnx_show'] == True)

In [None]:
# Figure Settings ####
from matplotlib import gridspec

# Figure size
HEIGHT = 4
WIDTH = 14

# Setup figure
fig = plt.figure(figsize = (WIDTH, HEIGHT), facecolor = 'white')


# Funnel Settings #####
EDGE_WIDTH = 1
UKB_BAR_WIDTH = 0.75
TNX_BAR_WIDTH = 0.65

spacer = 1
X_MARGIN = 0.01
HIV_TICK_TEXT = r"$\boxplus$"

top_ukb_ticks = [0, 1, 2, 3, 4, 5]
top_ukb_labs  = ['0', '1', '2', '3', '4', '5']
top_tnx_ticks = [0, 100, 200, 300]
top_tnx_labs  = ['0', '100', '200', '300']

# Setup top and plot ####

top_tot_cols = top_ukb_n + top_tnx_n + spacer

gs = fig.add_gridspec(nrows = 1, ncols = top_tot_cols,
                      wspace = 2, hspace = 0.65)

top_ukb_ax = fig.add_subplot(gs[0, :top_ukb_n])
top_tnx_ax = fig.add_subplot(gs[0, (top_ukb_n + spacer):(top_ukb_n + spacer + top_tnx_n)])


org_ls = input_df['Abbrev'].unique().tolist()

PROTECT_HATCH = None

UKB_SIG_STAR_FONT_DICT = {
                        'family': 'DejaVu Sans',
                        'color':  UKB_COL,
                        'weight': 'bold',
                        'size': 14,
}

TNX_SIG_STAR_FONT_DICT = {
                        'family': 'DejaVu Sans',
                        'color':  TNX_COL,
                        'weight': 'bold',
                        'size': 14,
}

STAR_BUMP = 1.02

ukb_x_ls = []
ukb_h_ls = []
ukb_w_ls = []
ukb_fc_ls = []
ukb_ec_ls = []
ukb_lw_ls = []
ukb_hatch_ls = []


tnx_x_ls = []
tnx_h_ls = []
tnx_w_ls = []
tnx_fc_ls = []
tnx_ec_ls = []
tnx_lw_ls = []
tnx_hatch_ls = []


# Top Plot ####
for curr_org_ind in tqdm(range(0, len(ordered_org_ls), 1), desc = top_dis):
    
    curr_org = ordered_org_ls[curr_org_ind]
    
    curr_org_dat = top_dat.loc[top_dat['corr_abbrev'] == curr_org, :]
    curr_abbrev =  curr_org_dat['corr_abbrev'].values[0]

    curr_ukb = curr_org_dat.loc[curr_org_dat['variable'] == 'neg_log_ukb', 'value'].values[0]
    curr_tnx = curr_org_dat.loc[curr_org_dat['variable'] == 'neg_log_tnx', 'value'].values[0]

    curr_ukb_sig = curr_org_dat.loc[curr_org_dat['variable'] == 'ukb_sig', 'value'].values[0]
    curr_tnx_sig = curr_org_dat.loc[curr_org_dat['variable'] == 'tnx_sig', 'value'].values[0]

    ukb_or = curr_org_dat.loc[curr_org_dat['variable'] == 'ukb_OR', 'value'].values[0]
    ukb_or_risk = curr_org_dat.loc[curr_org_dat['variable'] == 'ukb_OR_risk', 'value'].values[0]

    tnx_or = curr_org_dat.loc[curr_org_dat['variable'] == 'tnx_OR', 'value'].values[0]
    tnx_or_risk = curr_org_dat.loc[curr_org_dat['variable'] == 'tnx_OR_risk', 'value'].values[0] 
    
    curr_ukb_show = curr_org_dat.loc[curr_org_dat['variable'] == 'ukb_show', 'value'].values[0]
    curr_tnx_show = curr_org_dat.loc[curr_org_dat['variable'] == 'tnx_show', 'value'].values[0]

    if np.isnan(curr_ukb):
        curr_ukb = 0

    if np.isnan(curr_tnx):
        curr_tnx = 0

  

    if curr_ukb_show == True:
        #print(f"{curr_org} UKB show")
        ukb_x_ls.append(curr_org)
        ukb_h_ls.append(curr_ukb)
        ukb_w_ls.append(UKB_BAR_WIDTH)
        
            # If this is a risk
        if ukb_or_risk == True:

            # If sig fill in bar else just outline
            if curr_ukb_sig == True:
                #print(f"{curr_org} UKB sig risk")

                ukb_fc_ls.append(UKB_COL)
                ukb_ec_ls.append('none')
                ukb_lw_ls.append(EDGE_WIDTH)
                ukb_hatch_ls.append(None)
                
                top_ukb_ax.text(x = curr_org,
                                y = (curr_ukb * STAR_BUMP),
                                s = '*',
                                ha = 'center',
                                fontdict  = UKB_SIG_STAR_FONT_DICT)
                                
                
            else:
                #print(f"{curr_org} UKB NOT sig risk")

                ukb_fc_ls.append('none')
                ukb_ec_ls.append(UKB_COL)
                ukb_lw_ls.append(EDGE_WIDTH)
                ukb_hatch_ls.append(None)

        else:
            # If sig fill in bar else just outline
            if curr_ukb_sig == True:
                #print(f"{curr_org} UKB sig protect")

                ukb_fc_ls.append(UKB_COL)
                ukb_ec_ls.append('none')
                ukb_lw_ls.append(EDGE_WIDTH)                
                ukb_hatch_ls.append(None)

          

                top_ukb_ax.text(x = curr_org,
                                y = (curr_ukb * STAR_BUMP),
                                s = '*',
                                ha = 'center',
                                fontdict  = UKB_SIG_STAR_FONT_DICT)
                                
                
            else:
                #print(f"{curr_org} UKB NOT sig protect")

                ukb_fc_ls.append('none')
                ukb_ec_ls.append(UKB_COL)
                ukb_lw_ls.append(EDGE_WIDTH)   
                ukb_hatch_ls.append(PROTECT_HATCH)

                
    if curr_tnx_show == True:
        #print(f"{curr_org} TNX show")
        tnx_x_ls.append(curr_org)
        tnx_h_ls.append(curr_tnx)
        tnx_w_ls.append(TNX_BAR_WIDTH)
        
        if tnx_or_risk == True:            
            if curr_tnx_sig == True:
                tnx_fc_ls.append(TNX_COL)
                tnx_ec_ls.append('none')
                tnx_lw_ls.append(EDGE_WIDTH)
                tnx_hatch_ls.append(None)
  
                
                top_tnx_ax.text(x = curr_org,
                                y = (curr_tnx * STAR_BUMP),
                                s = '*',
                                ha = 'center',
                                fontdict  = TNX_SIG_STAR_FONT_DICT)

            else:
                #print(f"{curr_org} TNX NOT sig risk")

                tnx_fc_ls.append('none')
                tnx_ec_ls.append(TNX_COL)
                tnx_lw_ls.append(EDGE_WIDTH)
                tnx_hatch_ls.append(None)
         
        else:
            if curr_tnx_sig == True:
                #print(f"{curr_org} TNX sig protect")

                tnx_fc_ls.append(TNX_COL)
                tnx_ec_ls.append('none')
                tnx_lw_ls.append(EDGE_WIDTH)
                tnx_hatch_ls.append(PROTECT_HATCH)
                

                
                top_tnx_ax.text(x = curr_org,
                                y = (curr_tnx * STAR_BUMP),
                                s = '*',
                                ha = 'center',
                                fontdict  = TNX_SIG_STAR_FONT_DICT)
            else:
                #print(f"{curr_org} TNX NOT sig protect")

                tnx_fc_ls.append('none')
                tnx_ec_ls.append(TNX_COL)
                tnx_lw_ls.append(EDGE_WIDTH)
                tnx_hatch_ls.append(None)
            
            
            
top_ukb_ax.bar(x = ukb_x_ls, height = ukb_h_ls, width = ukb_w_ls, 
               color = ukb_fc_ls,  ec = ukb_ec_ls, linewidth = ukb_lw_ls,
               hatch = ukb_hatch_ls, align = 'center')      

top_tnx_ax.bar(x = tnx_x_ls, height = tnx_h_ls, width = tnx_w_ls, 
               color = tnx_fc_ls,  ec = tnx_ec_ls, linewidth = tnx_lw_ls,
               hatch = tnx_hatch_ls, align = 'center')   



                

# Datasource fonts
ukb_label_font = {
            'family': 'DejaVu Sans',
            'color':  UKB_COL,
            'weight': 'normal',
            'size': 16,
}


tnx_label_font = {
            'family': 'DejaVu Sans',
            'color':  TNX_COL,
            'weight': 'normal',
            'size': 16,
}

ylabel_font = {
                'family': 'DejaVu Sans',
                'color':  'black',
                'weight': 'normal',
                'size': 12,
}


                
dis_label_x = 0.1
dis_label_y = 1.02
dis_label_ha = 'left'
dis_label_va = 'center'
dis_label_size = 20

dis_label_font = {
                        'family': 'DejaVu Sans',
                        'color':  'black',
                        'weight': 'normal',
                        'size': 14,
}


x_tick_label = 13
x_tick_rotation = 90
x_tick_interval = 1
x_tick_ha = 'center'
x_tick_weight = 'normal'


y_tick_label = 11
y_tick_rotation = 0
y_tick_interval = 1




test_label_x = 0.0
test_label_y = 1.02
test_label_ha = 'left'

HIV_TICK_OFFSET = -1.20
HIV_X_OFFSET = -0.05
hiv_tick_font = {
                        'family': 'DejaVu Sans',
                        'color':  'black',
                        'weight': 'normal',
                        'size': 14,
}

X_TITLE_Y_POS = -0.5
X_TITLE_FONT = {
                        'family': 'DejaVu Sans',
                        'color':  'black',
                        'weight': 'normal',
                        'size': 15,
}

Y_TITLE_FONT = {
                        'family': 'DejaVu Sans',
                        'color':  'black',
                        'weight': 'normal',
                        'size': 15,
}

# Datasource labels
top_ukb_ax.text(x = test_label_x, y = test_label_y, s = "UK Biobank: Discovery", 
            va = 'bottom', ha = test_label_ha,
            transform = top_ukb_ax.transAxes, fontdict = ukb_label_font)

top_tnx_ax.text(x = test_label_x, y = test_label_y, s = "TriNetX: Replication", 
            va = 'bottom', ha = test_label_ha,
            transform = top_tnx_ax.transAxes, fontdict = tnx_label_font)


# X-axis title - need to use text not set_xlabel(becuase UKB has longer org names in plot)
UKB_X_AXIS_TITLE_TEXT = 'Pathogens with Serology Data in UKB'
TNX_X_AXIS_TITLE_TEXT = 'Pathogens with \nSignif. Result in UKB'

top_ukb_ax.text(x = 0.5, y = X_TITLE_Y_POS, 
                s = UKB_X_AXIS_TITLE_TEXT, 
                ha = 'center',
                fontdict = X_TITLE_FONT, transform = top_ukb_ax.transAxes)

top_tnx_ax.text(x = 0.5, y = X_TITLE_Y_POS, 
                s = TNX_X_AXIS_TITLE_TEXT, 
                ha = 'center',
                fontdict = X_TITLE_FONT, transform = top_tnx_ax.transAxes)



# Y-axis Titles
top_ukb_ax.set_ylabel('-Log10(FDR)', fontdict = Y_TITLE_FONT)


# Y-axis setting limits
top_ukb_ax.yaxis.set_ticks(top_ukb_ticks, top_ukb_labs)
top_tnx_ax.yaxis.set_ticks(top_tnx_ticks, top_tnx_labs)


for curr_ax in [top_ukb_ax, top_tnx_ax]:

    # Show decimal
    curr_ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%0.1f'))

    # Y-axis tick labeling
    curr_ax.tick_params(axis = 'y', which = 'both', 
                       left = True, labelleft = True,
                       right = False, labelright = False,
                       labelsize = y_tick_label, rotation = y_tick_rotation)
    
    # X-axis tick labeling
    curr_ax.tick_params(axis = 'x', which = 'both', 
                       bottom = True, labelbottom = True,
                       top = False, labeltop = False,
                       labelsize = x_tick_label, rotation = x_tick_rotation)

    # Set margins so bars on x-axis start closer to y-axis
    curr_ax.margins(x = X_MARGIN)
    
    for curr_lab in curr_ax.get_xticklabels():
        curr_text = curr_lab.get_text()

        if 'HIV' in curr_text:
            hiv_x = curr_lab._x
            hiv_y = curr_lab._y

            curr_ax.text(x = hiv_x + HIV_X_OFFSET, y = hiv_y + HIV_TICK_OFFSET, s = HIV_TICK_TEXT, 
                         ha = 'center',
                         fontdict = hiv_tick_font)
    

top_ukb_ax.spines['top'].set_visible(False)
top_ukb_ax.spines['right'].set_visible(False)
top_ukb_ax.spines['bottom'].set_visible(True)
top_ukb_ax.spines['left'].set_visible(True)
top_ukb_ax.set_anchor('E')


top_tnx_ax.spines['top'].set_visible(False)
top_tnx_ax.spines['right'].set_visible(False)
top_tnx_ax.spines['bottom'].set_visible(True)
top_tnx_ax.spines['left'].set_visible(True)
top_tnx_ax.set_anchor('E')


top_tnx_ax.margins(x = 0.075)

# Disease labels
fig.suptitle(t = top_title, 
             x = dis_label_x, y = dis_label_y,
             ha = dis_label_ha,
             va = dis_label_va,
             fontsize = dis_label_size)


In [None]:
fig.savefig(fn, format = 'pdf', dpi = 600, bbox_inches="tight")