In [None]:
# %load ../snippets/basic_settings.py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from pathlib import Path
import seaborn as sns
import sys
import plotly.express as px
import yaml

sns.set_context("notebook", font_scale=1.1)
pd.set_option("display.max_columns", 100)
pd.set_option("display.max_rows", 100)
plt.rcParams["figure.figsize"] = (16, 12)
plt.rcParams['savefig.dpi'] = 200
plt.rcParams['figure.autolayout'] = False
plt.rcParams['axes.labelsize'] = 18
plt.rcParams['axes.titlesize'] = 20
plt.rcParams['font.size'] = 16
plt.rcParams['lines.linewidth'] = 2.0
plt.rcParams['lines.markersize'] = 8
plt.rcParams['legend.fontsize'] = 14
plt.rcParams['text.usetex'] = False  # True activates latex output in fonts!
plt.rcParams['font.family'] = "serif"
plt.rcParams['font.serif'] = "cm"
pd.set_option('display.float_format', lambda x: '{:,.2f}'.format(x))

In [None]:
# dataDir = Path("/nfs/nas22/fs2202/biol_micro_bioinf_nccr/hardt/nguyenb/tnseq/scratch/03_22/counts")
# countData = pd.read_csv(dataDir/"dnaid1315_mbarq_merged_counts.csv")
# sampleData = pd.read_csv(dataDir/"example_sample_data.csv")

In [None]:
config_file = "../nguyenb_config.yaml"
with open(config_file) as file:
    # The FullLoader parameter handles the conversion from YAML
    # scalar values to Python the dictionary format
    configs = yaml.load(file, Loader=yaml.FullLoader)

In [None]:
# Run on server:
root = Path(configs['root']['server'])
scratchDir = Path(configs['scratchDir']['server'])

In [None]:
mapDir = root/configs['mapDir']
countDir = root/configs['libraryCountsDir']
resultDir = root/configs['resultDir']
sample_data_file = root/configs['sampleData']
control_file_short = root/"controls_6barcodes.csv"

In [None]:
%ls /nfs/cds-peta/exports/biol_micro_cds_gr_sunagawa/scratch/ansintsova/Projects_NCCR/hardt/nguyenb/tnseq/scratch/04_22/results

In [None]:
resultDir

In [None]:
def read_in_sample_data(sample_data_file, sampleIDs, treatment_col="", batch_col=""):
    """
    add data validation code
    
    """
    return pd.read_csv(sample_data_file)

In [None]:
def read_merged_count_file(merged_count_file):
    counts = pd.read_csv(merged_count_file)
    annotation_cols = list(counts.columns[0:2])
    sampleIDs = list(counts.columns[2:])
    return counts, annotation_cols, sampleIDs


def calculate_cpms(merged_df, annotation_cols, sampleIDs):
    merged_df = merged_df[merged_df.sum(axis=1, numeric_only=True) > 10]
    # Normalized for library depth and log transform
    cpms = merged_df.copy().set_index(list(annotation_cols))
    cpms = np.log2(cpms/cpms.sum()*1000000 +0.5).reset_index()
    return cpms

In [None]:
countData, ann_cols, sampleIDs  = read_merged_count_file(countDir/"library_13_2_mbarq_merged_counts.csv")
countData = countData.set_index(['barcode', 'Name'])
countData = countData[countData.sum(axis=1) > 100].reset_index()
sampleData = read_in_sample_data(sample_data_file, sampleIDs)



In [None]:
countData

In [None]:
# Figure out good samples (from 28-04-2022-mageck-analysis)
%store -r good_samples

In [None]:
good_samples = ['dnaid1315_10',
 'dnaid1315_107',
 'dnaid1315_117',
 'dnaid1315_124',
 'dnaid1315_128',
 'dnaid1315_129',
 'dnaid1315_131',
 'dnaid1315_136',
 'dnaid1315_17',
 'dnaid1315_18',
 'dnaid1315_19',
 'dnaid1315_20',
 'dnaid1315_28',
 'dnaid1315_40',
 'dnaid1315_42',
 'dnaid1315_50',
 'dnaid1315_52',
 'dnaid1315_66',
 'dnaid1315_81',
 'dnaid1315_90',
 'dnaid1315_92',
 'dnaid1315_94',
 'dnaid1315_96']


In [None]:
cpms = calculate_cpms(countData, ann_cols, sampleIDs)

In [None]:
cpms.head()

In [None]:
day1 = sampleData[(sampleData.day.isin(['d1'])) & (sampleData.sampleID.isin(sampleIDs))]

In [None]:
day1CPM = cpms[['barcode', 'Name'] + list(day1.sampleID.values)].set_index(['barcode', 'Name'])

