## Comparative analysis of Cas9 virus screens

This notebook loads Cas9 screening MAGeCK results and plot a heat map of hit genes found in at least one screen.
Note that 

### parameters

In [None]:
########
# data #
########
from pathlib import Path
mageck_out_dir = Path.cwd() / "input" / "CRISPRflow_out"
matia_et_al_path = "input/journal.ppat.1010800.s010.xlsx" # Matia et al. Vaccinia 
#exclude unpublished screens
exclude_list = ["KSHV_CRISPRi", "Broeckel_SARS-CoV1_", "Flather_MERS_CMK_", "AT_SARS2_omicron_BA1_", "Wang_NL63_"]

#######################################
# identify hit genes from all screens #
#######################################
import copy,operator
# define the metric of interest
metric = "pos|fdr"
#define the threshold of the metric
threshold = 2 # p<0.01
#transformation
transform_method = "-log10" # "-log10", "none"
#define how to compare (lt = less than, gt = greater than)
myoperater = copy.deepcopy(operator.gt) # operator.gt

############
# plotting #
############
grid_search = True
plot_outdir = 'plots_output'
import plotly.express as px
funcAnno_types_color_pal = px.colors.qualitative.Plotly + [px.colors.qualitative.T10[7]] + [px.colors.qualitative.T10[8]] + [px.colors.qualitative.T10[9]] 
cell_types_color_pal = px.colors.qualitative.Plotly #bokeh.palettes.Tab10[10] # 10 cell types
fam_color_pal = px.colors.qualitative.T10#bokeh.palettes.Set3[10] # 10 families

screen_name_mapping = { # define screen names to be shown in the plots, format: {screen_name: "plotting_screen_name"}
 'Huh7_229E_R1-R2_229E_infected_vs_Control_for_tw': "Baggen et al. 229E",
 'Huh7_SARS-CoV-2_R1-R2_SARS-CoV-2-low-stringency_vs_Control_for_tw': "Baggen et al. SARS-CoV-2",
 'A549_SARS-CoV-2_R1-R2_SARS-CoV-2-MOI-0.01_vs_SARS-CoV-2-control': "Daniloski et al. SARS-CoV-2",
 'HeLa_Diep_EV': "Diep et al. EV",
 'HeLa_Diep_RV': "Diep et al. RV",
 'HAP1_Flint_EBOV': "Flint et al. EBOV",
 'A549_Han_InfluenzaA': "Han et al. Influenza A",
 'HFF_HCMV_R1_CMV-surviving_vs_CMV-t0': "Hein et al. HCMV",
 'HAP1_Hoffmann_YFV': "Hoffmann et al. YFV",
 'HAP1_Hoffmann_ZIKV': "Hoffmann et al. ZIKV",
 'A549_Liu_LCMV': "Liu et al. LCMV",
 'Huh7.5.1_Marceau_DENV': "Marceau et al. DENV",
 'Huh7.5.1_Marceau_HCV': "Marceau et al. HCV",
 'Huh7.5.1_Ooi_DENV1_276RKI': "Ooi et al. DENV-1-276RKI",
 'Huh7.5.1_Ooi_DENV2_429557': "Ooi et al. DENV-2-429557",
 'Huh7.5.1_Ooi_DENV3_Philippines-H871856': "Ooi et al. DENV-3-PHL/H871856",
 'Huh7.5.1_Ooi_DENV4_BC287-97': "Ooi et al. DENV-4-BC287/97",
 'Huh7.5_229E_A-B-C_229E_infected_vs_229E_Control': "Schneider et al. 229E 33C",
 'Huh7.5_NL63_A-B-C_NL63_infected_vs_NL63_Control': "Schneider et al. NL63 33C",
 'Huh7.5_OC43_A-B-C_OC43_infected_vs_OC43_Control': "Schneider et al. OC43 33C",
 'Huh7.5_SARS-CoV-2-33_A-B-C_SARS-CoV-2_infected_vs_SARS-CoV-2_Control': "Schneider et al. SARS-CoV-2 33C",
 'Huh7.5_SARS-CoV-2-37_A-B-C_SARS-CoV-2_infected_vs_SARS-CoV-2_Control': "Schneider et al. SARS-CoV-2 37C",
 'A549_Sunshine_RSV_screen1': "This study RSV s1 ",
 'A549_Sunshine_RSV_screen2': "This study RSV s2",
 'Huh7.5.1_Wang_229E': "Wang et al. 229E",
 'Huh7.5.1_Wang_OC43': "Wang et al. OC43",
 'Huh7.5.1_Wang_SARS-CoV2': "Wang et al. SARS-CoV-2",
 'HeLa_Matia_Vaccinia': "Matia et al. Vaccinia",
 'HT29-DKO_HT29DKO_PeVA1': "Qiao et al. PeV-A1",
 'HT29-DKO_HT29DKO_PeVA2': "Qiao et al. PeV-A2",
 'U87MG_ReoT3D_U87M': 'Richards et al. ReoT3D'
}

