In [None]:
# %load ../snippets/basic_settings.py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from pathlib import Path
import seaborn as sns
import sys
import plotly.express as px
import yaml

sns.set_context("notebook", font_scale=1.1)
pd.set_option("display.max_columns", 100)
pd.set_option("display.max_rows", 100)
plt.rcParams["figure.figsize"] = (16, 12)
plt.rcParams['savefig.dpi'] = 200
plt.rcParams['figure.autolayout'] = False
plt.rcParams['axes.labelsize'] = 18
plt.rcParams['axes.titlesize'] = 20
plt.rcParams['font.size'] = 16
plt.rcParams['lines.linewidth'] = 2.0
plt.rcParams['lines.markersize'] = 8
plt.rcParams['legend.fontsize'] = 14
plt.rcParams['text.usetex'] = False  # True activates latex output in fonts!
plt.rcParams['font.family'] = "serif"
plt.rcParams['font.serif'] = "cm"
pd.set_option('display.float_format', lambda x: '{:,.2f}'.format(x))

In [None]:
# dataDir = Path("/nfs/nas22/fs2202/biol_micro_bioinf_nccr/hardt/nguyenb/tnseq/scratch/03_22/counts")
# countData = pd.read_csv(dataDir/"dnaid1315_mbarq_merged_counts.csv")
# sampleData = pd.read_csv(dataDir/"example_sample_data.csv")

In [None]:
config_file = "../nguyenb_config.yaml"
with open(config_file) as file:
    # The FullLoader parameter handles the conversion from YAML
    # scalar values to Python the dictionary format
    configs = yaml.load(file, Loader=yaml.FullLoader)

In [None]:
# Run on server:
root = Path(configs['root']['server'])
scratchDir = Path(configs['scratchDir']['server'])

In [None]:
mapDir = root/configs['mapDir']
countDir = root/configs['libraryCountsDir']
resultDir = root/configs['resultDir']
sample_data_file = root/configs['sampleData']
control_file_short = root/"controls_6barcodes.csv"

In [None]:
def read_in_sample_data(sample_data_file, sampleIDs, treatment_col="", batch_col=""):
    """
    add data validation code
    
    """
    return pd.read_csv(sample_data_file)

In [None]:
def read_merged_count_file(merged_count_file):
    counts = pd.read_csv(merged_count_file)
    annotation_cols = list(counts.columns[0:2])
    sampleIDs = list(counts.columns[2:])
    return counts, annotation_cols, sampleIDs


def calculate_cpms(merged_df, annotation_cols, sampleIDs):
    merged_df = merged_df[merged_df.sum(axis=1, numeric_only=True) > 10]
    # Normalized for library depth and log transform
    cpms = merged_df.copy().set_index(list(annotation_cols))
    cpms = np.log2(cpms/cpms.sum()*1000000 +0.5).reset_index()
    return cpms

In [None]:
countData, ann_cols, sampleIDs  = read_merged_count_file(countDir/"library_13_2_mbarq_merged_counts.csv")
countData = countData.set_index(['barcode', 'Name'])
countData = countData[countData.sum(axis=1) > 100].reset_index()
sampleData = read_in_sample_data(sample_data_file, sampleIDs)



In [None]:
# Figure out good samples (from 28-04-2022-mageck-analysis)
%store -r good_samples

In [None]:
good_samples = ['dnaid1315_10',
 'dnaid1315_107',
 'dnaid1315_117',
 'dnaid1315_124',
 'dnaid1315_128',
 'dnaid1315_129',
 'dnaid1315_131',
 'dnaid1315_136',
 'dnaid1315_17',
 'dnaid1315_18',
 'dnaid1315_19',
 'dnaid1315_20',
 'dnaid1315_28',
 'dnaid1315_40',
 'dnaid1315_42',
 'dnaid1315_50',
 'dnaid1315_52',
 'dnaid1315_66',
 'dnaid1315_81',
 'dnaid1315_90',
 'dnaid1315_92',
 'dnaid1315_94',
 'dnaid1315_96']


In [None]:
cpms = calculate_cpms(countData, ann_cols, sampleIDs)

In [None]:
day1 = sampleData[(sampleData.day.isin(['d1'])) & (sampleData.sampleID.isin(sampleIDs))]

In [None]:
day1CPM = cpms[['barcode', 'Name'] + list(day1.sampleID.values)].set_index(['barcode', 'Name'])

In [None]:
varBcs = day1CPM.var(axis=1).sort_values().tail(500).reset_index().barcode.values

In [None]:
varBcs

