# Load Libraries

In [None]:
# Data manipulation
import numpy as np
import pandas as pd

# Data science
import math
import scipy.stats as stats
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from statsmodels.stats.multitest import multipletests as mt

# Plots
import matplotlib as mpl
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

# Working with dates
from datetime import date,datetime
import dateutil

# Looping  progress
from tqdm.notebook import tqdm

# Reg expressions
import re

# Pretty table printing
import tabulate

import os
import subprocess

# Misc libraries
from IPython.display import display, HTML
#from IPython.core.display import display, HTML

# Set seaborn figure size, font size, and style
sns.set(rc={'figure.figsize':(11.7,8.27)})
sns.set(font_scale=1.5)
sns.set_style("white")

# Set Pandas options so we can see our entire dataframe
pd.options.display.max_rows = 10000
pd.options.display.max_columns = 10000
pd.options.display.max_colwidth = None

# Print our versions of this packages, this allows us to make sure
# we have the working versions we need. 
print(f"Pandas version: {pd.__version__}")

# Prep Environment

In [None]:
# Remove grey side bars
display(HTML("<style>.container { width:90% !important; }</style>"))

os.chdir('../../../../results/')

# Make 4 panel plot

## Plotting functions

In [None]:
# Description:
# Generates a swarm plot for both case and control data
#
# Inputs:
# input_df: Input data that must have at least DISEASE_COL, ORG_COL
# input_case_ax: Axis object for case plot
# input_con_ax: Axis object for control plot
# DISEASE_COL: Column in input_df where disease status can be found, must
#              be encoded as 'case' or 'control'.
# ORG_COL: Column in input_df where VIRTUS normalized hit rates for a particular
#          pathogen can be found
# CASE_X_AXIS_LABEL: X-axis label to go under case plot
# CONS_X_AXIS_LABEL: X-axis label to go under control plot
# CON_COLOR: Color for control swarm plot markers (RGB)
# CASE_COLOR: Color for case swarm plot markers (RGB)
# SWARM_MARKER_SIZE: How large the swarm plot markers should be.
# SWARM_ALPHA: Opacity of the swarm plot markers
# Y_BASELINE: Y-axis lower limit for plot
# Y_TICK_LABEL_FD: Font size to use for Y-tick labels
# X_TICK_LABEL_FD: Font size to use for label indicating if plot above
#                  is a case or control plot.
# ANNOT_FD: Font size to output Mann-Whitney P-value label in
# ANNOT_LW: Linewidth to draw lines showing what was compared with the
#           Mann-Whitney test
# DEBUG: Boolean flag whether to print debug messages.

# Outputs:
#   Returns list of axes
#     input_case_ax: Axes containing drawn swarmplot for case data
#     input_con_ax:  Axes containing drawn swarmplot for control data

from matplotlib.lines import Line2D
import matplotlib.transforms as transforms
from scipy.stats import mannwhitneyu as mwu
from scipy.stats import ttest_ind as ttest
from matplotlib.patches import Polygon as poly

