# Load Libraries

In [None]:
# Data manipulation
import numpy as np
import pandas as pd

# Data science
import math
import scipy.stats as stats
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from statsmodels.stats.multitest import multipletests as mt
import random

# Plots
import matplotlib as mpl
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import ticker
from matplotlib import gridspec
from matplotlib.colors import to_hex
import colorcet as cc
from matplotlib.lines import Line2D

# Working with dates
from datetime import date,datetime
import dateutil

# Looping  progress
from tqdm.notebook import tqdm

# Reg expressions
import re

# Pretty table printing
import tabulate

import os
import subprocess

# Misc libraries
from IPython.display import display, HTML
#from IPython.core.display import display, HTML

# Set seaborn figure size, font size, and style
sns.set(rc={'figure.figsize':(11.7,8.27)})
sns.set(font_scale=1.5)
sns.set_style("white")

# Set Pandas options so we can see our entire dataframe
pd.options.display.max_rows = 10000
pd.options.display.max_columns = 10000
pd.options.display.max_colwidth = None

# Print our versions of this packages, this allows us to make sure
# we have the working versions we need. 
print(f"Pandas version: {pd.__version__}")

# Specify the directory where custom fonts are stored
font_dir = '/users/lapt3u/.fonts'

# Add fonts from the specified directory to Matplotlib's font manager
font_files = fm.findSystemFonts(fontpaths=[font_dir])
for font_file in font_files:
    fm.fontManager.addfont(font_file)

# Set the default font family to Arial (if Arial is available)
plt.rcParams['font.family'] = 'Arial'

print(plt.rcParams['font.family'])

# Prep Environment

In [None]:
HOME_DIR = "/data/pathogen_ncd"
os.chdir(f'{HOME_DIR}/results')

# Data prep

In [None]:
all_res = pd.read_excel('final_res_with_merged_data_04_12_2023.xlsx')
res = all_res.loc[all_res['rep_stat'] == 'replicated', :]
cnr = all_res.loc[all_res['rep_stat'] == 'could_not', :]

## Get Abbreviated Disease Names

In [None]:
dis_dat = pd.read_excel('../dicts/dis_abbrev_dict.xlsx')
res = pd.merge(res, dis_dat[['icd', 'dis_abbrev']], 
         left_on = 'icd', 
         right_on = 'icd',
         how = 'left')

res.loc[:, 'dis_abbrev_lab'] = "[" + res.loc[:, 'icd'] + "] " + res.loc[:, 'dis_abbrev']

cnr = pd.merge(cnr, dis_dat[['icd', 'dis_abbrev']], 
         left_on = 'icd', 
         right_on = 'icd',
         how = 'left')



cnr.loc[:, 'dis_abbrev_lab'] = "[" + cnr.loc[:, 'icd'] + "] " + cnr.loc[:, 'dis_abbrev']

## Get nice organism names

In [None]:
vir_dat = pd.read_excel('../dicts/antigen_dict.xlsx')
vir_dat = vir_dat[['Tag', 'Organism', 'Family', 'Species', 'Baltimore']].drop_duplicates()

res = pd.merge(res, vir_dat, 
         left_on = 'org', 
         right_on = 'Tag',
         how = 'left')

cnr = pd.merge(cnr, vir_dat, 
         left_on = 'org', 
         right_on = 'Tag',
         how = 'left')

In [None]:
simp_org_name = {
                      'Human Polyomavirus BKV': 'BKV',
                      'Epstein-Barr Virus': 'EBV',
                      'Human Herpesvirus-7': 'HHV-7',
                      'Herpes Simplex virus-1': 'HSV-1',
                      'Herpes Simplex virus-2': 'HSV-2',
                      'Human Herpesvirus-6': 'HHV-6',
                      "Kaposi's Sarcoma-Associated Herpesvirus": 'KSHV',
                      'Human T-Lymphotropic Virus 1': 'HTLV-1',
                      'Human Immunodeficiency Virus': 'HIV',
                      'Varicella Zoster Virus': 'VZV',
                      'Merkel Cell Polyomavirus': 'MCV',
                      'Human Papillomavirus type-18': 'HPV-18',
                      'Hepatitis C Virus': 'HCV',
                      'Human Polyomavirus JCV': 'JCV',
                      'Human Papillomavirus type-16': 'HPV-16',
                      'Human Cytomegalovirus': 'CMV',
                      'Hepatitis B Virus': 'HBV',

            }

