# Load Libraries

In [None]:
# Data manipulation
import numpy as np
import pandas as pd

# Data science
import math
import scipy.stats as stats
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from statsmodels.stats.multitest import multipletests as mt

# Plots
import matplotlib as mpl
import seaborn as sns
import matplotlib.pyplot as plt

# Working with dates
from datetime import date,datetime
import dateutil

# Looping  progress
from tqdm.notebook import tqdm

# Reg expressions
import re

# Pretty table printing
import tabulate

import os
import subprocess

# Misc libraries
from IPython.display import display, HTML
#from IPython.core.display import display, HTML

# Set seaborn figure size, font size, and style
sns.set(rc={'figure.figsize':(11.7,8.27)})
sns.set(font_scale=1.5)
sns.set_style("white")

# Set Pandas options so we can see our entire dataframe
pd.options.display.max_rows = 10000
pd.options.display.max_columns = 10000
pd.options.display.max_colwidth = None

# Print our versions of this packages, this allows us to make sure
# we have the working versions we need. 
print(f"Pandas version: {pd.__version__}")

# Prep Environment

In [None]:
# Remove grey side bars
display(HTML("<style>.container { width:80% !important; }</style>"))

os.chdir('../../../../results/')

# Load in data

In [None]:
dat = pd.read_excel('../manuscript/latest/supplemental_datasets/supplemental_dataset_2.xlsx',
                    sheet_name = 'Results')

In [None]:
# Revert our more human-friendly column names back to computer-friendly ones.
human_to_computer_column_dict = {'Disease' : 'disease', 
                                'ICD10' : 'icd', 
                                'Organism' : 'org', 
                                'Antibody' : 'anti', 
                                'Pair is Associated' : 'pair_is_associated', 
                                'Standard Level' : 'std_lev', 
                                'Replication Status' : 'rep_stat', 
                                'UKB adj p' : 'ukb_per_dis_bh_fdr_corr_nom_p', 
                                'TNX adj p' : 'tnx_per_dis_bh_fdr_corr_p', 
                                'UKB OR' : 'ukb_OR', 
                                'TNX OR' : 'tnx_OR', 
                                'UKB CI' : 'ukb_anti_CI', 
                                'TNX CI' : 'tnx_CI', 
                                'UKB nCase' : 'ukb_nCase',
                                'UKB nControl' : 'ukb_nControl',
                                'TNX nCase' : 'tnx_nCase', 
                                'TNX nControl' : 'tnx_nControl'
}

dat = dat.rename(columns = human_to_computer_column_dict)

# Revert human friendly organism names to computationally friendly ones
human_to_computer_org_dict = {
                                'BKV': 'bkv',
                                'C. trachomatis': 'c_trach',
                                'CMV': 'cmv',
                                'EBV': 'ebv',
                                'H. pylori': 'h_pylor',
                                'HBV': 'hbv',
                                'HCV': 'hcv',
                                'HHV-6': 'hhv_6',
                                'HHV-7': 'hhv_7',
                                'HIV': 'hiv',
                                'HPV-16': 'hpv_16',
                                'HPV-18': 'hpv_18',
                                'HSV-1': 'hsv_1',
                                'HSV-2': 'hsv_2',
                                'HTLV-1': 'htlv',
                                'JCV': 'jcv',
                                'KSHV': 'kshv',
                                'MCV': 'mcv',
                                'T. gondii': 't_gond',
                                'VZV': 'vzv'
                            }
dat.loc[:, 'org'] = dat.loc[:, 'org'].replace(human_to_computer_org_dict)


human_to_computer_std_lev_dict = {
    'Unknown' : 'unk',
    'Exp. Negative' : 'exp_neg'
}
dat.loc[:, 'std_lev'] = dat.loc[:, 'std_lev'].replace(human_to_computer_std_lev_dict)

human_to_computer_rep_dict = {
    'Did not attempt' : 'did_not_attempt', 
    'Replicated' : 'replicated', 
    'Failed Replication' : 'did_not', 
    'Could not attempt' : 'could_not'
}
dat.loc[:, 'rep_stat'] = dat.loc[:, 'rep_stat'].replace(human_to_computer_rep_dict)

In [None]:
dat['pl_rep_stat'] = dat['rep_stat'].replace({'did_not_attempt' : 'DNA',
                                              'replicated' :      'REP',
                                              'did_not'    :      'DNR',
                                              'could_not'  :      'CNR'})