def gen_swarm_axs(input_df, input_case_ax, input_con_ax,
                 DISEASE_COL = 'sample_disease', ORG_COL = 'cmv',
                 CASE_X_AXIS_LABEL = 'SLE Cases', 
                 CONS_X_AXIS_LABEL = 'Controls',
                 CON_COLOR =  sns.color_palette()[0],
                 CASE_COLOR = sns.color_palette()[1],
                 SWARM_MARKER_SIZE = 4, SWARM_ALPHA = 0.7, 
                 Y_BASELINE = -0.0001,
                 Y_TICK_LABEL_FD = {'size' : 11},
                 X_TICK_LABEL_FD = {'size' : 11},   
                 ANNOT_FD = {'size' : 14},
                 ANNOT_LW = 1,
                 DEBUG = False):
    
    
    # Generate the case and control labels by first splitting up the data
    cases = input_df.loc[input_df[DISEASE_COL] == 'case', :]
    cons = input_df.loc[input_df[DISEASE_COL] == 'control', :]

    # Counting
    case_n = len(cases)
    con_n  = len(cons)

    # Finally generating label
    CASE_X_AXIS_LABEL = f"{CASE_X_AXIS_LABEL}\nn = {case_n}"
    CONS_X_AXIS_LABEL = f"{CONS_X_AXIS_LABEL}\nn = {con_n}"

    # Make the swarmplots!
    
    #############################
    #         Cases plot        #
    #############################
    sns.swarmplot(x = DISEASE_COL, y = ORG_COL,
                  data = cases, 
                  color = CASE_COLOR,
                  dodge = True,
                  size = SWARM_MARKER_SIZE,
                  linewidth = 1,
                  edgecolor = CASE_COLOR,
                  alpha = SWARM_ALPHA,
                  ax = input_case_ax)

    
    # Do some manipulation to case plot
    # Set the lower y-axis limit to handed in Y_BASELINE
    # Turn off x-axis tick labeling
    # Turn of x-axis and y-axis titles
    # Turn off the top and right spines of plot
    # Adjust the y-axis tick label parameters
    #############################
    #         Cases-axis        #
    #############################
    input_case_ax.set_ylim(Y_BASELINE, )
    input_case_ax.tick_params(axis = "x", 
                        top = False, labeltop = False,
                        bottom = False, labelbottom = False)
    
    input_case_ax.set_xlabel('')
    input_case_ax.set_ylabel('')

    

    input_case_ax.spines[['top', 'right']].set_visible(False)

    input_case_ax.tick_params(axis = "y", 
                    left = True, labelleft = True,
                    right = False, labelright = False,
                    labelsize = Y_TICK_LABEL_FD['size'])




    #############################
    #         Cons plot        #
    #############################
    sns.swarmplot(x = DISEASE_COL, y = ORG_COL,
                  data = cons, 
                  color = CON_COLOR,
                  size = SWARM_MARKER_SIZE,
                  linewidth = 1,
                  edgecolor = CON_COLOR,
                  alpha = SWARM_ALPHA,
                  ax = input_con_ax)



    # Do some manipulation controls plot
    # Set the lower y-axis limit to handed in Y_BASELINE
    # Turn off x-axis tick labeling
    # Turn of x-axis and y-axis titles
    # Turn off the top and right spines of plot
    # Adjust the y-axis tick label parameters
    #############################
    #         Cons-axis         #
    #############################
    input_con_ax.set_ylim(Y_BASELINE, )

    input_con_ax.tick_params(axis = "x", 
                        top = False, labeltop = False,
                        bottom = False, labelbottom = False)

        
    input_con_ax.set_xlabel('')
    input_con_ax.set_ylabel('')

    
    input_con_ax.spines[['top', 'right', 'left']].set_visible(False)

    input_con_ax.tick_params(axis = "y", 
                    left = False, labelleft = False,
                    right = False, labelright = False,
                    labelsize = Y_TICK_LABEL_FD['size'])
        
    
    # Calculate the Mann-Whitney p-value
    mw_p = mwu(x = cases[ORG_COL], y = cons[ORG_COL], alternative = 'two-sided')[1]


    # Add the Case or Control labels under the appropriate plots
    CASE_CENTER_X = 0.5
    CON_CENTER_X  = 0.5

    fig.text(CASE_CENTER_X, -0.025, CASE_X_AXIS_LABEL, 
             ha='center', va='top', fontdict = X_TICK_LABEL_FD,
             transform = input_case_ax.transAxes)
    fig.text(CON_CENTER_X, -0.025, CONS_X_AXIS_LABEL, 
             ha='center', va='top', fontdict = X_TICK_LABEL_FD,
             transform = input_con_ax.transAxes)


    CASE_CENTER_X = fig.transFigure.inverted().transform(
                        input_case_ax.transAxes.transform((0.5, 0.5)))[0]
    CON_CENTER_X = fig.transFigure.inverted().transform(
                        input_con_ax.transAxes.transform((0.5, 0.5)))[0]

    # Now figure out where the Mann-Whitney comparison lines should be drawn
    MID_PT = CON_CENTER_X - CASE_CENTER_X
    MID_PT = CASE_CENTER_X + ((CON_CENTER_X - CASE_CENTER_X) / 2)

    c = max(input_df[ORG_COL])
    y_min = 0

    y_bot = 0.87
    nub_len = 0.025
    y_top = y_bot + nub_len

    nub_x = CASE_CENTER_X + ((CON_CENTER_X - CASE_CENTER_X) / 2)

    x_points = [CASE_CENTER_X, CASE_CENTER_X, MID_PT, MID_PT, MID_PT, CON_CENTER_X, CON_CENTER_X]
    y_points = [y_bot, y_top, y_top, y_top + (nub_len/2), y_top, y_top, y_bot]
    pts = list(zip(x_points, y_points))

    # Draw the line for the Mann-Whitney comparison
    line = Line2D(xdata = x_points, ydata = y_points,
              linewidth = ANNOT_LW, 
              linestyle = 'solid', 
              color = 'black',
              clip_on = False,
              transform = fig.transFigure)

    fig.add_artist(line)


    # Add Mann-Whitney p-value text
    mw_p_str = "{:.2g}".format(mw_p)

    fig.text(x = MID_PT, y = 1.02 * (y_top + nub_len/2), 
             s = f"Mann-Whitney p = {mw_p_str}",
             fontdict = ANNOT_FD,
             ha = 'center')
    
    
    if DEBUG:
        print(f'y_axis_title_x_loc: {y_axis_title_x_loc}')
    
    return [input_case_ax, input_con_ax]

