In [1]:
import pandas as pd
import numpy as np
import cassiopeia as cas
import seaborn as sns
from os.path import join, exists
from os import makedirs, getcwd
from pandarallel import pandarallel
import networkx as nx
import matplotlib.pyplot as plt
from networkx.drawing.nx_agraph import write_dot, graphviz_layout

pandarallel.initialize(nb_workers=32)

# importing the sys module
import sys

# the setrecursionlimit function is
# used to modify the default recursion
# limit set by python. Using this,
# we can increase the recursion limit
# to satisfy our needs
 
sys.setrecursionlimit(10**6)
import pickle

INFO: Pandarallel will run on 32 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


In [2]:
indir = "/data2/mito_lineage/data/processed/mttrace/jan21_2021/MTblacklist/merged/filters/minC10_minR50_topN0_hetT0.001_hetC10_hetCount5_bq20/filter_mgatk/vireoIn/multiplex/"

donors = "/data2/mito_lineage/data/processed/mttrace/jan21_2021/MTblacklist/merged/filters/minC10_minR50_topN0_hetT0.001_hetC10_hetCount5_bq20/filter_mgatk/vireoIn/multiplex/cells_meta.tsv"
prefix= "jan21_2021/MTblacklist"
name = "merged_missingV"
#outdir = "output/cass/data"
outdir = "/data/isshamie/mito_lineage/output/cass"
af_thresh = 0.01
dp_thresh = 2
set_missing = True

In [3]:
import os
from src.config import ROOT_DIR
#os.chdir(ROOT_DIR)

Project Directory: /data2/mito_lineage


In [4]:
# outdir = join(outdir, prefix)
# if not exists(outdir):
#     print(f"Making outdir {outdir} in folder {getcwd()}")
#     makedirs(outdir)

In [5]:
donors_df = pd.read_csv(donors, sep='\t').astype({"donor":"Int32", "donor_index":"Int32"})
donors_df

Unnamed: 0,new index,ID,raw ID,condition,donor,donor_index
0,1,AAACGAAAGAGGTCCA-1_Control,AAACGAAAGAGGTCCA-1,Control,1,1
1,2,AAACGAAAGCGATACG-1_Control,AAACGAAAGCGATACG-1,Control,0,1
2,3,AAACGAAAGTCGTGAG-1_Control,AAACGAAAGTCGTGAG-1,Control,3,1
3,4,AAACGAACAATAGTGA-1_Control,AAACGAACAATAGTGA-1,Control,1,2
4,5,AAACGAACACAATAAG-1_Control,AAACGAACACAATAAG-1,Control,3,2
...,...,...,...,...,...,...
17401,17402,TTTGTGTTCGAGTTAC-1_Flt3l,TTTGTGTTCGAGTTAC-1,Flt3l,3,4205
17402,17403,TTTGTGTTCGCATAAC-1_Flt3l,TTTGTGTTCGCATAAC-1,Flt3l,,
17403,17404,TTTGTGTTCGTGGTAT-1_Flt3l,TTTGTGTTCGTGGTAT-1,Flt3l,0,3853
17404,17405,TTTGTGTTCGTTACAG-1_Flt3l,TTTGTGTTCGTTACAG-1,Flt3l,1,3750


## Loop through each donor and load their depth and allele depth matrices

In [6]:
from src.utils.data_io import wrap_load_mtx_df

In [7]:
allAD = {}
allDP = {}
allAF = {}
for d, df in donors_df.dropna(axis=0, subset=["donor"]).groupby("donor"):
    print('donor', d)
    allAD[d], allDP[d] = wrap_load_mtx_df(indir, oth_f=False, prefix=f"donor{d}",
                             columns=('Variant', 'Cell', 'integer'), inc_af=False,
                             as_dense=True, var_names=True, vcf_prefix=f"donor{d}", verbose=False)
    
    allAD[d] = allAD[d].transpose()
    allDP[d] = allDP[d].transpose()
    allAD[d].index = df["ID"].values
    allDP[d].index = df["ID"].values
    allAF[d] = allAD[d]/allDP[d]
    print(allAF[d].shape, allAF[d].shape[0]*allAF[d].shape[1])
    print("0s in allAF", (allAF[d]==0).sum().sum())
    print("missing values in allAF", (allAF[d].isnull()).sum().sum())
    if set_missing:
        allAF[d].fillna(-1)
    else:     
        allAF[d].fillna(0)

donor 0
(3853, 992) 3822176
0s in allAF 0
missing values in allAF 3624893
donor 1
(3751, 960) 3600960
0s in allAF 0
missing values in allAF 3412484
donor 2
(3073, 967) 2971591
0s in allAF 0
missing values in allAF 2905901
donor 3
(4205, 979) 4116695
0s in allAF 0
missing values in allAF 3890272


In [8]:
def dp_where(x, dp, dp_thresh):
#    print('name', x.name)
    curr = dp.loc[x.name]<dp_thresh
    x.loc[curr] = -1 #"-"
    return x