dat['pl_std_lev'] = dat['std_lev']
dat['pl_std_lev'] = dat['pl_std_lev'].replace({'exp_neg' : 'Exp Neg',
                                               'unk' :      'Unknown'})

## Collect metrics

In [None]:
UKB_THRESH = 0.3

MET_COL_LS = ['Result', 'Metric', 'Group', 'Value']
rep_ls = ['REP', 'DNR', 'DNA', 'CNR']
grp_ls = ['Tier 1', 'Tier 2', 'Exp Neg', 'Unknown']

# Total ORG tests, Count
mets = pd.DataFrame([
    ['Total ORG tests', 'Count', 'Total', len(dat)],
    ['Total ORG tests', 'Count', 'Tier 1', len(dat.loc[dat['pl_std_lev'] == 'Tier 1', :])],
    ['Total ORG tests', 'Count', 'Tier 2', len(dat.loc[dat['pl_std_lev'] == 'Tier 2', :])],
    ['Total ORG tests', 'Count', 'Exp Neg', len(dat.loc[dat['pl_std_lev'] == 'Exp Neg', :])],
    ['Total ORG tests', 'Count', 'Unknown', len(dat.loc[dat['pl_std_lev'] == 'Unknown', :])]
], columns = MET_COL_LS)


mets = mets.append(
    
    pd.DataFrame([
    ['UKB Sig', 'Count', 'Total', 
         len(dat.loc[((dat['ukb_per_dis_bh_fdr_corr_nom_p'] < UKB_THRESH)), :])],
    ['UKB Sig', 'Count', 'Tier 1', 
         len(dat.loc[((dat['pl_std_lev'] == 'Tier 1') & 
                     (dat['ukb_per_dis_bh_fdr_corr_nom_p'] < UKB_THRESH)), :])],
    ['UKB Sig', 'Count', 'Tier 2', 
          len(dat.loc[((dat['pl_std_lev'] == 'Tier 2') & 
                     (dat['ukb_per_dis_bh_fdr_corr_nom_p'] < UKB_THRESH)), :])],
    ['UKB Sig', 'Count', 'Exp Neg', 
          len(dat.loc[((dat['pl_std_lev'] == 'Exp Neg') & 
                     (dat['ukb_per_dis_bh_fdr_corr_nom_p'] < UKB_THRESH)), :])],
    ['UKB Sig', 'Count', 'Unknown',
          len(dat.loc[((dat['pl_std_lev'] == 'Unknown') & 
                     (dat['ukb_per_dis_bh_fdr_corr_nom_p'] < UKB_THRESH)), :])],
], columns = MET_COL_LS))


grp_ls = ['Tier 1', 'Tier 2', 'Exp Neg', 'Unknown']


over_pct_ls = []
for curr_grp in grp_ls:
    tot_tests = mets.loc[((mets['Result'] == 'Total ORG tests') &
                         (mets['Metric'] == 'Count') &
                         (mets['Group'] == curr_grp)), 'Value'].tolist()[0]

    sig_tests = mets.loc[((mets['Result'] == 'UKB Sig') &
                             (mets['Metric'] == 'Count') &
                             (mets['Group'] == curr_grp)), 'Value'].tolist()[0]

    #print(f'{curr_grp}: total: {tot_tests} | sig: {sig_tests}')

    curr_overlap = f"{sig_tests} | {tot_tests}"
    

    curr_percent = sig_tests / tot_tests
    
    
    over_pct_ls.append(['UKB Sig', 'Overlap', curr_grp, curr_overlap])
    over_pct_ls.append(['UKB Sig', 'Percent', curr_grp, curr_percent])
    
mets = mets.append(pd.DataFrame(over_pct_ls, columns = MET_COL_LS))