In [None]:
# Description:
# Generates a bar plot for both diff exp genes and non-diff exp
# genes for each GSE ID tested
#
# Inputs:
# input_df: Input data that must have at least DISEASE_NAME column,
#           GSE_ID column, gene_type column (DEG/unchanged),
#           clean_cell column, neg_log_corr_p column, Enrichment
#           column, and P_val column.
# input_ax: Axis to draw RELI result bar plots to
# DISEASE_NAME: Name of disease column in input_df ['sle', or 'uc']
# X_AXIS_LABEL: Label to go under all plots shown in input_ax (such as disease name) 
# Y_MAX: Max Y value to help in setting the y-ticks
# NON_FC: Color for unchanged bars 
# DIFF_FC: Color for differentially expressed bars
# Y_TICK_LABEL_FD: Fontsize for Y tick labels
# X_AXIS_TITLE_FD: Fontsize for X_AXIS_LABEL text
# X_TICK_LABEL_FD: Fontsize for GSE ID and Cell type labels on X-axis
# LEGEND_FONT_SIZE: Fontsize for text in legend
# GSE_LAB_Y: Y value to place the GSE ID and Cell type label at
# DIS_LAB_Y: Y value to place the X_AXIS_LABEL at
# LEG_Y: Y value to place the legend
# LEG_X: X value to place the legend
# SIG_LW: Linewidth for the signifcance threshold line to draw on plot
# ADD_LEGEND: Boolean flag whether to add a legend to the plot
# DEBUG: Boolean flag whether to print debug messages.

# Outputs:
#   Returns input_ax that we have drawn on

import matplotlib.transforms as transforms
from matplotlib.patches import Patch
from matplotlib.ticker import FormatStrFormatter