In [None]:
varBcs = day1CPM.var(axis=1).sort_values().reset_index()#.barcode.values
varBcs = varBcs.dropna(subset=['Name'])
varBcs = varBcs[varBcs[0] > 0.2]

In [None]:
varBcs

In [None]:
annot_full[annot_full.Name == 'lipB']

In [None]:
annot_full = pd.read_csv('/nfs/nas22/fs2202/biol_micro_bioinf_nccr/hardt/nguyenb/tnseq/scratch/10_22/control_norm_analysis/26-10-22-annotated-results.csv')

In [None]:
annot = annot[['Name', 'KEGG_Pathway']].drop_duplicates()
annot['KEGG_Pathway'] = annot['KEGG_Pathway'].fillna('-')
#annot['KEGG_Pathway'] = annot['KEGG_Pathway'].str.split(',', expand=True)[0]


In [None]:
d = annot.set_index('Name').to_dict()['KEGG_Pathway']
all_genes = d.keys()
all_pathways = [p.split(',') for p in d.values()  if type(p)== str]

all_pathways = set([a for p in all_pathways for a in p if a.startswith('ko')] )

In [None]:
d['prgK']

In [None]:
from collections import defaultdict
fd = defaultdict(dict)
for gene in all_genes:
    for pathway in all_pathways:
        if pathway in d[gene]:
            fd[pathway][gene] = 1
        else:
            fd[pathway][gene] = 0

In [None]:
anot_df = pd.DataFrame(fd)

In [None]:
anot_df = anot_df.reset_index().rename({'index':'Name'}, axis=1)

In [None]:
fdf = varBcs[['barcode', 'Name']].merge(cpms, how='left', on=['barcode', 'Name'])
full_fdf = fdf.merge(anot_df, how='left', on='Name')

In [None]:
full_fdf.head()

In [None]:
fdf.head()
X = fdf.iloc[:,2:]

In [None]:
from sklearn.preprocessing import MinMaxScaler
from sklearn.neighbors import KDTree, BallTree
s = MinMaxScaler()
X_s = s.fit_transform(X)
kdt = BallTree(X_s, leaf_size=30,  metric='l1')

In [None]:
fdf_s = pd.concat([fdf.iloc[:, [0,1]], pd.DataFrame(X_s)], axis=1)

In [None]:
t1 =fdf_s.iloc[878, 2:].values #rfaL
t2 =fdf_s.iloc[997, 2:].values #rfaI
t3 =fdf_s.iloc[1059, 2:].values #rfbD

In [None]:
fdf_s[fdf_s.Name == 'rfaL']

In [None]:
d1, loci1 = kdt.query([t1],  k=25, return_distance=True)
d2, loci2= kdt.query([t2],  k=25, return_distance=True)
d3, loci3 = kdt.query([t3],  k=25, return_distance=True)

In [None]:
d

In [None]:
n1 = fdf_s.iloc[loci1[0]].Name.values

In [None]:
n2 = fdf_s.iloc[loci2[0]].Name.values

In [None]:
n3 = fdf_s.iloc[loci3[0]].Name.values

In [None]:
n1

In [None]:
n2

In [None]:
n3

In [None]:
n1

In [None]:
n2

In [None]:
n2

In [None]:
for i in list(n1 & n2):
    print(i)

In [None]:
fdf_s.iloc[loci[0]]

In [None]:
kor = full_fdf.iloc[loci[0]][['barcode', 'Name']+[c for c in full_fdf.columns if 'ko' in c]]

In [None]:
kor['Name'] = kor['barcode'] + kor['Name']
kor = kor.set_index('Name')
kor = kor.drop(['barcode'], axis=1)

In [None]:
kor

In [None]:
annot[annot.Name == 'leuO']

In [None]:
kor = kor.loc[:,kor.sum()>0]

In [None]:
kor

In [None]:
px.imshow(kor, width=800, height=800)

In [None]:
kor.sum()

In [None]:
r = fdf_s.iloc[loci[0]]

In [None]:
#r = pd.concat([r.iloc[:, [0,1]], r[r.iloc[0, 2:].sort_values().index]], axis=1)
#r.columns = ['barcode', 'Name'] + list(range(0, 49))

In [None]:
r = r.melt(id_vars=['barcode', 'Name'])

In [None]:
px.scatter(r, x='variable', y = 'value', color='Name', trendline="lowess", trendline_options=dict(frac=0.1))

In [None]:
annot[annot.Name == 'ychF']

In [None]:
fdf = fdf.merge(annot, on='Name', how='left')

In [None]:
fdf = fdf.dropna()

In [None]:
fdf