met_ls = []
for curr_rep in rep_ls:

    curr_dat = dat.loc[dat['pl_rep_stat'] == curr_rep, :].copy(deep = True)

    # Handle Total
    curr_rep_tot = len(curr_dat)
    #print(f'{curr_rep}: total: {curr_rep_tot}')

    # For did not attempt total tests is total org tests not just sig
    if curr_rep == 'DNA':
        tot_tests = mets.loc[((mets['Result'] == 'Total ORG tests') &
                             (mets['Metric'] == 'Count') &
                             (mets['Group'] == 'Total')), 'Value'].tolist()[0]
    else:
        tot_tests = mets.loc[((mets['Result'] == 'UKB Sig') &
                             (mets['Metric'] == 'Count') &
                             (mets['Group'] == 'Total')), 'Value'].tolist()[0]
    
    curr_overlap = f"{curr_rep_tot} | {tot_tests}"
    curr_percent = curr_rep_tot / tot_tests
    
    met_ls.append([curr_rep, 'Count', 'Total', curr_rep_tot])
    met_ls.append([curr_rep, 'Overlap', 'Total', curr_overlap])
    met_ls.append([curr_rep, 'Percent', 'Total', curr_percent])

    for curr_grp in grp_ls:
        
        # For did not attempt total tests is total org tests not just sig
        if curr_rep == 'DNA':
            tot_sig_tests = mets.loc[((mets['Result'] == 'Total ORG tests') &
                                 (mets['Metric'] == 'Count') &
                                 (mets['Group'] == curr_grp)), 'Value'].tolist()[0]
        else:
            tot_sig_tests = mets.loc[((mets['Result'] == 'UKB Sig') &
                                 (mets['Metric'] == 'Count') &
                                 (mets['Group'] == curr_grp)), 'Value'].tolist()[0]


        curr_grp_dat = curr_dat.loc[curr_dat['pl_std_lev'] == curr_grp, :].copy(deep = True)
        curr_rep_stat_num = len(curr_grp_dat)

        #print(f'{curr_grp}: total: {curr_rep_stat_num}')

        curr_overlap = f"{curr_rep_stat_num} | {tot_sig_tests}"
        curr_percent = curr_rep_stat_num / tot_sig_tests

        
        met_ls.append([curr_rep, 'Count', curr_grp, curr_rep_stat_num])
        met_ls.append([curr_rep, 'Overlap', curr_grp, curr_overlap])
        met_ls.append([curr_rep, 'Percent', curr_grp, curr_percent])
        
mets = mets.append(pd.DataFrame(met_ls, columns = MET_COL_LS))

# Rename some of our dfs
res = dat.copy(deep = True)
dat = mets.copy(deep = True)

In [None]:
# Convert percents to percents instead of decimals
dat.loc[dat['Metric'] == 'Percent', 'Value'] = dat.loc[dat['Metric'] == 'Percent', 'Value'] * 100

In [None]:
# Custom sort dict
sort_dict = {
                'Tier 1'    : 0, 
                'Tier 2'    : 1,
                'Unknown'   : 2,
                'Exp Neg'   : 3,

            }

In [None]:
# Total number of tests per group
tmp = dat.loc[((dat['Result'] == 'Total ORG tests') &
             (dat['Metric'] == 'Count')), ['Group', 'Value']]
    
tmp = tmp.set_index('Group')

tots = tmp['Value'].to_dict()

# tots: {'Total': 8616, 'Tier 1': 8, 'Tier 2': 83, 'Exp Neg': 88, 'Unknown': 8437}

## Split data into UKB Res and TNX res

In [None]:
# UKB Res: Percents of total UKB tests that were sig or not
# TNX Res: Percents of significant UKB tests that were sig or not

ukb = dat.loc[dat['Result'].isin(['UKB Sig', 'DNA']), :]
tnx = dat.loc[dat['Result'].isin(['CNR', 'DNR', 'REP']), :]

# Build the plot

In [None]:
ukb_percs = ukb.loc[ukb['Metric'] == 'Percent', :]
ukb_percs = ukb_percs.loc[ukb_percs['Group'] != "Total"]
ukb_percs = ukb_percs.loc[ukb_percs['Result'] != "Total ORG tests"]
ukb_percs = ukb_percs.sort_values(by = ['Group'], key = lambda x: x.map(sort_dict))

tnx_percs = tnx.loc[tnx['Metric'] == 'Percent', :]
tnx_percs = tnx_percs.loc[tnx_percs['Group'] != "Total"]
tnx_percs = tnx_percs.loc[tnx_percs['Result'] != "Total ORG tests"]
tnx_percs = tnx_percs.sort_values(by = ['Group'], key = lambda x: x.map(sort_dict))


## Split into our 2 categories

In [None]:
# Split out the data
ukb_dna = ukb_percs.loc[ukb_percs['Result'] == 'DNA', :]
ukb_sig = ukb_percs.loc[ukb_percs['Result'] == 'UKB Sig', :]

## Build the data labels

### UKB

In [None]:
plt_order = ['DNA', 'UKB Sig']
plt_order = ['UKB Sig']
ukb_x_labs = ukb_percs['Group'].unique().tolist()