def gen_reli_axs(input_df, input_ax,
                 DISEASE_NAME = 'sle',
                 X_AXIS_LABEL = 'SLE Genetic Risk Loci',
                 Y_MAX = 10,
                 NON_FC =  sns.color_palette()[0],
                 DIFF_FC = sns.color_palette()[1],
                 Y_TICK_LABEL_FD = {'size' : 11},
                 X_AXIS_TITLE_FD = {'size' : 14},
                 X_TICK_LABEL_FD = {'size' : 11},   
                 LEGEND_FONT_SIZE = 14,
                 GSE_LAB_Y = -0.05,
                 DIS_LAB_Y = -0.15,
                 LEG_Y = 1,
                 LEG_X = 1, 
                 SIG_LW = 1,
                 ADD_LEGEND = False,
                 DEBUG = False):


    
    # Options for bars in bar plot
    DIFF_EC = 'black'
    NON_EC = 'black'
    BAR_WIDTH = 1
    BAR_SEP = 0.5
    GSE_SEP = 0.5
    DIS_SEP = 2
    
    # Options for FC labeling
    FOLD_LINE_COLOR = 'black'
    LAB_SPACER = 0.5
    NUB_DIV = 4


    # Grab the data for the disease of interest 
    filt_dat = input_df.loc[input_df['disease'] == DISEASE_NAME, :]
    
    # Get list of GSE IDs
    filt_gse = input_df['GSE_ID'].unique().tolist()


    lab_ls = []
    cell_lab_ls = []
    fold_ls = []

    # Loop through each GSE ID collecting info about it and RELI results
    # and creating bar plot
    curr_x = 0
    for ind, curr_gse in enumerate(filt_gse):
        curr_gse_dat = filt_dat.loc[filt_dat['GSE_ID'] == curr_gse, :]


        diff_dat = curr_gse_dat.loc[curr_gse_dat['gene_type'] == 'DEG', :]
        non_dat = curr_gse_dat.loc[curr_gse_dat['gene_type'] == 'unchanged', :]

        curr_cell = diff_dat['clean_cell'].tolist()[0]

        diff_neg_p = diff_dat['neg_log_corr_p'].tolist()[0]
        diff_enrich = diff_dat['Enrichment'].tolist()[0]
        diff_unadj_p = diff_dat['P_val'].tolist()[0]
        diff_unadj_neg_log_p = -(np.log10(diff_unadj_p))


        non_neg_p = non_dat['neg_log_corr_p'].tolist()[0]
        non_enrich = non_dat['Enrichment'].tolist()[0]
        non_unadj_p = non_dat['P_val'].tolist()[0]
        non_unadj_neg_log_p = -(np.log10(non_unadj_p))

        if DEBUG:
            print(f"[gse: {curr_gse}] x: {curr_x}, h: {diff_neg_p}")

        # Diff Bar
        input_ax.bar(
               x      = curr_x,
               height = diff_neg_p,
               width  = BAR_WIDTH,
               align  = 'center',
               fc     = DIFF_FC,
               ec     = DIFF_EC
               )

        mid_gse_x = curr_x + (BAR_WIDTH / 2) + (BAR_SEP / 2)
        curr_x = curr_x + BAR_WIDTH + BAR_SEP

        lab_ls.append([curr_gse, curr_cell, mid_gse_x])
        
        if DEBUG:
            print(f"[gse: {curr_gse}] x: {curr_x}, h: {non_neg_p}")

        # UnDiff Bar
        input_ax.bar(
               x      = curr_x,
               height = non_neg_p,
               width  = BAR_WIDTH,
               align  = 'center',
               fc     = NON_FC,
               ec     = NON_EC
           )
        
        # Increase the x value taking into account the width of the bar and separator size
        curr_x = curr_x + BAR_WIDTH + GSE_SEP
        
        if DEBUG:
            print(f"[gse: {curr_gse}] x: {curr_x}")
            print(f"[gse: {curr_gse}]GSE_lab_x: {mid_gse_x}")
            print("===========================================================")

   

    # Remove x-axis label and tick labels
    input_ax.set_xlabel('')
    input_ax.tick_params(axis = "x", 
                        top = False, labeltop = False,
                        bottom = False, labelbottom = False)

    # Remove y-axis label and set ticks based on the Y_MAX value
    # Also, remove the top and right spines for plot
    input_ax.set_ylabel('')
    input_ax.set_yticks(ticks = range(0, Y_MAX + 1, 2))
    input_ax.tick_params(axis = "y", 
                            left = True, labelleft = True,
                            right = False, labelright = False,
                            labelsize = Y_TICK_LABEL_FD['size'])
    input_ax.spines[['top', 'right']].set_visible(False)
        
    
    # Label the GSE and cell type
    trans = transforms.blended_transform_factory(input_ax.transData, input_ax.transAxes)
    for curr_gse, curr_cell, x_val in lab_ls:
        input_ax.text(x = x_val,
                      y = GSE_LAB_Y,
                      s = f"{curr_gse}\n[{curr_cell}]",
                      ha = 'center', 
                      va = 'bottom',
                      fontdict = X_TICK_LABEL_FD,
                      transform = trans)



    # Label the Dis genetic loci
    input_ax.text(x  = 0.5, 
                y  = DIS_LAB_Y,
                s  = X_AXIS_LABEL,
                ha = 'center',
                va = 'bottom',
                fontdict = X_AXIS_TITLE_FD,
                transform = input_ax.transAxes)


    # Add legend
    if ADD_LEGEND:

        legend_items = [Patch(facecolor = DIFF_FC, edgecolor = DIFF_EC, label = 'Diff Exp Genes'),
                        Patch(facecolor = NON_FC, edgecolor = NON_EC, label = 'Unchanged Genes')]


        input_ax.legend(handles = legend_items, fontsize = LEGEND_FONT_SIZE,
                        loc = 'center', bbox_to_anchor = (LEG_X, LEG_Y),
                        ncol = 1, fancybox = True, shadow = False)


    # Add 0.05 threshold
    thresh = -(np.log10(0.05))
    input_ax.axhline(y = thresh, xmin = input_ax.get_xlim()[0], xmax = input_ax.get_xlim()[1],
                     linewidth = SIG_LW, linestyle = '--',  color = 'black')

    
    input_ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))

    return input_ax