In [None]:
df = day1CPM.reset_index()
df = df[df.barcode.isin(varBcs)].drop(['Name'], axis=1).set_index('barcode').drop_duplicates()

In [None]:
df2 = df.T.corr()
df2.columns.name = 'barcode2'
df2 = df2.reset_index().melt(id_vars=['barcode'])
df2['r2'] = df2.value**2

In [None]:
df2 = df2[(abs(df2.r2) > 0.8) & (df2.r2 < 0.99)]

In [None]:
df2.sort_values('value').tail(10)

In [None]:
plt.plot(df.loc['TCCGCGAATAGAATAGC'], df.loc['CGAGTACCAACCGTGAC'], '.')

In [None]:
df3 = (df2.merge(countData[['barcode', 'Name']], on='barcode', how='left')
 .merge(countData[['barcode', 'Name']], left_on='barcode2', right_on='barcode', how='left'))[['barcode_x', 'barcode2', 'value', 'r2', 'Name_x', 'Name_y']]

In [None]:
df3[['Name_x', 'Name_y', 'value']].drop_duplicates().sort_values('value').dropna().head(30)

In [None]:
df3[['Name_x', 'Name_y', 'value']].drop_duplicates().sort_values('value').dropna().tail(50)

In [None]:
network_df = df3.groupby(['Name_x', 'Name_y']).agg({'value':['mean'], 'r2':['mean']}).reset_index()

In [None]:
network_df.nunique()

In [None]:
df3[df3.Name_x == 'pilP']

In [None]:
x = day1CPM.reset_index()
x = x[x.Name=='pilQ']
x

In [None]:
x = day1CPM.reset_index()
x = x[x.Name=='pilT']
x

In [None]:
fdf = day1CPM.reset_index()
fdf = fdf[fdf.Name.isin(['pilR', 'pilP'])]
fdf = fdf.melt(id_vars=['barcode', 'Name'], var_name='sampleID', value_name='log2CPM')

In [None]:
fdf


In [None]:
bcs = ['CGTATCCCAGGATCTGT','TATCGAACCACATCATA']
bcs2 = ['AACTATACGGGAACGCC', 'AAGTAACCAGTCGAAGA']
#bcs2 = ['AAACAACCGGTACTGAG', 'GGGGTATGAAACTTAAG']


In [None]:
fdf1 = fdf[fdf.barcode.isin(bcs)].pivot(index=[ 'sampleID'], columns='Name').reset_index()
fdf2 = (day1CPM.reset_index()
        .melt(id_vars=['barcode', 'Name'], var_name='sampleID', value_name='log2CPM'))
fdf2 = fdf2[fdf2.barcode.isin(bcs2)].pivot(index=[ 'sampleID'], columns='Name').reset_index()
fdf2.columns =['sampleID', 'barcode_1', 'barcode_2', 'gene1', 'gene2']

In [None]:
fdf2

In [None]:
px.scatter(fdf2, x='gene1', y='gene2', hover_data=['sampleID'] )

In [None]:
countData

In [None]:
test = df3.dropna().sample(20)

In [None]:
test = test[['Name_x', 'Name_y', 'value']]

In [None]:
test

In [None]:
import networkx as nx

# Colors
vmin= -3
vmax=0.5
cmap = plt.cm.coolwarm
from matplotlib import colors
divnorm=colors.TwoSlopeNorm(vmin=vmin, vcenter=0., vmax=vmax)
sm = plt.cm.ScalarMappable(cmap=cmap,norm=divnorm)


def graph_from_df(df, node1, node2, score):
    
    G = nx.from_pandas_edgelist(df, node1, node2, score)
    edge_width = dict([((a, b), G.get_edge_data(a,b)[score]*5) for a, b in G.edges()])
    nx.set_edge_attributes(G, name ="edge_width", values =edge_width)
    return G


def plot_cluster(graph, ax, sm):
    pos = nx.kamada_kawai_layout(graph)
    edge_width = list(nx.get_edge_attributes(graph, 'edge_width').values())
    #node_size = list(networkx.get_node_attributes(graph, 'in_cluster').values())
    ncolor = list(nx.get_node_attributes(graph, 'day_ci').values())
    ncolor = [sm.to_rgba(x) for x in ncolor]
    nx.draw_networkx_edges(graph, pos, alpha=0.8, width=edge_width, ax=ax )
    nx.draw_networkx_nodes(graph, pos, node_size=node_size, node_color=ncolor,  
                         alpha=0.9, ax=ax)
    label_options = {"ec": "k", "fc": "white", "alpha": 0.5}
    nx.draw_networkx_labels(graph, pos, font_size=16, bbox=label_options, ax=ax);
    