In [None]:
import os
from scipy.stats import zscore
import pandas as pd
import seaborn as sns; sns.set_theme(color_codes=True)
import matplotlib.pyplot as plt
from pathlib import Path
from matplotlib.patches import Patch
plt.rcParams.update({'figure.max_open_warning': 0})
from utils import *

plot_outdir = Path.cwd() / "output" / plot_outdir
plot_outdir.mkdir(parents=True, exist_ok=True)

threshold_to_display = reverse_transform_metric(threshold, transform_method)

### load data

In [None]:
#read in all screen data
df_list = {}
rra_folder_count = 0
for item in os.listdir(mageck_out_dir):
    path2item = mageck_out_dir / item
    path2rra = path2item / "rra"
    #gene_summary_file = f"{item}.gene_summary.txt"
    #path2genesummary = os.path.join(path2rra,gene_summary_file)
    if os.path.isdir(path2rra):
        rra_folder_count+=1
        #print(f"processing output directory {path2rra}")  
        genesummary_file = find_summary_file(os.listdir(path2rra))[0]
        
        if genesummary_file == "gene_summary file not found or multiple files found":
            print(f"gene_summary file not found or multiple files found")
        else:
            #print(f"found file {genesummary_file}")
            df = pd.read_csv(path2rra /genesummary_file, sep='\t')
            genesummary_file = genesummary_file.replace("Huh751","").replace("Gecko_","") # clean up the screen name
            df_list[genesummary_file] = df
print(f"number of MAGeCK output (rra) folders found: {rra_folder_count}")
#list all screens
print(f"number of screens read into memory: {len(df_list.keys())}")
print(f"names of gene summary files:")
for i in df_list.keys():
    print(f"\t{i}")
#print(f"names of screens:{list(df_list.keys())}")

In [None]:
# read dataset Matia et al. Vaccinia (this dataset didn't have fastq files, and we directly read the processed FDR data)
Matia_Vaccinia = pd.read_excel(matia_et_al_path, sheet_name=None)

# produce a gene summary df for Matia et al. Vaccinia
# we keep the lowest FDR for each gene
list_of_dfs = [df for df in Matia_Vaccinia.values()]
combined_df = pd.concat(list_of_dfs)
# Group by 'gene' and take the lowest 'FDR' for each gene
Matia_Vaccinia_df = combined_df.groupby('Gene').agg({'FDR': 'min'}).reset_index()
Matia_Vaccinia_df = Matia_Vaccinia_df[['Gene', 'FDR']]
Matia_Vaccinia_df.columns = ['id', 'pos|fdr']
# add the Matia et al. Vaccinia screen to the df_list
df_list["Matia_Vaccinia"] = Matia_Vaccinia_df

In [None]:
print(f"number of screens before applying exclusion rules: {len(list(df_list.keys()))}")

### associate cell type with the screen name