def plot_network(cas_tree, outdir, name):
    G = nx.DiGraph()
    G.add_nodes_from(cas_tree.nodes)
    G.add_edges_from(cas_tree.edges)

    # write dot file to use with graphviz
    # run "dot -Tpng test.dot >test.png"
    f = plt.figure(figsize=(10,10), dpi=300)
    write_dot(G,join(outdir, f'{name}_hybrid.dot'))

    # same layout using matplotlib with no labels
    plt.title('draw_networkx')
    pos = graphviz_layout(G, prog='dot', root=cas_tree.root)
    nx.draw(G, pos, with_labels=False, arrows=True)
    plt.savefig(join(outdir, f'{name}_hybrid.png'))

def run_cass(af, outdir, name, to_bin=True, dp=None, af_thresh=None, dp_thresh=None, priors=None):
    if to_bin and af_thresh is not None:
        #Binarize
        af = af.applymap(lambda x: 0 if x<af_thresh else 1)
    if dp is not None and dp_thresh is not None:
        af = af.parallel_apply(dp_where, axis=1, args=(dp, dp_thresh))
    character_matrix = af.copy()
    #character_matrix = af.parallel_apply(dp_where, axis=1, args=(dp,dp_thresh))
    var_map = {val:f"r{i}" for i, val in enumerate(character_matrix.columns) }
    character_matrix = character_matrix.rename(var_map, axis=1)
    print('character mat')
    print(character_matrix.shape)
    print(character_matrix.head())
    if priors is not None:
        cas_tree = cas.data.CassiopeiaTree(character_matrix=character_matrix, priors=priors)
    else:
        cas_tree = cas.data.CassiopeiaTree(character_matrix=character_matrix)
    cas_tree.character_matrix.head(5)
    
    # REINSTANTIATE the bottom and top solvers
    vanilla_greedy = cas.solver.VanillaGreedySolver()
    if priors is not None:
        ilp_solver = cas.solver.ILPSolver(convergence_time_limit=12600, maximum_potential_graph_layer_size=max(character_matrix.shape[0]+10,10000), weighted=True, seed=1234)
    else:
        ilp_solver = cas.solver.ILPSolver(convergence_time_limit=12600, maximum_potential_graph_layer_size=max(character_matrix.shape[0]+10,10000), weighted=False, seed=1234)
    hybrid_solver = cas.solver.HybridSolver(top_solver=vanilla_greedy, bottom_solver=ilp_solver, cell_cutoff=40, threads=24)
    hybrid_solver.solve(cas_tree, logfile=join(outdir, f"{name}.hybrid.log"))
    rndict = {}
    _iter = 0
    for n in cas_tree.nodes:
        if ',' in n:
            rndict[n] = f'node{_iter}'
            _iter += 1
    cas_tree.relabel_nodes(rndict) 
    print("Saving")
    pickle.dump(cas_tree, open(join(outdir, name+".full.missingV.castree.p"), 'wb'))
    
    plot_network(cas_tree, outdir, name)
    return cas_tree

In [9]:
if set_missing:
    for d in allAD:
        print('donor', d)
        print("num cells:", allAF[d].shape[0])
        run_cass(allAF[d], outdir, name=f"{name}_donor{d}", to_bin=True, dp=allDP[d], 
                 af_thresh=af_thresh, dp_thresh=dp_thresh, priors=None,)
else:
    for d in allAD:
        print('donor', d)
        print("num cells:", allAF[d].shape[0])
        run_cass(allAF[d], outdir, name=f"{name}_donor{d}_noMissingVals", to_bin=True, dp=None, 
                 af_thresh=af_thresh, dp_thresh=None, priors=None,)    

donor 0
num cells: 3853
character mat
(3853, 992)
                            r0  r1  r2  r3  r4  r5  r6  r7  r8  r9  ...  r982  \
AAACGAAAGCGATACG-1_Control  -1  -1  -1  -1   1  -1  -1  -1   1  -1  ...    -1   
AAACGAAGTACCAAGG-1_Control  -1  -1  -1  -1   1  -1  -1  -1   1  -1  ...    -1   
AAACTCGCACAAGGGT-1_Control  -1  -1  -1  -1   1  -1  -1  -1   1  -1  ...    -1   
AAACTCGCACCTCGTT-1_Control  -1  -1  -1  -1   1  -1  -1  -1   1  -1  ...    -1   
AAACTCGGTGCCCTAG-1_Control  -1  -1  -1  -1   1  -1  -1  -1   1  -1  ...    -1   

                            r983  r984  r985  r986  r987  r988  r989  r990  \
AAACGAAAGCGATACG-1_Control    -1    -1    -1    -1    -1    -1    -1    -1   
AAACGAAGTACCAAGG-1_Control    -1    -1    -1    -1    -1    -1    -1    -1   
AAACTCGCACAAGGGT-1_Control    -1    -1    -1    -1    -1    -1    -1    -1   
AAACTCGCACCTCGTT-1_Control    -1    -1    -1    -1    -1    -1    -1    -1   
AAACTCGGTGCCCTAG-1_Control    -1    -1    -1    -1    -1    -1    -1    -

RecursionError: maximum recursion depth exceeded in comparison