## Get VIRTUS data

### SLE

In [None]:
# File containing VIRTUS results for all SLE blood and B-cell samples
sle_virt_dat = '../virtus/results/sle/all_blood_cleaned_data_sle.csv'

dat = pd.read_csv(sle_virt_dat, sep = ',')
dat = dat.drop('Unnamed: 0', axis = 1)

tdat = dat.transpose()
tdat.columns = tdat.loc['org_name', :]

# Grab EBV type 1 results
ebv = tdat.loc[:, ['ebv_type1', 'sample_disease']]
ebv = ebv.drop(['org_name', 'org_grp'], axis = 0)

ebv.loc[:, 'ebv_type1'] = ebv.loc[:, 'ebv_type1'].astype(float)

### UC

In [None]:
# File containing VIRTUS results for all UC samples
uc_virt_dat = '../virtus/results/cleaned_uc_only_data_no_cov_filter.tsv'

uc = pd.read_csv(uc_virt_dat, sep = '\t')

tuc = uc.transpose()
tuc.columns = tuc.loc['org_name', :]

cmv = tuc.loc[:, ['cmv', 'sample_disease', 'is_ped', 'sample_gse']]
cmv = cmv.drop(['org_name', 'org_grp'], axis = 0)

cmv.loc[:, 'cmv'] = cmv.loc[:, 'cmv'].astype(float)

## Grab RELI Results

In [None]:
reli_dat_dir = "../manuscript/reli_stuff"

# Read in both the SLE and UC RELI results
sle_reli  = pd.read_excel(f"{reli_dat_dir}/2023-06-10_RELI_UC_and_SLE_cleaned_results.xlsx", sheet_name = 'EBV_SLE')
uc_reli = pd.read_excel(f"{reli_dat_dir}/2023-06-10_RELI_UC_and_SLE_cleaned_results.xlsx", sheet_name = 'CMV_UC')

# Matt: a reasonable cutoff would be corrected P <0.01 and 2-fold change or more.
# Filter our results based on Matt's reccomendation
sig_sle_reli_gses = sle_reli.loc[((sle_reli['Corrected_P_val'] < 0.01) & 
                                  (sle_reli['Enrichment'] >= 2)), 'GSE_ID'].tolist()
sig_reli_sle = sle_reli.loc[sle_reli['GSE_ID'].isin(sig_sle_reli_gses), :].copy(deep = True)

sig_uc_reli_gses = uc_reli.loc[((uc_reli['Corrected_P_val'] < 0.01) & 
                                (uc_reli['Enrichment'] >= 2)), 'GSE_ID'].tolist()
sig_reli_uc = uc_reli.loc[uc_reli['GSE_ID'].isin(sig_uc_reli_gses), :].copy(deep = True)

sig_reli_sle.loc[:, 'disease'] = 'sle'
sig_reli_uc.loc[:, 'disease'] = 'uc'

# Push all RELI results into a single dataframe
reli_fin = pd.concat([sig_reli_sle, sig_reli_uc])



reli_fin['neg_log_corr_p'] = -(np.log10(reli_fin['Corrected_P_val']))
reli_fin = reli_fin.sort_values(['disease', 'neg_log_corr_p'], ascending = [True, False])
gse_order = reli_fin['GSE_ID'].drop_duplicates().tolist()

# Clean up the cell types
reli_fin['clean_cell'] = reli_fin.loc[:, 'Cell_type'].replace({
                                                        'monocyte' : 'Monocyte',
                                                        'dendritic_cell' : 'Dendritic',
                                                        'peripheral_blood_mononuclear_cell' : 'PBMC',
                                                        'B-lymphocyte' : 'B-lymphocyte'
})

# Split them back out by disease
fin_reli_sle = reli_fin.loc[reli_fin['disease'] == 'sle', :].copy(deep = True)
fin_reli_uc  = reli_fin.loc[reli_fin['disease'] == 'uc', :].copy(deep = True)

## Start plotting

In [None]:
# plotting
CURR_DEBUG = False

# Inches
FIG_W = 11
FIG_H = 8.5

fig = plt.figure(figsize = (FIG_W, FIG_H), facecolor = 'w')

# Grid wspace and hspace values
WSPACE = 1
HSPACE = 1