ukb_data_labs = [] 
# Build labels
for x in plt_order:
    for y in ukb_x_labs:
        curr_over = ukb[((ukb['Result'] == x) & (ukb['Group'] == y) &
                         (ukb['Metric'] == 'Overlap'))]['Value'].values[0]
        
        ukb_data_labs.append(curr_over)
        
up_ukb_data_labs = []

# Add thousands separators to data labels
for x in range(0, len(ukb_data_labs)):
    
    curr_lab = ukb_data_labs[x]
    curr_arr = curr_lab.split("|")
    
    # Numerator
    num = f'{int(curr_arr[0]):,}'
    
    # Denominator
    denom = f'{int(curr_arr[1]):,}'
    
    # Put them back together and add to fixed arr
    up_ukb_data_labs.append(f"{num} | {denom}")
    
ukb_data_labs = up_ukb_data_labs

### TNX

In [None]:
# TNX labels
plt_order = ['DNR', 'REP']
plt_order = ['REP']

tnx_x_labs = tnx_percs['Group'].unique().tolist()

tnx_rep = pd.DataFrame(columns = ['Result', 'Metric', 'Group', 'Value'])

tnx_data_labs = [] 
# Build labels and data for plotting
for x in tnx_x_labs:
    curr_grp = tnx[((tnx['Group'] == x)  &
                 (tnx['Metric'] == 'Count'))]


    curr_dnr = curr_grp[curr_grp['Result'] == 'DNR']['Value'].values[0]
    curr_rep = curr_grp[curr_grp['Result'] == 'REP']['Value'].values[0]
    curr_cnr = curr_grp[curr_grp['Result'] == 'CNR']['Value'].values[0]

    curr_tot = curr_dnr + curr_rep

    dnr_perc = (curr_dnr / curr_tot) * 100
    rep_perc = (curr_rep / curr_tot) * 100

    # Add our data labels - might want to skip if == 0
    tnx_data_labs.append(f'{curr_rep} | {curr_tot}')

    # Now add percents to data for plots
    tnx_rep.loc[len(tnx_rep)] = ['REP', 'Percent', x, rep_perc]

tnx_up_data_labs = []

# Add thousands separators to data labels
for x in range(0, len(tnx_data_labs)):
    
    curr_lab = tnx_data_labs[x]
    curr_arr = curr_lab.split("|")
    
    # Numerator
    num = f'{int(curr_arr[0]):,}'
    
    # Denominator
    denom = f'{int(curr_arr[1]):,}'
    
    # Put them back together and add to fixed arr
    tnx_up_data_labs.append(f"{num} | {denom}")
    
tnx_data_labs = tnx_up_data_labs

## Build the plot

In [None]:
# Plot
import matplotlib.patches as patches
from matplotlib.patches import FancyArrowPatch as Arrow
import matplotlib.lines as lines

FIG_W = 14
FIG_H = 6

UKB_COLOR = '#5b9bd5'
TNX_COLOR = '#f4b183'

TEXT_LABEL_FONT_SIZE = 13
X_TICK_LABEL_FONT_SIZE = 13
Y_TICK_LABEL_FONT_SIZE = 13


DASHED_LWD = 4
DASHED_LINE_ALPHA = 0.95

BAR_W = 0.95
TNX_BAR_W = 0.95


fig = plt.figure(figsize = (FIG_W, FIG_H), facecolor = 'white')

gs = fig.add_gridspec(nrows = 1, ncols = 2, wspace = 0.50)

ukb_ax = fig.add_subplot(gs[0, 0:1])
tnx_ax = fig.add_subplot(gs[0, 1:])



ukb_ax.bar(ukb_x_labs, label = 'UKB Significant',
           height = ukb_sig['Value'],
           bottom = 0,
           color = UKB_COLOR,
           edgecolor = 'white',
           linewidth = 5,
           width = BAR_W)


tnx_ax.bar(tnx_x_labs, label = 'Replicated',
       height = tnx_rep['Value'],
       bottom = 0,
       edgecolor = 'white',
       color = TNX_COLOR, 
       linewidth = 5,
       width = TNX_BAR_W)


# get our rectangles
ukb_rects = ukb_ax.patches
tnx_rects = tnx_ax.patches