# Heatmap of ICD10 Blocks (risk/protect split)

In [None]:
# Add blocks
all_res['icd_cat'] = all_res.loc[:, 'icd'].str[:1]
all_res['icd_site'] = all_res.loc[:, 'icd'].str[1:].astype(str)


# Create ICD10 Blocks with descriptions
all_res['icd_block'] = np.nan
all_res['icd_descr'] = np.nan



all_res.loc[((all_res['icd_cat'].isin(['A', 'B']))), 'icd_block'] = 'A00-B99'

all_res.loc[(
                (all_res['icd_cat'] == 'C') |
    
                ((all_res['icd_cat'] == 'D') & (all_res['icd_site'].astype(int) < 50))
            ), 'icd_block'] = 'C00-D49'


all_res.loc[((all_res['icd_cat'] == 'D') & (all_res['icd_site'].astype(int) >= 50)), 
            'icd_block'] = 'D50-D89'

all_res.loc[(all_res['icd_cat'] == 'E'), 'icd_block'] = 'E00-E90'

all_res.loc[(all_res['icd_cat'] == 'F'), 'icd_block'] = 'F00-F99'

all_res.loc[(all_res['icd_cat'] == 'G'), 'icd_block'] = 'G00-G99'



all_res.loc[((all_res['icd_cat'] == 'H') & (all_res['icd_site'].astype(int) < 60)), 
            'icd_block'] = 'H00-H59'


all_res.loc[((all_res['icd_cat'] == 'H') & (all_res['icd_site'].astype(int) >= 60)), 
            'icd_block'] = 'H60-H95'

all_res.loc[(all_res['icd_cat'] == 'I'), 'icd_block'] = 'I00-I99'
all_res.loc[(all_res['icd_cat'] == 'J'), 'icd_block'] = 'J00-J99'
all_res.loc[(all_res['icd_cat'] == 'K'), 'icd_block'] = 'K00-K93'
all_res.loc[(all_res['icd_cat'] == 'L'), 'icd_block'] = 'L00-L99'
all_res.loc[(all_res['icd_cat'] == 'M'), 'icd_block'] = 'M00-M99'
all_res.loc[(all_res['icd_cat'] == 'N'), 'icd_block'] = 'N00-N99'
all_res.loc[(all_res['icd_cat'] == 'O'), 'icd_block'] = 'O00-O99'



all_res.loc[((all_res['icd_cat'].isin(['A', 'B']))), 'icd_descr'] = '[A00-B99] Infectious'

all_res.loc[(
                (all_res['icd_cat'] == 'C') |
    
                ((all_res['icd_cat'] == 'D') & (all_res['icd_site'].astype(int) < 50))
            ), 'icd_descr'] = '[C00-D49] Neoplasms'


all_res.loc[((all_res['icd_cat'] == 'D') & (all_res['icd_site'].astype(int) >= 50)), 
            'icd_descr'] = '[D50-D89] Blood'

all_res.loc[(all_res['icd_cat'] == 'E'), 'icd_descr'] = '[E00-E90] Endocrine, Nutritional, Metabolic'

all_res.loc[(all_res['icd_cat'] == 'F'), 'icd_descr'] = '[F00-F99] Mental, Behavioral'

all_res.loc[(all_res['icd_cat'] == 'G'), 'icd_descr'] = '[G00-G99] Nervous'



all_res.loc[((all_res['icd_cat'] == 'H') & (all_res['icd_site'].astype(int) < 60)), 
            'icd_descr'] = '[H00-H59] Eye'