In [None]:
cell_type_dict = {

"KSHV_CRISPRi_.gene_summary.txt":"", #unpublished, exclude
"Broeckel_SARS-CoV1_.gene_summary.txt":"Huh7.5.1", #unpublished, exclude
"Flather_MERS_CMK_.gene_summary.txt":"Huh7.5.1", #unpublished, exclude
"AT_SARS2_omicron_BA1_.gene_summary.txt":"Huh7.5.1",#unpublished? , exclude
"Wang_NL63_.gene_summary.txt":"Huh7.5.1",#unpublished? , exclude
"Kulsuptrakul_HAV_.gene_summary.txt":"Huh7.5.1", #exclude this one b/c Joe & Sara said so
"Diep_EV_.gene_summary.txt":"HeLa",
"Diep_RV_.gene_summary.txt":"HeLa",
"Flint_EBOV_.gene_summary.txt":"HAP1",
"Han_InfluenzaA_.gene_summary.txt":"A549",
"Hoffmann_YFV_.gene_summary.txt":"HAP1",
"Hoffmann_ZIKV_.gene_summary.txt":"HAP1",
"Marceau_DENV_.gene_summary.txt":"Huh7.5.1",
"Marceau_HCV_.gene_summary.txt":"Huh7.5.1",
"Ooi_DENV1_276RKI_.gene_summary.txt":"Huh7.5.1",
"Ooi_DENV2_429557_.gene_summary.txt":"Huh7.5.1",
"Ooi_DENV3_Philippines-H871856_.gene_summary.txt":"Huh7.5.1",
"Ooi_DENV4_BC287-97_.gene_summary.txt":"Huh7.5.1",
"229E_A-B-C_229E_infected_vs_229E_Control.gene_summary.txt":"Huh7.5",#Schneider
"NL63_A-B-C_NL63_infected_vs_NL63_Control.gene_summary.txt":"Huh7.5",#Schneider
"OC43_A-B-C_OC43_infected_vs_OC43_Control.gene_summary.txt":"Huh7.5",#Schneider
"SARS-CoV-2-33_A-B-C_SARS-CoV-2_infected_vs_SARS-CoV-2_Control.gene_summary.txt":"Huh7.5",#Schneider
"SARS-CoV-2-37_A-B-C_SARS-CoV-2_infected_vs_SARS-CoV-2_Control.gene_summary.txt":"Huh7.5",#Schneider
"Liu_LCMV_.gene_summary.txt":"A549",
"Sunshine_RSV_screen1_.gene_summary.txt":"A549",
"Sunshine_RSV_screen2_.gene_summary.txt":"A549",
"Wang_229E_.gene_summary.txt":"Huh7.5.1",
"Wang_OC43_.gene_summary.txt":"Huh7.5.1",
"Wang_SARS-CoV2_.gene_summary.txt":"Huh7.5.1",
"HCMV_R1_CMV-surviving_vs_CMV-t0.gene_summary.txt": "HFF",
"SARS-CoV-2_R1-R2_SARS-CoV-2-low-stringency_vs_Control_for_two.gene_summary.txt": "Huh7", 
"229E_R1-R2_229E_infected_vs_Control_for_two.gene_summary.txt": "Huh7",
"SARS-CoV-2_R1-R2_SARS-CoV-2-MOI-0.01_vs_SARS-CoV-2-control.gene_summary.txt": "A549", 
"Matia_Vaccinia":"HeLa", # newly added, Aug 2024, Matia et al. Vaccinia
"HT29DKO_PeVA1_.gene_summary.txt": "HT29-DKO", # newly added, Aug 2024, PeVA1
"HT29DKO_PeVA2_.gene_summary.txt": "HT29-DKO", # newly added, Aug 2024, PeVA2
"ReoT3D_U87MG_.gene_summary.txt": "U87MG", # newly added, Aug 2024, ReoT3D
}

df_list_key_rev = {}
for key in df_list.keys():
    cell = cell_type_dict[key]
    if key != "":
        newkey = cell + "_" + key
    else:
        newkey = key
    newkey = newkey.lstrip("_")
    df_list_key_rev[newkey] = df_list[key]
    
