# Imports: 

In [137]:
import pandas as pd
import os
import re

import holoviews as hv 
import bokeh.io
import bokeh.plotting
import colorcet as cc

import matplotlib.pyplot as plt
import seaborn as sns
cols = sns.color_palette()


import hvplot.pandas


# Functions: 

In [130]:
def get_NN12(rvid_query, df_interact):
    # first nearest neighbors: 
    df_NN1 = df_interact[(df_interact.lead_gene==rvid_query) | (df_interact.partner_gene==rvid_query)].copy()
    list_rvid_NN1 = list(set(df_NN1.lead_gene.tolist() + df_NN1.partner_gene.tolist()))
    list_rvid_NN1.sort()

    # second nearest neighbors: 
    df_NN2 = df_interact[ (df_interact.lead_gene.isin(list_rvid_NN1)) | (df_interact.partner_gene.isin(list_rvid_NN1))].copy()
    list_rvid_NN2 = list(set(df_NN2.lead_gene.tolist() + df_NN2.partner_gene.tolist()))
    list_rvid_NN2.sort()
    
    # third nearest neighbors: 
    df_NN3 = df_interact[ (df_interact.lead_gene.isin(list_rvid_NN2)) | (df_interact.partner_gene.isin(list_rvid_NN2))].copy()
    list_rvid_NN3 = list(set(df_NN3.lead_gene.tolist() + df_NN3.partner_gene.tolist()))
    list_rvid_NN3.sort()
    
    return list_rvid_NN1, list_rvid_NN2, list_rvid_NN3

In [131]:
def interactive_scatter(df_lfc, x_rvid, y_rvid):
    df_xy = df_lfc[ (df_lfc.rvid==x_rvid) | (df_lfc.rvid==y_rvid) ][['rvid']+cols_data].copy()
    df_xy = df_xy.set_index('rvid').T.rename_axis('screen').reset_index()
    hv = df_xy.hvplot.scatter(x = x_rvid, y = y_rvid, width = 400, height = 400, size = 200, line_color='k', line_width=3, hover_cols = ['screen'])
    return hv

In [132]:
def interactive_scatter_grid(df_lfc, list_rvid):

    cols_data = df_lfc.columns[1:].tolist()
    
    df_xy = df_lfc[ df_lfc.Rv_ID.isin(list_rvid) ][ ['Rv_ID']+cols_data ].copy()
    df_xy = df_xy.set_index('Rv_ID').T.rename_axis('screen').reset_index()

    list_hv = []
    for i in range(len(list_rvid)):
        hv_temp = df_xy.hvplot.scatter(x = list_rvid[i], y = list_rvid, width = 300, height = 300, size = 200, line_color='k', 
                                       line_width=3, hover_cols = ['screen'], subplots=True, fontsize = {'xlabel': '15pt'}, xlabel = list_rvid[i] ).cols(len(list_rvid))
        list_hv.append(hv_temp)

    return list_hv

# Loading datasets: 

Annotations: 

In [138]:
fn = '/home/ajinich/Documents/repos/mtb_tn_db/data/annotations/uniprot_mtb_with_location.xlsx'
df_mtb_w_loc = pd.read_excel(fn)
df_mtb_w_loc = df_mtb_w_loc.fillna('')

re_str = 'Rv\d\d\d\dc?'
list_rvids = [re.findall(re_str, str_temp)[0] for str_temp in df_mtb_w_loc['Gene names']]
df_mtb_w_loc['Rv_ID'] = list_rvids

list_gene_names = [gn.split()[0] for gn in df_mtb_w_loc["Gene names"]]
df_mtb_w_loc['gene_names'] = list_gene_names

df_rvid_to_name = df_mtb_w_loc[['Rv_ID', 'gene_names']].copy() 

dict_rvid_to_name = {}
for index, row in df_rvid_to_name.iterrows():
    dict_rvid_to_name[row.Rv_ID] = row.gene_names

Interaction data: 

In [141]:
path = '/home/ajinich/Dropbox/KyuRhee/unknown_function/unknown_redox/data/GLS_TnSeq_v2/'
fn = 'test_SI_data_1_fdr.001.xlsx'
fn_path = os.path.join(path, fn)
df_interact = pd.read_excel(fn_path)

Log-2 fold-changes (with or without normalization)

In [140]:
norm = 0

In [143]:
# LFC dataset
fn_lfc_basis = '../data/standardized_data/result_logfc_matrix_2021_11_15_BASIS_invitro.csv'
df_lfc_basis = pd.read_csv(fn_lfc_basis)
df_lfc_basis.dropna(axis=0, inplace=True)

cols_data = df_lfc_basis.columns[1:]

if norm: 
    X = df_lfc_basis[cols_data].values
    X_norm = normalize(X, norm='l2', axis=0)

    df_lfc_basis_norm_invitro = df_lfc_basis.copy()
    df_lfc_basis_norm_invitro[cols_data] = X_norm

    df_lfc = df_lfc_basis_norm_invitro.copy()
    
else:
    df_lfc = df_lfc_basis.copy()

## Interactive scatter plots: 

In [144]:
rvid_query =  'Rv2940c'
# rvid_query = df_interact.sample()['lead_gene'].values[0]
list_rvid_NN1, list_rvid_NN2, list_rvid_NN3 = get_NN12(rvid_query, df_interact)

In [128]:
list_rvid = list_rvid_NN2.copy()
# list_rvid = list_rvid + ['Rv1344', 'Rv1345', 'Rv1346', 'Rv1347c', 'Rv1348', 'Rv1349']
list_hv = interactive_scatter_grid(df_lfc, list_rvid)
for hv in list_hv:
    display(hv)

# Random pair of genes: 

In [123]:
list_rvid = df_lfc.Rv_ID.sample(2).tolist()
list_hv = interactive_scatter_grid(df_lfc, list_rvid)
for hv in list_hv:
    display(hv)

# OTHER: 

In [145]:
rvid_query = 'Rv0796'
list_rvid_NN1, list_rvid_NN2, list_rvid_NN3 = get_NN12(rvid_query, df_interact)

df_insertion_elements = df_mtb_w_loc[df_mtb_w_loc.Rv_ID.isin(list_rvid_NN1)].copy()
list_IS6110 = [rvid.strip().split()[0] for rvid in df_insertion_elements['Gene names'].tolist()[0].split(';')]
list_transposase = [rvid.strip().split()[0] for rvid in df_insertion_elements['Gene names'].tolist()[1].split(';')]
list_mutator_fam_transposase = df_insertion_elements['Gene names'].values[2].split()
# list_rvid = list_IS6110 + list_transposase + list_mutator_fam_transposase + ['Rv3508']

num_sample = 2
list_rvid = list_IS6110[:num_sample] + list_transposase[:num_sample] + list_mutator_fam_transposase[:num_sample] + ['Rv3508', 'Rv3514']

In [147]:
list_hv = interactive_scatter_grid(df_lfc, list_rvid)
for hv in list_hv:
    display(hv)