# Serparation between swarm plots in grid slots
STATUS_SEP_W = 1

# Separator between disease plots in grid slots
TEST_SEP_W = 6

# Width of swarm plot in grid slots
SWARM_WIDTH = 10

# Height of swarm plot
SWARM_HEIGHT = 10

# Separator space between VIRTUS swarm plots and RELI bar plots
H_SEP = 2

# Split figure grid into 48 columns and 22 rows
FIG_GRID_H = 22
FIG_GRID_W = SWARM_WIDTH + STATUS_SEP_W + SWARM_WIDTH + TEST_SEP_W + \
                SWARM_WIDTH + STATUS_SEP_W + SWARM_WIDTH


# Generate the figure grid
fig_grid = plt.GridSpec(FIG_GRID_H, FIG_GRID_W, figure = fig,
                        wspace = WSPACE, hspace = HSPACE)

############################################
#                                          #
#         Some parameters                  #
#                                          #
############################################
RELI_DIFF_COLOR = sns.color_palette()[4]
RELI_UNDIFF_COLOR   = sns.color_palette()[7]

VIRTUS_CASE_COLOR = '#bb3d2e'
VIRTUS_CON_COLOR  = 'black'

Y_AXIS_TITLE_FD = {'size' : 12}
Y_TICK_LABEL_FD = {'size' : 11}
X_AXIS_TITLE_FD = {'size' : 12}
X_TICK_LABEL_FD = {'size' : 11}
ANNOT_FD = {'size' : 10}
ANNOT_LW = 1
LEGEND_FONT_SIZE = 10


LEFT_Y_AXIS_TITLE_OFFSET = 0.0625
RIGHT_Y_AXIS_TITLE_OFFSET = 0.0025

GSE_LAB_Y = -0.125
DIS_LAB_Y = -0.275

LEG_X = 0.7
LEG_Y = 0.85


EBV_VIRT_TEXT = 'EBV Gene Expression Level'
CMV_VIRT_TEXT = 'CMV Gene Expression Level'
RELI_Y_TEXT   = '-Log10(Corrected RELI p-value)'

EBV_RELI_TEXT = 'SLE Genetic Risk Loci overlap with\nEBV-induced gene expression changes'
CMV_RELI_TEXT = 'UC Genetic Risk Loci overlap with\nCMV-induced gene expression changes'

############################################
#                                          #
#         VIRTUS plots                     #
#                                          #
############################################

# Calculate where each disease and disease status swarm plots should go
# in terms of the grid
sle_case_swarm_st = 0
sle_case_swarm_end = sle_case_swarm_st + SWARM_WIDTH

sle_con_swarm_st = sle_case_swarm_end + STATUS_SEP_W
sle_con_swarm_end = sle_con_swarm_st + SWARM_WIDTH

tot_pl_width = sle_con_swarm_end


sle_left_lab_x = (sle_case_swarm_end - sle_case_swarm_st) / 2
sle_right_lab_x = (sle_con_swarm_end - sle_con_swarm_st)  / 2

uc_case_swarm_st = sle_con_swarm_end + TEST_SEP_W
uc_case_swarm_end = uc_case_swarm_st + SWARM_WIDTH

uc_con_swarm_st = uc_case_swarm_end + STATUS_SEP_W
uc_con_swarm_end = uc_con_swarm_st + SWARM_WIDTH


uc_left_lab_x = (uc_case_swarm_end - uc_case_swarm_st) / 2
uc_right_lab_x = (uc_con_swarm_end - uc_con_swarm_st)  / 2


sle_case_swarm_ax = fig.add_subplot(fig_grid[:SWARM_HEIGHT, sle_case_swarm_st: sle_case_swarm_end])
sle_con_swarm_ax = fig.add_subplot(fig_grid[:SWARM_HEIGHT, sle_con_swarm_st : sle_con_swarm_end], sharey = sle_case_swarm_ax)

uc_case_swarm_ax = fig.add_subplot(fig_grid[:SWARM_HEIGHT, uc_case_swarm_st: uc_case_swarm_end])
uc_con_swarm_ax = fig.add_subplot(fig_grid[:SWARM_HEIGHT, uc_con_swarm_st : uc_con_swarm_end], sharey = uc_case_swarm_ax)