df_list = df_list_key_rev

### select a subset of screens to include in the final plot

In [None]:
# define screen of interest (SOI) to be included in the final plot
SOI = list(df_list.keys())
# exclusion list
toSkip = ["KSHV_CRISPRi_.gene_summary.txt",
          "Huh7.5.1_Kulsuptrakul_HAV_.gene_summary.txt",
          "Huh7.5.1_Wang_NL63_.gene_summary.txt",
          "Huh7.5.1_Broeckel_SARS-CoV1_.gene_summary.txt",
          "Huh7.5.1_AT_SARS2_omicron_BA1_.gene_summary.txt",
          "Huh7.5.1_Flather_MERS_CMK_.gene_summary.txt"
         ]
SOI = [x for x in SOI if not x in toSkip]
print(f"number of screens: {len(SOI)}")

### get hit genes

In [None]:
#get hit genes
hit_genes = []
# go through each screen and get the hit genes based on the metric (FDR) and threshold (0.01)
for scr in SOI:
    df = df_list[scr]
    hit_genes = hit_genes + df[myoperater(transform_metric(df[metric], t=transform_method), threshold)]["id"].tolist()
hit_genes = list(set(hit_genes))
print(f"number of hit genes: {len(hit_genes)}")

In [None]:
#check presence of hit gene in all screens, remove if not found in all screens
hit_genes_cp = copy.deepcopy(hit_genes)
for hit in hit_genes:
    not_found_count = 0
    for scr in SOI:
        df = df_list[scr] # go through all hits for all screens
        if not hit in list(df["id"]):
            #print(f"..{hit} not found in {scr}")
            if hit in hit_genes_cp:
                hit_genes_cp.remove(hit) #remove a hit gene if not found in a screen
            not_found_count+=1
    if hit == "dummyguide" and hit in hit_genes_cp:
        hit_genes_cp.remove(hit)
        print(f"..removed {hit} from hit_genes")

print(f"number of hit genes after removing those not found in all screens: {len(hit_genes_cp)}")

In [None]:
#update hit_genes
hit_genes = hit_genes_cp

### annotate hit genes

In [None]:
#prep df for plot
df_plot = get_df_plot(df_list=df_list, transform_method=transform_method, hit_genes=hit_genes, metric=metric, screen_lst = SOI, exclude_list = exclude_list)
df_plot = df_plot.transpose()

In [None]:
# read in the metascape result
manual_df = pd.read_excel("input/metascape_result.xlsx", sheet_name = "manual")
enrich_df = pd.read_excel("input/metascape_result.xlsx", sheet_name = "Enrichment")
annots_df = pd.read_excel("input/metascape_result.xlsx", sheet_name = "Annotation")

In [None]:
# add fixed annotation
gene2funcAnno = {}

for index, row in manual_df.iterrows():
    Syms = row["Symbol"]
    Fgroup = row["Fgroup"]
    genes = Syms.split(",")
    #strip leading and trailing spaces
    genes = [x.strip() for x in genes]
    for gene in genes:
        if not gene in gene2funcAnno:
            gene2funcAnno[gene] = Fgroup

In [None]:
# the following are the gene function categories we are showing in the final plot
fucntion_categories = ["Heparan sulfate biosynthesis", "Glycosylation", "Lysosome related", "Golgi related", "Vesicle trafficking", "Membrane trafficking", "Transmembrane"]