In [None]:
df = day1CPM.reset_index()
df = df[df.barcode.isin(varBcs)].drop(['Name'], axis=1).set_index('barcode').drop_duplicates()

In [None]:
df2 = df.T.corr()
df2.columns.name = 'barcode2'
df2 = df2.reset_index().melt(id_vars=['barcode'])
df2['r2'] = df2.value**2

In [None]:
df2 = df2[(abs(df2.r2) > 0.8) & (df2.r2 < 0.99)]

In [None]:
df2.sort_values('value').tail(10)

In [None]:
plt.plot(df.loc['TCCGCGAATAGAATAGC'], df.loc['CGAGTACCAACCGTGAC'], '.')

In [None]:
df3 = (df2.merge(countData[['barcode', 'Name']], on='barcode', how='left')
 .merge(countData[['barcode', 'Name']], left_on='barcode2', right_on='barcode', how='left'))[['barcode_x', 'barcode2', 'value', 'r2', 'Name_x', 'Name_y']]

In [None]:
df3[['Name_x', 'Name_y', 'value']].drop_duplicates().sort_values('value').dropna().head(30)

In [None]:
df3[['Name_x', 'Name_y', 'value']].drop_duplicates().sort_values('value').dropna().tail(50)

In [None]:
network_df = df3.groupby(['Name_x', 'Name_y']).agg({'value':['mean'], 'r2':['mean']}).reset_index()

In [None]:
network_df.nunique()

In [None]:
df3[df3.Name_x == 'pilP']

In [None]:
x = day1CPM.reset_index()
x = x[x.Name=='pilQ']
x

In [None]:
x = day1CPM.reset_index()
x = x[x.Name=='pilT']
x

In [None]:
fdf = day1CPM.reset_index()
fdf = fdf[fdf.Name.isin(['pilR', 'pilP'])]
fdf = fdf.melt(id_vars=['barcode', 'Name'], var_name='sampleID', value_name='log2CPM')

In [None]:
fdf


In [None]:
bcs = ['CGTATCCCAGGATCTGT','TATCGAACCACATCATA']
bcs2 = ['AACTATACGGGAACGCC', 'AAGTAACCAGTCGAAGA']
#bcs2 = ['AAACAACCGGTACTGAG', 'GGGGTATGAAACTTAAG']


In [None]:
fdf1 = fdf[fdf.barcode.isin(bcs)].pivot(index=[ 'sampleID'], columns='Name').reset_index()
fdf2 = (day1CPM.reset_index()
        .melt(id_vars=['barcode', 'Name'], var_name='sampleID', value_name='log2CPM'))
fdf2 = fdf2[fdf2.barcode.isin(bcs2)].pivot(index=[ 'sampleID'], columns='Name').reset_index()
fdf2.columns =['sampleID', 'barcode_1', 'barcode_2', 'gene1', 'gene2']

In [None]:
fdf2

In [None]:
px.scatter(fdf2, x='gene1', y='gene2', hover_data=['sampleID'] )

In [None]:
countData

In [None]:
test = df3.dropna().sample(20)

In [None]:
test = test[['Name_x', 'Name_y', 'value']]

In [None]:
test

In [None]:
import networkx as nx

# Colors
vmin= -3
vmax=0.5
cmap = plt.cm.coolwarm
from matplotlib import colors
divnorm=colors.TwoSlopeNorm(vmin=vmin, vcenter=0., vmax=vmax)
sm = plt.cm.ScalarMappable(cmap=cmap,norm=divnorm)


def graph_from_df(df, node1, node2, score):
    
    G = nx.from_pandas_edgelist(df, node1, node2, score)
    edge_width = dict([((a, b), G.get_edge_data(a,b)[score]*5) for a, b in G.edges()])
    nx.set_edge_attributes(G, name ="edge_width", values =edge_width)
    return G


def plot_cluster(graph, ax, sm):
    pos = nx.kamada_kawai_layout(graph)
    edge_width = list(nx.get_edge_attributes(graph, 'edge_width').values())
    #node_size = list(networkx.get_node_attributes(graph, 'in_cluster').values())
    ncolor = list(nx.get_node_attributes(graph, 'day_ci').values())
    ncolor = [sm.to_rgba(x) for x in ncolor]
    nx.draw_networkx_edges(graph, pos, alpha=0.8, width=edge_width, ax=ax )
    nx.draw_networkx_nodes(graph, pos, node_size=node_size, node_color=ncolor,  
                         alpha=0.9, ax=ax)
    label_options = {"ec": "k", "fc": "white", "alpha": 0.5}
    nx.draw_networkx_labels(graph, pos, font_size=16, bbox=label_options, ax=ax);
    