all_res.loc[((all_res['icd_cat'] == 'H') & (all_res['icd_site'].astype(int) >= 60)), 
            'icd_descr'] = '[H60-H95] Ear'

all_res.loc[(all_res['icd_cat'] == 'I'), 'icd_descr'] = '[I00-I99] Circulatory'
all_res.loc[(all_res['icd_cat'] == 'J'), 'icd_descr'] = '[J00-J99] Respiratory'
all_res.loc[(all_res['icd_cat'] == 'K'), 'icd_descr'] = '[K00-K93] Digestive'
all_res.loc[(all_res['icd_cat'] == 'L'), 'icd_descr'] = '[L00-L99] Skin, Subcutaneous'
all_res.loc[(all_res['icd_cat'] == 'M'), 'icd_descr'] = '[M00-M99] Musculoskeletal'
all_res.loc[(all_res['icd_cat'] == 'N'), 'icd_descr'] = '[N00-N99] Genitourinary'
all_res.loc[(all_res['icd_cat'] == 'O'), 'icd_descr'] = '[O00-O99] Pregnancy, Childbirth'

In [None]:
# Split up results into risk vs protect
all_res['is_risk'] = False
all_res.loc[all_res['ukb_OR'] > 1, 'is_risk'] = True

# Filter for CNR and REP
dat = all_res.loc[all_res['rep_stat'].isin(['could_not', 'replicated']), :]

pl_dat = pd.DataFrame(dat.groupby(['org', 'icd_descr', 
                                   'rep_stat' ]).size()).reset_index(drop = False)
pl_dat.columns = ['org', 'icd_descr', 'rep_stat',  'count']


# Push in the nicer org names
pl_dat = pl_dat.merge(vir_dat, left_on = 'org', right_on = 'Tag', how = 'left')
pl_dat.loc[:, 'Baltimore']  = pl_dat.loc[:, 'Baltimore'].fillna(0)
pl_dat.loc[:, 'Baltimore']  = pl_dat.loc[:, 'Baltimore'].astype(int)

pl_dat['simple_name'] = pl_dat.loc[:, 'Organism'].replace(simp_org_name)

In [None]:
spl_ls = []

for _, curr_row in tqdm(pl_dat.iterrows(), total = len(pl_dat)):

    curr_row_ls = curr_row.tolist()

    curr_org = curr_row['org']
    curr_dis = curr_row['icd_descr']

    curr_block = dat.loc[((dat['org'] == curr_org) & 
                      (dat['icd_descr'] == curr_dis) &
                     (dat['rep_stat'] == 'replicated')), :]

    curr_risk_dict = curr_block['is_risk'].value_counts().to_dict()


    if False in curr_risk_dict.keys():
        curr_protect = curr_risk_dict[False]
    else:
        curr_protect = 0

    if True in curr_risk_dict.keys():
        curr_risk = curr_risk_dict[True]
    else:
        curr_risk = 0

    curr_row_ls.extend([curr_risk, curr_protect])

    spl_ls.append(curr_row_ls)
    
up_pl_dat = pd.DataFrame(spl_ls, columns = pl_dat.columns.tolist() + ['risk_cnt', 'protect_cnt'])


pl_dat = up_pl_dat.copy(deep = True)

viruses = pl_dat.copy(deep = True)
viruses = viruses.loc[((viruses['Tag'] != 'c_trach') & 
                       (viruses['Tag'] != 'h_pylor') &
                       (viruses['Tag'] != 't_gond')), :]

viruses['formatted_name'] = viruses.loc[:, 'simple_name']
viruses['org_type'] = 'virus'

microbes = pl_dat.copy(deep = True)
microbes = microbes.loc[((microbes['Tag'] == 'c_trach') | 
                         (microbes['Tag'] == 'h_pylor') |
                         (microbes['Tag'] == 't_gond')), :]