fucntion_categories_keywords = {
    "Heparan sulfate biosynthesis": ["Heparan", "Heparin"],
    "Glycosylation": ["Glycosylation", "Glycosyltransferase", "Glycoprotein"],
    "Lysosome related": ["Lysosome", "Lysosomal",],
    "Vesicle trafficking": ["Vesicle", "Vesicle trafficking", "Vesicle transport", "Vesicle fusion", "Vesicular"],
    "Golgi related": ["Golgi"],
    "Membrane trafficking": ["Membrane trafficking"],
    "Transmembrane": ["transmembrane",  "TM protein", "TM domain", "transmembrane transport"]
}
for func_cat in fucntion_categories: # loop through all annotation groups
    print(f"function category: {func_cat}")
    for keyword in fucntion_categories_keywords[func_cat]: # loop through all keywords in the annotation group
        for idx, row in enrich_df.iterrows():
            if keyword.lower() in row["Description"].lower():
                print(f"... keyword '{keyword}' found in GO description: '{row['Description']}'")

                # add all genes in this GO group under the label {annot_group}
                genes = row["Symbols"].split(",")
                genes = [x.strip() for x in genes]
                for gene in genes:
                    if not gene in gene2funcAnno: # add the annotation if not overwriting a previous annotation
                        gene2funcAnno[gene] = func_cat

                

In [None]:
# define color blocks for gene function categories

#create fucntion:color mapping
funcAnno_types_uni = list(set(gene2funcAnno.values()))
funcAnno_types_uni.append("Miscellaneous")
funcAnno_types_uni.sort()

funcAnno2color = {}
for idx, cell in enumerate(funcAnno_types_uni):
    funcAnno2color[cell] = funcAnno_types_color_pal[idx]

# adjust the order of keys in the dictionary, use alphabetical order
funcAnno2color = dict(sorted(funcAnno2color.items(), key=lambda item: item[0]))
# set grey color for Miscellaneous
funcAnno2color["Miscellaneous"] = "#d3d3d3"
funcAnno2color

In [None]:
#translate each gene to function category color
funcAnno_lst = []
funcAnno_color_lst = [] 
for gene in df_plot.index:
    if gene in gene2funcAnno:
        funcAnno = gene2funcAnno[gene]
    else:
        funcAnno = "Miscellaneous"
    funcAnno_color = funcAnno2color[funcAnno]
    funcAnno_lst.append(funcAnno)
    funcAnno_color_lst.append(funcAnno_color)

#create df 
df_funcAnno_color = pd.DataFrame({"gene":df_plot.index, "Function category":funcAnno_color_lst})
df_funcAnno_color = df_funcAnno_color.set_index("gene")


### prepare for plotting

In [None]:
# check the list of screens before plotting
scr_list = list(df_plot.columns)

In [None]:
# generate color blocks for cell types

#cell type info
cell_type = ["Huh7",
            "Huh7",
            "A549",
            "H1-HeLa",
            "H1-HeLa",
            "HAP1",
            "A549",
            "HFF",
            "HAP1",
            "HAP1",
            "HT29-DKO\n(HS/SA deficient)",
            "HT29-DKO\n(HS/SA deficient)",
            "A549",
            "Huh7.5.1",
            "Huh7.5.1",
            "Huh7.5.1",
            "Huh7.5.1",
            "Huh7.5.1",
            "Huh7.5.1",
            "U87M",
            "Huh7.5",
            "Huh7.5",
            "Huh7.5",
            "Huh7.5",
            "Huh7.5",
            "A549",
            "A549",
            "Huh7.5.1",
            "Huh7.5.1",
            "Huh7.5.1",
            "HeLa"]

#create color mapping
cell_types_uni = list(set(cell_type))
cell_types_uni.sort()

cell2color = {}
for idx, cell in enumerate(cell_types_uni):
    cell2color[cell] = cell_types_color_pal[idx]
    
#translate color
cell_type_color = [cell2color[c] for c in cell_type]


In [None]:
#color blocks for virus families
Virus_family = ["Coronaviridae",
                "Coronaviridae",
                "Coronaviridae",
                "Picornaviridae",
                "Picornaviridae",
                "Filoviridae",
                "Orthomyxoviridae",
                "Herpesviridae",
                "Flaviviridae",
                "Flaviviridae",
                "Picornaviridae",
                "Picornaviridae",
                "Arenaviridae",
                "Flaviviridae",
                "Flaviviridae",
                "Flaviviridae",
                "Flaviviridae",
                "Flaviviridae",
                "Flaviviridae",
                "Reoviridae",
                "Coronaviridae",
                "Coronaviridae",
                "Coronaviridae",
                "Coronaviridae",
                "Coronaviridae",
                "Pneumoviridae",
                "Pneumoviridae",
                "Coronaviridae",
                "Coronaviridae",
                "Coronaviridae",
                "Poxviridae"]