# Add the data labels
# https://stackoverflow.com/a/28931750
# UKB
for curr_rect, curr_lab in zip(ukb_rects, ukb_data_labs):
    curr_height = curr_rect.get_height()
    
    # If the height is 0 then it doesn't actually show up 
    # in plot and thus we don't want to label it.
    if curr_height == 0:
        continue
    
    else:
        # Center our label on bar
        curr_x = curr_rect.get_x() + (curr_rect.get_width() / 2)
        
        # Need to pull y-value because bars are stacked, then just
        # center it in that space.
        #curr_y = curr_rect.get_y() + (curr_height / 2)
        curr_y = curr_rect.get_height() + 2

        ukb_ax.text(curr_x, curr_y, curr_lab, 
                color = 'black',
                ha="center", va="bottom", fontsize = TEXT_LABEL_FONT_SIZE)

# TNX
for curr_rect, curr_lab in zip(tnx_rects, tnx_data_labs):
    curr_height = curr_rect.get_height()
    
    # If the height is 0 then it doesn't actually show up 
    # in plot and thus we don't want to label it.
    if curr_height == 0:
        print("caught!")
        continue
    
    else:
        # Center our label on bar
        curr_x = curr_rect.get_x() + (curr_rect.get_width() / 2)
        
        # Need to pull y-value because bars are stacked, then just
        # center it in that space.
        #curr_y = curr_rect.get_y() + (curr_height / 2)
        curr_y = curr_rect.get_height() + 2

        tnx_ax.text(curr_x, curr_y, curr_lab, 
                color = 'black',
                ha = "center", va = "bottom", 
                fontsize = TEXT_LABEL_FONT_SIZE)
                
for curr_ax in [ukb_ax, tnx_ax]:
    curr_ax.spines['top'].set_visible(False)
    curr_ax.spines['right'].set_visible(False)
    

ukb_ax.tick_params(axis="x", 
                   bottom = False, top = False, 
                   labelbottom = True, labeltop = False,
                   labelsize = X_TICK_LABEL_FONT_SIZE)
   
ukb_ax.tick_params(axis="y", left = True, right = False, 
                   labelleft = True, labelright = False,
                   labelsize = Y_TICK_LABEL_FONT_SIZE,
                   direction = 'out')

tnx_ax.tick_params(axis="x", 
                   bottom = False, top = False, 
                   labelbottom = True, labeltop = False,
                   labelsize = X_TICK_LABEL_FONT_SIZE)
   
tnx_ax.tick_params(axis="y", left = True, right = False, 
                   labelleft = True, labelright = False,
                   labelsize = Y_TICK_LABEL_FONT_SIZE,
                   direction = 'out')

# Draw the dashed green lines around signif bars in UKB
heights = [p.get_height() for p in ukb_ax.patches]
widths = [p.get_width() for p in ukb_ax.patches]


x_coords = [-0.5, -0.5, 0.5, 0.5, 1.5, 1.5, 2.5, 2.5]
y_coords = [0, heights[0], heights[0], heights[1], heights[1],
            heights[2], heights[2], 0]


TOP_FUDGE = 1.005
coord_path = np.array([[-0.5, 0], 
                       [-0.5, (heights[0] * TOP_FUDGE)], 
                       [0.5, (heights[0] * TOP_FUDGE)],
                       [0.5, (heights[1] * TOP_FUDGE)], 
                       [1.5, (heights[1] * TOP_FUDGE)],
                       [1.5, (heights[2] * TOP_FUDGE)], 
                       [2.5, (heights[2] * TOP_FUDGE)],
                       [2.5, (heights[3] * TOP_FUDGE)],
                       [3.5, (heights[3] * TOP_FUDGE)],
                       [3.5, 0]])



path_poly = patches.Polygon(coord_path,
                            alpha = 1,
                            linestyle = "dashed",
                            linewidth = DASHED_LWD,
                            ec = '#55a868', 
                            fc = 'None')
                
ukb_ax.add_patch(path_poly)



# Draw arrow over from UKB results to TNX results
# Using workaround from: https://github.com/matplotlib/matplotlib/issues/17284#issuecomment-772820638
ARROW_ST_X = 0.4
ARROW_END_X = 1.25

ARROW_ST_Y = 0.75
ARROR_END_Y = 0.75

arrow_line = Arrow((ARROW_ST_X, ARROW_ST_Y), (ARROW_END_X, ARROR_END_Y), 
           arrowstyle = '-',
           shrinkA = 0, shrinkB = 5,
           connectionstyle = "arc3, rad = -0.25",
           linestyle = "dashed",
           linewidth = DASHED_LWD,
           ec = '#55a868',
           fc = '#55a868', 
           transform = ukb_ax.transAxes)