# SLE VIRTUS plot
sle_case_swarm_ax, sle_con_swarm_ax = gen_swarm_axs(input_df = ebv,
                                                    input_case_ax = sle_case_swarm_ax,
                                                    input_con_ax = sle_con_swarm_ax,
                                                    DISEASE_COL = 'sample_disease', 
                                                    ORG_COL = 'ebv_type1',
                                                    CASE_X_AXIS_LABEL = 'SLE Cases', 
                                                    CONS_X_AXIS_LABEL = 'Controls',
                                                    CON_COLOR =  VIRTUS_CON_COLOR,
                                                    CASE_COLOR = VIRTUS_CASE_COLOR,
                                                    SWARM_MARKER_SIZE = 4, 
                                                    SWARM_ALPHA = 0.7,
                                                    Y_BASELINE = -0.0001,
                                                    Y_TICK_LABEL_FD = Y_TICK_LABEL_FD,
                                                    X_TICK_LABEL_FD = X_TICK_LABEL_FD,
                                                    ANNOT_FD = ANNOT_FD, 
                                                    ANNOT_LW = ANNOT_LW,
                                                    DEBUG = CURR_DEBUG)




# UC Adults and Peds VIRTUS plot
uc_case_swarm_ax, uc_con_swarm_ax = gen_swarm_axs(input_df = cmv,
                                                  input_case_ax = uc_case_swarm_ax,
                                                  input_con_ax = uc_con_swarm_ax,
                                                  DISEASE_COL = 'sample_disease', 
                                                  ORG_COL = 'cmv',
                                                  CASE_X_AXIS_LABEL = 'UC Cases', 
                                                  CONS_X_AXIS_LABEL = 'Controls',
                                                  CON_COLOR =  VIRTUS_CON_COLOR,
                                                  CASE_COLOR = VIRTUS_CASE_COLOR,
                                                  SWARM_MARKER_SIZE = 4, 
                                                  SWARM_ALPHA = 0.7,
                                                  Y_BASELINE = -0.001,
                                                  Y_TICK_LABEL_FD = Y_TICK_LABEL_FD,
                                                  X_TICK_LABEL_FD = X_TICK_LABEL_FD,
                                                  ANNOT_FD = ANNOT_FD, 
                                                  ANNOT_LW = ANNOT_LW, 
                                                  DEBUG = CURR_DEBUG)
    

    
############################################
#                                          #
#         RELI plots                       #
#                                          #
############################################ 
CURR_DEBUG = False

# Calculate where each disease and gene expression status RELI bar plots 
# should go in terms of the grid
PLOT_SEP_W = TEST_SEP_W
STATUS_SEP_W = 2


sle_reli_st = 0
sle_reli_end = sle_reli_st + tot_pl_width

uc_reli_st = sle_reli_end + PLOT_SEP_W
uc_reli_end = uc_reli_st + tot_pl_width


sle_lab_x = (sle_reli_end - sle_reli_st) / 2
uc_lab_x = (uc_reli_end - uc_reli_st)  / 2


sle_bar_ax = fig.add_subplot(fig_grid[(SWARM_HEIGHT + H_SEP) :, sle_reli_st: sle_reli_end])
uc_bar_ax = fig.add_subplot(fig_grid[(SWARM_HEIGHT + H_SEP):, uc_reli_st: uc_reli_end])    
    
    
if CURR_DEBUG:

    print(f"[SLE case swarm]: Height = {SWARM_HEIGHT}, grid_st = {sle_case_swarm_st}, grid_end = {sle_case_swarm_end}")
    print(f"[SLE con swarm]: Height = {SWARM_HEIGHT}, grid_st = {sle_con_swarm_st}, grid_end = {sle_con_swarm_end}")
    print(f"[UC case swarm]: Height = {SWARM_HEIGHT}, grid_st = {uc_case_swarm_st}, grid_end = {uc_case_swarm_end}")
    print(f"[UC con swarm]: Height = {SWARM_HEIGHT}, grid_st = {uc_con_swarm_st}, grid_end = {uc_con_swarm_end}")
    print(f"[SLE bar]: Height = {SWARM_HEIGHT}, grid_st = {sle_reli_st}, grid_end = {sle_reli_end}")
    print(f"[UC bar]: Height = {SWARM_HEIGHT}, grid_st = {uc_reli_st}, grid_end = {uc_reli_end}")

    print(FIG_GRID_W)
    
    