#create color mapping
fam_uni = list(set(Virus_family))
fam_uni.sort() # consistent legend

fam2color = {}
for idx, fam in enumerate(fam_uni):
    fam2color[fam] = fam_color_pal[idx]

#translate color
family_color = [fam2color[c] for c in Virus_family]
    

In [None]:
#create df for heatmpap row color blocks
df_cell_color = pd.DataFrame({"screen":df_plot.columns, "Virus family":family_color, "Cell type":cell_type_color })
df_cell_color = df_cell_color.set_index("screen")

# visually check the df for row color blocks
df_cell_color

In [None]:
# translate screen names (using mapping defined at the top of this notebook) to desired ones to be shown in the final plot
cols = df_plot.columns.to_list()
cols = [screen_name_mapping[c] for c in cols]
df_plot.columns = cols 

# apply screen name update to the df_cell_color dataframe as well
df_cell_color["screen_new"] = cols
df_cell_color = df_cell_color.set_index("screen_new")

### plot

In [None]:
#plot

prod_method = "average"
prod_metric = "cosine"
dimx=10
dimy=40
#define clustering methods
c_params1 = c_params(
                    method = prod_method,  # single, complete, average, weighted, centroid, median, ward
                    metric = prod_metric, # "braycurtis", "canberra", "chebyshev", "cityblock", "correlation", "cosine", "dice", "euclidean", "hamming", "jaccard", "jensenshannon", "kulczynski1", "mahalanobis", "matching", "minkowski", "rogerstanimoto", "russellrao", "seuclidean", "sokalmichener", "sokalsneath", "sqeuclidean", "yule"
                    z_score = None, # 0, 1, or None
                    std_scale = None) # 0, 1, or None

#sns.set(font_scale=1.15)
#sns.set(rc={"figure.figsize":(6, 4)}) 
g = sns.clustermap(df_plot,cmap=sns.color_palette("Purples", as_cmap=True), method = c_params1.method, metric = c_params1.metric, z_score = c_params1.z_score, standard_scale = c_params1.std_scale,
                   col_colors = df_cell_color,
                   row_colors = df_funcAnno_color,
                   colors_ratio=[0.03,0.007],
                   cbar_pos=(0.25,1.02, .5, .01),  
                   #cbar_pos=None,
                   cbar_kws = {"orientation":"horizontal", "label":"-log10(FDR)", "extend":"min"},
                   vmin=1.3010299956639813,
                  dendrogram_ratio=(.1, .028),
                  robust=True, figsize=(dimx,dimy),
                  yticklabels=True)
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), fontsize = 15)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), fontsize = 11)
g.ax_heatmap.tick_params(bottom=True)
#g.figure.axes[-1].xaxis.label.set_size(19)
#g.figure.axes[-1].xaxis.set_tick_params(labelsize=19)
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), rotation=60, rotation_mode="anchor", ha="right")

#color block legend
lut2 = dict(zip(fam_uni,fam_color_pal))
handles = [Patch(facecolor=lut2[name]) for name in lut2]
col_legend2 = plt.legend(handles, lut2, title='Virus family',
            bbox_to_anchor=(-0.25, 1.035), bbox_transform=plt.gcf().transFigure, loc='upper left')

lut1 = dict(zip(cell_types_uni,cell_types_color_pal))
handles = [Patch(facecolor=lut1[name]) for name in lut1]
col_legend1 = plt.legend(handles, lut1, title='Cell type',
           bbox_to_anchor=(-0.05, 1.035), bbox_transform=plt.gcf().transFigure, loc='upper left')