ar_style = patches.ArrowStyle.CurveFilledB(angleA = 0)
arrow_head = Arrow((ARROW_ST_X, ARROW_ST_Y), (ARROW_END_X, ARROR_END_Y), 
                   arrowstyle = ar_style,
                   shrinkA = 0, shrinkB = 0,
                   connectionstyle = "arc3, rad = -0.25",
                   linestyle = "solid",
                   linewidth = 0,
                   ec = None,
                   fc = '#55a868', mutation_scale = 50,
                   transform = ukb_ax.transAxes)

ar_text_x = ARROW_ST_X + ((ARROW_END_X - ARROW_ST_X) / 2)
ar_text_y = ARROW_ST_Y + .15
ar_text_va = 'center'
ar_text_ha = 'center'

ar_text_fd = {
            'family': 'DejaVu Sans',
            'color':  'black',
            'weight': 'normal',
            'size': 14,  
}

ar_text = ukb_ax.text(x = ar_text_x, 
                      y = ar_text_y,
                      s = "Test significant pairs\nfor replication",
                      ha = ar_text_ha,
                      va = ar_text_va, 
                      fontdict = ar_text_fd,
                      transform = ukb_ax.transAxes)
fig.patches.extend([arrow_head, arrow_line])



test_label_x = 0.0
test_label_y = 1.08
test_label_ha = 'left'

ukb_label_font = {
            'family': 'DejaVu Sans',
            'color':  UKB_COLOR,
            'weight': 'normal',
            'size': 16,
}


tnx_label_font = {
            'family': 'DejaVu Sans',
            'color':  TNX_COLOR,
            'weight': 'normal',
            'size': 16,
}


Y_TITLE_FONT = {
            'family': 'DejaVu Sans',
            'color':  'black',
            'weight': 'normal',
            'size': 15,
}

# Datasource labels
ukb_data_lab = ukb_ax.text(x = test_label_x, y = test_label_y, 
                           s = "UK Biobank: Discovery", 
                           va = 'bottom', ha = test_label_ha,
                           transform = ukb_ax.transAxes, 
                           fontdict = ukb_label_font)

tnx_data_lab = tnx_ax.text(x = test_label_x, y = test_label_y, 
                           s = "TriNetX: Replication", 
                           va = 'bottom', ha = test_label_ha,
                           transform = tnx_ax.transAxes, 
                           fontdict = tnx_label_font)


ukb_ax.set_ylabel('% Sig Disease-Pathogen of All Pairs Tested', fontdict = Y_TITLE_FONT)
tnx_ax.set_ylabel('% UKB Sig Pairs with TNX Data', fontdict = Y_TITLE_FONT)


TIER_1_TEXT = 'Tier 1\nPositives'
TIER_2_TEXT = 'Tier 2\nPositives'
EXP_NEG_TEXT = 'Expected\nNegatives'

ukb_ticks = ukb_ax.get_xticklabels()
tnx_ticks = tnx_ax.get_xticklabels()

for curr_tick in ukb_ticks:

    if curr_tick.get_text() == 'Tier 1':
        curr_tick.set_text(TIER_1_TEXT)

    elif curr_tick.get_text() == 'Tier 2':
        curr_tick.set_text(TIER_2_TEXT) 

    elif curr_tick.get_text() == 'Exp Neg':
        curr_tick.set_text(EXP_NEG_TEXT)         
        
for curr_tick in tnx_ticks:

    if curr_tick.get_text() == 'Tier 1':
        curr_tick.set_text(TIER_1_TEXT)

    elif curr_tick.get_text() == 'Tier 2':
        curr_tick.set_text(TIER_2_TEXT) 
        
    elif curr_tick.get_text() == 'Exp Neg':
        curr_tick.set_text(EXP_NEG_TEXT)      
        
        
ukb_ticks_loc = ukb_ax.get_xticks()
tnx_ticks_loc = tnx_ax.get_xticks()

ukb_ax.set_xticks(ukb_ticks_loc)
tnx_ax.set_xticks(tnx_ticks_loc)

ukb_ax.set_xticklabels(ukb_ticks)
tnx_ax.set_xticklabels(tnx_ticks)


ukb_ax.set_ylim(0, 102)
tnx_ax.set_ylim(0, 102)

In [None]:
out_dir = '../manuscript/figures/fig_3'
fn = f"{out_dir}/fig_3_pub.pdf"
fig.savefig(fn, format = 'pdf', dpi = 600, bbox_inches="tight")