sle_bar_ax = gen_reli_axs(input_df = fin_reli_sle, 
                             input_ax = sle_bar_ax,
                             DISEASE_NAME = 'sle',
                             X_AXIS_LABEL = EBV_RELI_TEXT,
                             Y_MAX = 6,
                             NON_FC =  RELI_UNDIFF_COLOR,
                             DIFF_FC = RELI_DIFF_COLOR,
                             ADD_LEGEND = True,
                             DEBUG = CURR_DEBUG,
                             Y_TICK_LABEL_FD = Y_TICK_LABEL_FD,
                             X_AXIS_TITLE_FD = X_AXIS_TITLE_FD,
                             X_TICK_LABEL_FD = X_TICK_LABEL_FD,
                             LEGEND_FONT_SIZE = LEGEND_FONT_SIZE,
                             GSE_LAB_Y = GSE_LAB_Y,
                             DIS_LAB_Y = DIS_LAB_Y,
                             SIG_LW = 2,
                             LEG_Y = LEG_Y, 
                             LEG_X = LEG_X)


uc_bar_ax = gen_reli_axs(input_df = fin_reli_uc, 
                             input_ax = uc_bar_ax,
                             DISEASE_NAME = 'uc',
                             X_AXIS_LABEL = CMV_RELI_TEXT,
                             Y_MAX = 10,
                             NON_FC =  RELI_UNDIFF_COLOR,
                             DIFF_FC = RELI_DIFF_COLOR,
                             ADD_LEGEND = False,
                             DEBUG = CURR_DEBUG,
                             Y_TICK_LABEL_FD = Y_TICK_LABEL_FD,
                             X_AXIS_TITLE_FD = X_AXIS_TITLE_FD,
                             X_TICK_LABEL_FD = X_TICK_LABEL_FD,
                             LEGEND_FONT_SIZE = LEGEND_FONT_SIZE,
                             GSE_LAB_Y = GSE_LAB_Y,
                             DIS_LAB_Y = DIS_LAB_Y,
                             SIG_LW = 2,
                             LEG_Y = LEG_Y, 
                             LEG_X = LEG_X)




fig.text(x = 0 + LEFT_Y_AXIS_TITLE_OFFSET, 
         y = fig.transFigure.inverted().transform(sle_case_swarm_ax.transAxes.transform((0, 0.5)))[1],
                s = EBV_VIRT_TEXT, 
                fontdict = Y_AXIS_TITLE_FD,
                ha = 'center', va = 'center', 
                rotation = 90)


fig.text(x = 0.5 - RIGHT_Y_AXIS_TITLE_OFFSET, 
         y = fig.transFigure.inverted().transform(uc_con_swarm_ax.transAxes.transform((0, 0.5)))[1],
                s = CMV_VIRT_TEXT, 
                fontdict = Y_AXIS_TITLE_FD,
                ha = 'center', va = 'center', 
                rotation = 90)


fig.text(x = 0 + LEFT_Y_AXIS_TITLE_OFFSET, 
         y = fig.transFigure.inverted().transform(sle_bar_ax.transAxes.transform((0, 0.5)))[1],
                s = RELI_Y_TEXT, 
                fontdict = Y_AXIS_TITLE_FD,
                ha = 'center', va = 'center', 
                rotation = 90)

fig.text(x = 0.5 - RIGHT_Y_AXIS_TITLE_OFFSET, 
         y = fig.transFigure.inverted().transform(uc_bar_ax.transAxes.transform((0, 0.5)))[1],
                s = RELI_Y_TEXT, 
                fontdict = Y_AXIS_TITLE_FD,
                ha = 'center', va = 'center', 
                rotation = 90)


HALF_Y = 0.475
TOP_Y  = 0.95

LEFT_X = 0.025
HALF_X = 0.5



# Add panel letter labels
panel_font = {'size' : 20, 'weight' : 'bold'}

fig.text(x = LEFT_X, y = TOP_Y, s = 'A', fontdict = panel_font)
fig.text(x = HALF_X, y = TOP_Y, s = 'B', fontdict = panel_font)
fig.text(x = LEFT_X, y = HALF_Y, s = 'C', fontdict = panel_font)
fig.text(x = HALF_X, y = HALF_Y, s = 'D', fontdict = panel_font)

In [None]:
out_dir = '../manuscript/figures/fig_5'
fn = f"{out_dir}/fig_5_pub.pdf"
fig.savefig(fn, format = 'pdf', dpi = 600, bbox_inches="tight")