#lut3 = dict(zip(funcAnno_types_uni,funcAnno_types_color_pal))
lut3 = funcAnno2color
handles = [Patch(facecolor=lut3[name]) for name in lut3]
col_legend3 = plt.legend(handles, lut3, title='Functional category',
            bbox_to_anchor=(-0.25, 0.9), bbox_transform=plt.gcf().transFigure, loc='upper left')

g.ax_col_colors.set_yticklabels(g.ax_col_colors.get_yticklabels(), fontsize = 15) # set column color label style
g.ax_row_colors.set_xticklabels(g.ax_row_colors.get_xticklabels(), fontsize = 15, rotation = 60, ha="right") # set row color label style


plt.gca().add_artist(col_legend1)
plt.gca().add_artist(col_legend2)
plt.gca().add_artist(col_legend3)

#dirpath = os.path.join("C:", os.sep,"Users", "duo.peng", "My Drive", "Joe", "RSV paper heatmap")
dirpath = plot_outdir
if not os.path.exists(dirpath):
    os.makedirs(dirpath)

#eps
filepath = os.path.join(dirpath, f"heatmap_hit_genes_FDR_cutoff={threshold_to_display}_nHits={len(hit_genes)}.eps")
plt.savefig(filepath, dpi=600, bbox_inches='tight')

#jpeg
filepath = os.path.join(dirpath, f"heatmap_hit_genes_FDR_cutoff={threshold_to_display}_nHits={len(hit_genes)}.jpg")
plt.savefig(filepath, dpi=600, bbox_inches='tight')

#pdf
filepath = os.path.join(dirpath, f"heatmap_hit_genes_FDR_cutoff={threshold_to_display}_nHits={len(hit_genes)}.pdf")
plt.savefig(filepath, dpi=600, bbox_inches='tight')

### save plot data

In [None]:
# add annotation to the df_plot
df_plot_for_save = copy.deepcopy(df_plot)
df_plot_for_save.insert(0, "Functional category", funcAnno_lst)
# add index name
df_plot_for_save.index.name = "Hit gene name"

# add mutli-index for the columns
tuple_list = []
for i in df_plot_for_save.columns:
    if "et al." in i or "This study" in i:
        tuple_list.append(("FDR", i))
    else:
        tuple_list.append(("Annotation", i))

df_plot_for_save.columns = pd.MultiIndex.from_tuples(tuple_list)

In [None]:
# save plot data
df_plot_for_save.to_csv(f"output/table_hit_genes_FDR_cutoff={threshold_to_display}_nHits={len(hit_genes)}.csv")

In [None]:
# save a copy of the unannotated genes
df_plot_misc = df_plot_for_save[df_plot_for_save[("Annotation","Functional category")] == "Miscellaneous"]
df_plot_misc.to_csv(f"output/table_hit_genes_FDR_cutoff={threshold_to_display}_nHits={len(hit_genes)}_Miscellaneous.csv")

## Grid search of clustering methods

here we scan row/col clustering methods and distance metrics to find the best combination for the given data  

In [None]:
#grid search

method_lst = ["single", "complete", "average", "weighted", "centroid", "median", "ward"]
metric_lst = ["braycurtis", "canberra", "chebyshev", "cityblock", "correlation", "cosine", "dice", "euclidean", "hamming", "jaccard", "jensenshannon", "mahalanobis", "matching", "minkowski", "rogerstanimoto", "russellrao", "seuclidean", "sokalmichener", "sokalsneath", "sqeuclidean", "yule"] #kulczynski1
z_score_lst = [0, 1, None]
std_scale = [0, 1, None]