corr_name_ls = []
for _, curr_row in tqdm(microbes.iterrows(), total = len(microbes)):
    
    curr_name = curr_row['Organism']
    
    if curr_name == 'Toxoplasma gondii':
        corr_name_ls.append('$\it{T. gondii}$')
    
    elif curr_name == 'Helicobacter pylori':
        corr_name_ls.append('$\it{H. pylori}$')

    elif curr_name == 'Chlamydia trachomatis': 
        corr_name_ls.append('$\it{C. trachomatis}$')   
        
microbes['formatted_name'] = corr_name_ls
microbes['org_type'] = 'microbe'

fin_pl = pd.concat([microbes, viruses])

rep = fin_pl.loc[fin_pl['rep_stat'] == 'replicated', :]

# Collect counts and order pathogens

In [None]:
tot_df = rep.groupby('org').agg({'count' : 'sum', 'risk_cnt' : 'sum', 'protect_cnt' : 'sum'})
tot_df = tot_df.reset_index(drop = False)

tots = pd.DataFrame(columns = rep.columns)

for curr_row in tot_df.values:

    curr_org = curr_row[0]
    curr_tot_cnt = curr_row[1]
    curr_tot_risk = curr_row[2]
    curr_tot_prot = curr_row[3]

    curr_org_info = rep.loc[rep['org'] == curr_org, :].iloc[0].copy(deep = True)

    curr_org_info['icd_descr'] = '[A00 - O99] All Diseases'
    curr_org_info['count'] = curr_tot_cnt
    curr_org_info['risk_cnt'] = curr_tot_risk
    curr_org_info['protect_cnt'] = curr_tot_prot

    tots = tots.append(curr_org_info, ignore_index = True)
    
tots['neg_protect_cnt'] = tots.loc[:, 'protect_cnt'] * -1

# Ugly way to manuallys specify org order
ordered_orgs = [

                # Bacteria
                ('$\\it{C. trachomatis}$', 0),
                ('$\\it{H. pylori}$', 1),
                
                # Alveolata
                ('$\\it{T. gondii}$', 2),
                
                # Riboviria - HBV more closely related to HIV and HTLV-1 than HCV atleast by phylo
                ('HBV', 3),
                ('HCV', 4),
                ('HIV', 5),
                ('HTLV1', 6),
                
                # Papovaviricetes
                ('BKV', 7),
                ('JCV', 8),
                ('MCV', 9),
                ('HPV16', 10),
                ('HPV18', 11),
                
                # Herpes
                ('CMV', 12),
                ('EBV', 13),
                ('HSV1', 14),
                ('HSV2', 15),
                ('HHV6', 16),
                ('HHV7', 17),
                ('VZV', 18),
                ('KSHV/HHV8', 19),
            
                ]

corr_x_names = [x[0] for x in ordered_orgs]
ordered_orgs_df = pd.DataFrame(ordered_orgs)
ordered_orgs_df.columns = ['formatted_name', 'enum_number']

rep.loc[:, 'formatted_name'] = rep.loc[:, 'formatted_name'].str.replace('-', '')
tots.loc[:, 'formatted_name'] = tots.loc[:, 'formatted_name'].str.replace('-', '')


# Per Leah, bold the ICD description but not the code range
rep.loc[:, 'icd_descr'] = rep.loc[:, 'icd_descr'].str.replace('] ', '] $\\bf{', regex = False) + '}$'
tots.loc[:, 'icd_descr'] = tots.loc[:, 'icd_descr'].str.replace('] ', '] $\\bf{', regex = False) + '}$'


## More processing

In [None]:
# Also, feel like I'm coding in circles
# Probably is a one-liner to do this with pandas but my brain is fried
uniq_icds = rep.loc[:, 'icd_descr'].unique().tolist()
uniq_orgs = rep.loc[:, 'org'].unique().tolist()


order_ls = []
for curr_org in tqdm(uniq_orgs):
    curr_org_total = 0
    for curr_icd in uniq_icds:
        
        curr_dat = rep.loc[((rep['org'] == curr_org) & (rep['icd_descr'] == curr_icd)), 'count']
        
        if len(curr_dat) > 0:
            curr_val = curr_dat.values.tolist()[0]
        
            if curr_val > 0:
                curr_org_total = curr_org_total + 1
            
    
    order_ls.append([curr_org, curr_org_total])
    
    