if grid_search:
    for meth in method_lst: # scan clustering methods
        for metr in metric_lst: # scan metrics NOTE that currently we don't scan zscore and std scale
                    c_params1 = c_params(method = meth, metric = metr, z_score = None, std_scale = None) 
                    try:
                        prod_method = meth
                        prod_metric = metr
                        dimx=10
                        dimy=30
                        #define clustering methods
                        c_params1 = c_params(
                                            method = prod_method,  # single, complete, average, weighted, centroid, median, ward
                                            metric = prod_metric, # "braycurtis", "canberra", "chebyshev", "cityblock", "correlation", "cosine", "dice", "euclidean", "hamming", "jaccard", "jensenshannon", "kulczynski1", "mahalanobis", "matching", "minkowski", "rogerstanimoto", "russellrao", "seuclidean", "sokalmichener", "sokalsneath", "sqeuclidean", "yule"
                                            z_score = None, # 0, 1, or None
                                            std_scale = None) # 0, 1, or None
                        
                        #sns.set(font_scale=1.15)
                        #sns.set(rc={"figure.figsize":(6, 4)}) 
                        g = sns.clustermap(df_plot,cmap=sns.color_palette("Purples", as_cmap=True), method = c_params1.method, metric = c_params1.metric, z_score = c_params1.z_score, standard_scale = c_params1.std_scale,
                                        col_colors = df_cell_color,
                                        row_colors = df_funcAnno_color,
                                        colors_ratio=[0.03,0.007],
                                        cbar_pos=(0.25,1.02, .5, .01),  
                                        #cbar_pos=None,
                                        cbar_kws = {"orientation":"horizontal", "label":"-log10(FDR)", "extend":"min"},
                                        vmin=1.3010299956639813,
                                        dendrogram_ratio=(.1, .028),
                                        robust=True, figsize=(dimx,dimy),
                                        yticklabels=True)
                        g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), fontsize = 15)
                        g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), fontsize = 11)
                        g.ax_heatmap.tick_params(bottom=True)
                        #g.figure.axes[-1].xaxis.label.set_size(19)
                        #g.figure.axes[-1].xaxis.set_tick_params(labelsize=19)
                        g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), rotation=60, rotation_mode="anchor", ha="right")
                        # title
                        g.fig.suptitle(f"method={c_params1.method}  metric={c_params1.metric}  zscore={c_params1.z_score}  std_scale={c_params1.std_scale}", fontsize=16)

                        #color block legend
                        lut2 = dict(zip(fam_uni,fam_color_pal))
                        handles = [Patch(facecolor=lut2[name]) for name in lut2]
                        col_legend2 = plt.legend(handles, lut2, title='Virus family',
                                    bbox_to_anchor=(-0.25, 1.035), bbox_transform=plt.gcf().transFigure, loc='upper left')

                        lut1 = dict(zip(cell_types_uni,cell_types_color_pal))
                        handles = [Patch(facecolor=lut1[name]) for name in lut1]
                        col_legend1 = plt.legend(handles, lut1, title='Cell type',
                                bbox_to_anchor=(-0.05, 1.035), bbox_transform=plt.gcf().transFigure, loc='upper left')

                        #lut3 = dict(zip(funcAnno_types_uni,funcAnno_types_color_pal))
                        lut3 = funcAnno2color
                        handles = [Patch(facecolor=lut3[name]) for name in lut3]
                        col_legend3 = plt.legend(handles, lut3, title='Functional category',
                                    bbox_to_anchor=(-0.25, 0.9), bbox_transform=plt.gcf().transFigure, loc='upper left')

                        g.ax_col_colors.set_yticklabels(g.ax_col_colors.get_yticklabels(), fontsize = 15) # set column color label style
                        g.ax_row_colors.set_xticklabels(g.ax_row_colors.get_xticklabels(), fontsize = 15, rotation = 60, ha="right") # set row color label style


                        plt.gca().add_artist(col_legend1)
                        plt.gca().add_artist(col_legend2)
                        plt.gca().add_artist(col_legend3)
                        print(f"SUCCEED: method={c_params1.method}  metric={c_params1.metric}  zscore={c_params1.z_score}  std_scale={c_params1.std_scale}")
                    except Exception as e:
                        print(f"FAILED: cmethod={c_params1.method}  cdist={c_params1.metric}  zscore={c_params1.z_score}  std_scale={c_params1.std_scale}")
                        print(e)
                        plt.close() # close the plot if it fails