order_df = pd.DataFrame(order_ls, columns = ['org', 'tot_icd_blocks'])
order_df = order_df.sort_values('tot_icd_blocks', ascending = False)

ordered_org_ls = order_df['org'].tolist()

fin_rep = pd.DataFrame(columns = rep.columns)
for curr_org in tqdm(ordered_org_ls):
    curr_dat = rep.loc[rep['org'] == curr_org, :]
    fin_rep = pd.concat([fin_rep, curr_dat], axis = 0)
    
    
rep = fin_rep.copy(deep = True)

new_order = pd.DataFrame(list(enumerate(rep.loc[:, 'org'].unique().tolist())), columns = ['enum_number', 'org'])
new_order = new_order.merge(fin_rep.loc[:, ['org', 'formatted_name']], how = 'left',
                            on = 'org')
new_order = new_order.drop('org', axis = 1)
new_order = new_order.drop_duplicates()

In [None]:
# plot
from matplotlib import gridspec

ALPHA = 1

rep_piv = rep.pivot(index = 'formatted_name', columns = 'icd_descr', values = 'count')
rep_piv = rep_piv.fillna(0)
rep_piv = rep_piv.astype(int)
rep_piv = rep_piv.merge(new_order, left_index = True, right_on = 'formatted_name', how = 'left')
rep_piv = rep_piv.set_index('formatted_name')
rep_piv = rep_piv.sort_values('enum_number')
rep_piv.columns.name = 'icd_descr'
rep_piv = rep_piv.drop('enum_number', axis = 1)

rep_piv_str = pd.DataFrame('', columns = rep_piv.columns, index = rep_piv.index)

rep_piv_combos = rep_piv_str.stack().reset_index().loc[:, 
                                           ['formatted_name', 'icd_descr']].values.tolist()

for curr_org, curr_dis in tqdm(rep_piv_combos):
    ret = rep.loc[((rep['formatted_name'] == curr_org) &
         (rep['icd_descr'] == curr_dis)), :]

    if len(ret) == 0:
        rep_piv_str.loc[curr_org, curr_dis] = '0'

    else:
        curr_risk = ret['risk_cnt'].sum()
        curr_prot = ret['protect_cnt'].sum()
        curr_tot  = ret['count'].sum()


        rep_piv_str.loc[curr_org, curr_dis] = f'{curr_tot}\n[{curr_prot}|{curr_risk}]'


fig = plt.figure(figsize=(15, 15), facecolor = 'white')
gs = gridspec.GridSpec(nrows = 1, ncols = 19)

heat_ax = fig.add_subplot(gs[:, :])

sns.heatmap(data = rep_piv, cmap = 'flare', mask = (rep_piv == 0),
               annot = rep_piv_str, fmt = '',
               linewidths = 2, linecolor = (0, 0, 0, ALPHA),
               vmin = 1, ax = heat_ax, cbar = False, annot_kws={'size': 15})

heat_ax.set_ylabel('')
heat_ax.set_xlabel('')



LEGEND_X = 1.25
LEGEND_Y = -0.5
LEGEND_TEXT = 'Replicated Results\n[Protective | Risk]'

LEGEND_FONT = {
        'family': 'Arial',
        'color':  'black',
        'weight': 'normal',
        'size': 15,
}

LEGEND_BOX = {
                'boxstyle' : 'square',
                'fc' : 'white',
                'ec' : 'black'
}

heat_ax.text(x = LEGEND_X, 
             y = LEGEND_Y, 
             s = LEGEND_TEXT,
             fontdict = LEGEND_FONT,
             ha = 'center', va = 'center'),
             #bbox = LEGEND_BOX)


In [None]:
out_dir = f'{HOME_DIR}/manuscript/figures'
fn = f"{out_dir}/Figure_S5.svg"
fig.savefig(fn, format = 'svg', dpi = 600, bbox_inches="tight")