In [1]:
import pandas as pd
from os.path import join, isdir, exists, basename, dirname
import os
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from icecream import ic
from src.utils import variant_utils as vu
from glob import glob
from itertools import product
from collections import defaultdict
from tqdm.notebook import tqdm
import pickle

In [2]:
# don_dir = "/data/Mito_Trace/output/pipeline/v02/CHIP_b1/MTBlacklist_A2/data/merged/MT/cellr_True/numread_200/filters/minC10_minR50_topN0_hetT0.001_hetC10_hetCount5_bq20/mgatk/vireoIn/multiplex/clones_simpleUnion/"
# outdir = "/data/Mito_Trace/output/pipeline/v02/CHIP_b1/MTBlacklist_A2/data/merged/MT/cellr_True/numread_200/filters/minC10_minR50_topN0_hetT0.001_hetC10_hetCount5_bq20/mgatk/vireoIn/multiplex/clones_simpleUnion/mt_clones_thresh/"
# #samples = "preB,postB"

# ## Parameters:
# # af thresholds
# # coverage threshold
# # other_af thresholds
# # number of cells / fraction cells
# # number of other cells / fraction other cells

# ### Smaller params
# af_ts = [0.01, 0.1, 0.4, 0.8,]
# oth_af_ts = [0.01, 0.1, 0.4]
# cov_ts = [2, 10, 30] # mean coverage at position for cells with the AF
# oth_cov_ts = [2, 10, 30]  # mean coverage for cells without AF 
# num_cells = [5, 10, 30] # num cells and fraction are for cells with sufficient coverage
# oth_num_cells = [0.25, 0.6, 0.8] # fraction is of cells with sufficient coverage
# mean_pos_cov = [0, 10] #populatiion average coverage at that position

# use_small = False
# to_plots = True




# af_ts = [0.01, 0.1, 0.4, 0.8, 0.95]
# oth_af_ts = [0.01, 0.1, 0.4, 0.8, 0.95]
# cov_ts = [2, 5, 10, 30] # mean coverage at position for cells with the AF
# oth_cov_ts = [2, 5, 10, 30]  # mean coverage for cells without AF 
# num_cells = [5, 10, 0.1, 0.2, 0.5 ] # num cells and fraction are for cells with sufficient coverage
# oth_num_cells = [0.25, 0.6, 0.8] # fraction is of cells with sufficient coverage
# mean_pos_cov = [0, 2, 5, 10] #populatiion average coverage at that position


In [None]:
use_small = snakemake.params.get("use_small", False)
to_plots = snakemake.params.get("to_plots", True)
don_dir = snakemake.input["don_dir"]
outdir = snakemake.params["outdir"]

params = snakemake.config["mt_as_clones"]["params"]
af_ts = params["af_ts"]
oth_af_ts = params["oth_af_ts"]
cov_ts = params["cov_ts"]
oth_cov_ts = params["oth_cov_ts"]
num_cells = params["num_cells"]
oth_num_cells = params["oth_num_cells"]
mean_pos_cov = params["mean_pos_cov"]

In [3]:
cells_dir = join(outdir, "cells")
if not exists(cells_dir):
    os.mkdir(cells_dir)

In [4]:
cells_dir

'/data/Mito_Trace/output/pipeline/v02/CHIP_b1/MTBlacklist_A2/data/merged/MT/cellr_True/numread_200/filters/minC10_minR50_topN0_hetT0.001_hetC10_hetCount5_bq20/mgatk/vireoIn/multiplex/clones_simpleUnion/mt_clones_thresh/cells'

In [5]:
#samples = samples.split(",")

In [6]:
#don_dirs = [d for d in glob(don_dir + "/donor*") if isdir(d)]
don_dirs = {int(basename(dirname(x)).split("donor")[-1]):dirname(x) for x in don_dir} #donor{d}/af.tsv

don_dirs

{0: '/data/Mito_Trace/output/pipeline/v02/CHIP_b1/MTBlacklist_A2/data/merged/MT/cellr_True/numread_200/filters/minC10_minR50_topN0_hetT0.001_hetC10_hetCount5_bq20/mgatk/vireoIn/multiplex/clones_simpleUnion/donor0',
 1: '/data/Mito_Trace/output/pipeline/v02/CHIP_b1/MTBlacklist_A2/data/merged/MT/cellr_True/numread_200/filters/minC10_minR50_topN0_hetT0.001_hetC10_hetCount5_bq20/mgatk/vireoIn/multiplex/clones_simpleUnion/donor1'}

In [7]:
# There are 7 params to use for calling the clone
params = {"af":af_ts,
          "oth_af":oth_af_ts,
          "ncells": num_cells,
          "oth_ncells":oth_num_cells,
          "mean_cov": mean_pos_cov,
          "cov":cov_ts,
          "oth_cov":oth_cov_ts,
         }

full_params = list(product(*list(params.values())))
full_params = pd.DataFrame(full_params, columns=params.keys())
full_params.head()

Unnamed: 0,af,oth_af,ncells,oth_ncells,mean_cov,cov,oth_cov
0,0.01,0.01,5,0.25,0,2,2
1,0.01,0.01,5,0.25,0,2,10
2,0.01,0.01,5,0.25,0,2,30
3,0.01,0.01,5,0.25,0,10,2
4,0.01,0.01,5,0.25,0,10,10


In [8]:
full_params.shape

(1944, 7)

In [9]:
def load_donor(don_dir):
    vcf = pd.read_csv(join(don_dir, "cellSNP.base.vcf"), sep="\t")
    variants_meta = vcf[["#CHROM", "POS", "REF", "ALT"]]
    variants_meta.index = variants_meta.apply(lambda x: f"{x['POS']}{x['REF'][0]}>{x['ALT']}", axis=1)
    variants_meta = vu.type_of_variants(variants_meta.index)
    variants_meta["ID"] = variants_meta.index
    variants_meta.index = variants_meta.apply(lambda x: f"{x['position']}{x['alt']}", axis=1)
    af = pd.read_csv(join(don_dir, "af.tsv"), sep="\t", index_col=0)
    cov = pd.read_csv(join(don_dir, "dp.tsv"), sep="\t", index_col=0)                   
    cells_meta = pd.read_csv(join(don_dir,"cells_meta.tsv"), sep="\t").set_index("ID")
    ic(cells_meta.shape)
    cells_meta = cells_meta.loc[af.index]
    ic("After filtering on af indices", cells_meta.shape)
    
    variants = set(af.columns)
    return af, cov, cells_meta, variants



## Informative variants function
#af	oth_af	ncells	oth_ncells	mean_cov	cov	oth_cov
def get_vars(thresholds, position_af, position_cov):
    """ Sees if current position is an 'informative' variant using current thresholds
    """
    # preprocess ncells thresholds
    if thresholds["ncells"] < 1: #fraction of cells or total
        ncells = int(np.floor(thresholds["ncells"] * len(position_af)))
    else:
        ncells = thresholds["ncells"]
        
    # Get cells that pass af and coverage thresholds
    bin_position_af_cov = (position_af > thresholds["af"]) & (position_cov > thresholds["cov"])
    in_cells = bin_position_af_cov[bin_position_af_cov].index
    
    #print('bin_pos', bin_position_af_cov)
    #print(bin_position_af_cov.sum())
    
    if thresholds["oth_ncells"] < 1: # Fraction of the remaining population
        oth_ncells = int(np.floor(thresholds["oth_ncells"] * len(position_af)-len(in_cells)))
    else:
        oth_ncells = thresholds["oth_ncells"]    

    # positions that pass the coverage threshold but not the AF threshold. 
    # These are cells confidently not with variant, and oth_ncells needed to pass threshold.
    oth_bin_position_af_cov = (position_af <= thresholds["oth_af"]) & (position_cov > thresholds["oth_cov"])
    oth_bin_position_af_cov = oth_bin_position_af_cov.loc[~(oth_bin_position_af_cov.index.isin(in_cells))]
    
    if (bin_position_af_cov.sum() > ncells) & (oth_bin_position_af_cov.sum() > oth_ncells):
        return True
    return False

def run_per_pos(af_cov, thresholds ,cov_id="COV-"):
    split = int(af_cov.shape[0]/2)
    af = af_cov.iloc[:split]
    cov = af_cov.iloc[split:]

    cov.index = [x.split(cov_id,1)[1] for x in cov.index]
    thresholds['isVar'] = thresholds.apply(get_vars, args=(af, cov), axis=1)
    thresholds["var"] = af_cov.name
    return thresholds


## Plot 
def plots(params_results, outdir, prefix):
    sns.histplot(params_results["n_vars"])

    cov_groups = params_results.groupby(["cov", "oth_cov", "mean_cov"])

    for ind, val in cov_groups:
        print(ind)
        sns.catplot(data=val, y="n_vars", x= "ncells", 
                    hue="af", row="oth_af", col="oth_ncells", kind="bar")
        plt.suptitle(f"{prefix} cov {ind[0]} oth_cov {ind[1]} mean_cov {ind[2]}")
        plt.tight_layout()
        plt.savefig(join(outdir,f"{prefix}_nvars_thresholds.cov_{ind[0]}_othcov_{ind[1]}_mean_{ind[2]}.png"))
    return

In [10]:
for d in don_dirs: 
    curr_af, curr_cov, curr_cells_meta, curr_variants = load_donor(don_dirs[d])
    if curr_af.shape[0] == 0:
        continue
    curr_cov_name = curr_cov.copy()   
    curr_cov_name.index = "COV-" + curr_cov.index
    curr_af_cov = pd.concat([curr_af, curr_cov_name], verify_integrity=True, sort=False)

    if use_small: 
        curr_af_cov = pd.concat([curr_af.iloc[:100,:100], curr_cov_name.iloc[:100,:100]], verify_integrity=True, sort=False)
        full_params_df = full_params.iloc[:100]
    else:
        full_params_df = full_params.copy()
    #curr_af_cov.apply(run_per_pos, args=(full_params_df,),expand=True)
    all_passed_vars = defaultdict(list)
    params_results = full_params_df.copy()
    params_results["Variants"] = None
    for ind, val in tqdm(full_params_df.iterrows()):
        # get boolean variant array and keep the ones that are true
        curr_vars = []

        #c_vars_df = pd.DataFrame(index=curr_af.o)
        for c_var in curr_af.columns:
            pos_af = curr_af[c_var]
            pos_cov = curr_cov[c_var]
            isVar = get_vars(val, pos_af, pos_cov)
            if isVar:
                all_passed_vars[tuple(val.values)].append(c_var)
                curr_vars.append(c_var)
        params_results.loc[ind,"Variants"] = ";".join(curr_vars)
        #full_params_df.loc[ind, "Variants"] = curr_vars
    #     if ind>200:
    #         break

    seen_variants = params_results["Variants"].dropna().apply(lambda x: x.split(";"))
    params_results["Variants List"] = seen_variants
    seen_variants = set(sum(list(seen_variants.values), []))
    seen_variants

    params_results["n_vars"] = params_results["Variants List"].apply(lambda x: len(x) if type(x) is list else 0)
    
    ## Save
    params_results.drop("Variants List", axis=1).to_csv(join(outdir, f"donor_{d}_thresh_results.tsv"), sep="\t", index=False)
    
    if to_plots:
        plots(params_results,outdir, prefix=f"donor_{d}")



ic| cells_meta.shape: (9023, 6)
ic| 'After filtering on af indices', cells_meta.shape: (9023, 6)


0it [00:00, ?it/s]

ic| cells_meta.shape: (8275, 6)
ic| 'After filtering on af indices', cells_meta.shape: (8275, 6)


0it [00:00, ?it/s]

## Check subset variants

## For each thresh get the variants and other and output the list of cells

In [11]:
len(params_results.groupby(["af", "oth_af", "ncells", "oth_ncells", "cov", "oth_cov"]))

972

In [12]:
len(params_results.groupby(["af", "oth_af", "cov", "oth_cov"]))

108

In [13]:
def get_cells(af, cov, oth_af, oth_cov, position_af, position_cov):
    """ Gets cells and oth_cells that pass thresholds. Assumes position already passed.
    """
    # Get cells that pass af and coverage thresholds
    bin_position_af_cov = (position_af > af) & (position_cov > cov)
    in_cells = bin_position_af_cov[bin_position_af_cov].index
    # positions that pass the coverage threshold but not the AF threshold. 
    # These are cells confidently not with variant, and oth_ncells needed to pass threshold.
    oth_bin_position_af_cov = (position_af <= oth_af) & (position_cov > oth_cov)
    # drop the cells
    oth_bin_position_af_cov = oth_bin_position_af_cov.loc[~(oth_bin_position_af_cov.index.isin(in_cells))]
    oth_cells = oth_bin_position_af_cov[oth_bin_position_af_cov].index
    return {"clone_cells": in_cells, "other_cells": oth_cells}

In [14]:
for d in don_dirs: 
    curr_af, curr_cov, curr_cells_meta, curr_variants = load_donor(don_dirs[d])
    if curr_af.shape[0] == 0:
        continue
    curr_cov_name = curr_cov.copy()   
    curr_cov_name.index = "COV-" + curr_cov.index
    curr_af_cov = pd.concat([curr_af, curr_cov_name], verify_integrity=True, sort=False)

    params_results = pd.read_csv(join(outdir, f"donor_{d}_thresh_results.tsv"), sep="\t")
    for thresh, curr_df in params_results.groupby(["af", "oth_af", "cov", "oth_cov"]):
        print(thresh)
        # get all variants seen here:
        curr_vars = set()
        for x in curr_df["Variants"].dropna().values:
            curr_vars = curr_vars.union(set(x.split(";")))
        
        curr_cells_vars_d = {}
        for c_var in curr_vars: #curr_af.columns:
            pos_af = curr_af[c_var]
            pos_cov = curr_cov[c_var]
            curr_cells_vars_d[c_var] = get_cells(*(thresh), pos_af, pos_cov)
        curr_f = join(cells_dir, f"don.{d}_af.{thresh[0]}_othaf.{thresh[1]}_cov.{thresh[2]}_othcov.{thresh[3]}.p")
        pickle.dump(curr_cells_vars_d, open(curr_f,'wb'))

ic| cells_meta.shape: (9023, 6)
ic| 'After filtering on af indices', cells_meta.shape: (9023, 6)


(0.01, 0.01, 2, 2)
541
751
419
776
433
120
14
784
581
132
591
535
407
113
453
530
589
586
529
655
930
31
747
199
92
257
727
590
577
109
617
626
83
672
40
676
4357
593
641
(0.01, 0.01, 2, 10)
541
751
419
776
433
120
14
784
581
132
591
535
407
113
453
530
589
586
529
655
930
31
747
199
92
257
727
590
577
109
617
626
83
672
40
676
4357
593
641
(0.01, 0.01, 2, 30)
541
751
419
776
433
120
14
784
581
132
591
535
407
113
453
530
589
586
655
930
31
747
199
92
257
727
590
577
109
617
626
83
672
40
676
529
593
641
(0.01, 0.01, 10, 2)
541
751
419
776
433
120
14
784
581
132
591
535
407
113
453
530
589
586
529
655
930
31
747
199
92
257
727
590
577
109
617
626
83
672
40
676
4357
593
641
(0.01, 0.01, 10, 10)
541
751
419
776
433
120
14
784
581
132
591
535
407
113
453
530
589
586
529
655
930
31
747
199
92
257
727
590
577
109
617
626
83
672
40
676
4357
593
641
(0.01, 0.01, 10, 30)
541
751
419
776
433
120
14
784
581
132
591
535
407
113
453
530
589
586
655
930
31
747
199
92
257
727
590
577
109
617
626
83


192
132
89
226
246
67
209
230
215
210
233
(0.1, 0.4, 30, 10)
221
210
190
238
183
61
225
221
78
228
180
170
64
198
217
216
191
334
130
4186
215
124
63
220
209
192
132
89
226
246
67
209
230
215
210
233
(0.1, 0.4, 30, 30)
221
210
190
238
183
61
225
221
78
228
180
170
64
198
217
216
191
334
130
215
124
63
220
209
192
132
89
226
246
67
209
230
215
210
233
(0.4, 0.01, 2, 2)
23
10
18
13
16
13
19
8
41
18
7
10
43
14
20
9
11
238
3575
7
11
33
165
6
17
67
15
16
50
18
12
15
11
12
(0.4, 0.01, 2, 10)
23
10
18
13
16
13
19
8
41
18
7
10
43
14
20
9
11
238
3575
7
11
33
165
6
17
67
15
16
50
18
12
15
11
12
(0.4, 0.01, 2, 30)
23
10
18
13
16
13
19
8
41
18
7
10
43
14
20
9
11
238
7
11
33
165
6
17
67
15
16
50
18
12
15
11
12
(0.4, 0.01, 10, 2)
238
67
13
19
50
43
33
165
3575
41
(0.4, 0.01, 10, 10)
238
67
13
19
50
43
33
165
3575
41
(0.4, 0.01, 10, 30)
238
67
13
19
50
43
33
165
41
(0.4, 0.01, 30, 2)
238
67
50
43
33
165
3575
41
(0.4, 0.01, 30, 10)
238
67
50
43
33
165
3575
41
(0.4, 0.01, 30, 30)
238
67
50
43
33
165
41

ic| cells_meta.shape: (8275, 6)
ic| 'After filtering on af indices', cells_meta.shape: (8275, 6)


(0.01, 0.01, 2, 2)
607
815
535
372
195
50
554
107
893
565
521
767
634
725
39
291
866
463
990
773
120
754
90
33
575
5539
14
436
15
96
300
(0.01, 0.01, 2, 10)
607
815
535
372
195
50
554
107
893
565
521
767
634
725
39
291
866
463
990
773
120
754
90
33
575
5539
14
436
15
96
300
(0.01, 0.01, 2, 30)
607
291
815
866
463
990
773
120
535
372
195
50
754
90
33
554
107
575
893
565
14
521
767
634
725
436
15
39
96
300
(0.01, 0.01, 10, 2)
607
815
535
372
195
50
554
107
893
565
521
767
634
725
39
291
866
463
990
773
120
754
90
33
575
5539
14
436
15
96
300
(0.01, 0.01, 10, 10)
607
815
535
372
195
50
554
107
893
565
521
767
634
725
39
291
866
463
990
773
120
754
90
33
575
5539
14
436
15
96
300
(0.01, 0.01, 10, 30)
607
291
815
866
463
990
773
120
535
372
195
50
754
90
33
554
107
575
893
565
14
521
767
634
725
436
15
39
96
300
(0.01, 0.01, 30, 2)
607
815
535
372
195
50
554
107
893
565
521
767
634
725
39
291
866
463
990
773
120
754
90
33
575
5539
14
436
15
96
300
(0.01, 0.01, 30, 10)
607
815
535
372
195
50

22
19
26
53
7
(0.4, 0.4, 2, 10)
24
71
15
18
24
83
23
43
41
37
22
26
18
32
24
14
16
5040
15
6
22
19
26
53
7
(0.4, 0.4, 2, 30)
24
71
15
18
24
83
23
43
41
37
22
26
18
32
24
14
16
15
6
22
19
26
53
7
(0.4, 0.4, 10, 2)
71
15
24
83
23
43
41
37
22
26
32
14
16
5040
19
26
53
(0.4, 0.4, 10, 10)
71
15
24
83
23
43
41
37
22
26
32
14
16
5040
19
26
53
(0.4, 0.4, 10, 30)
71
15
24
83
23
43
41
37
22
26
32
14
16
19
26
53
(0.4, 0.4, 30, 2)
71
5040
53
83
43
41
37
26
32
(0.4, 0.4, 30, 10)
71
5040
53
83
43
41
37
26
32
(0.4, 0.4, 30, 30)
71
53
83
43
41
37
26
32
(0.8, 0.01, 2, 2)
30
4503
40
65
17
31
14
(0.8, 0.01, 2, 10)
30
4503
40
65
17
31
14
(0.8, 0.01, 2, 30)
30
14
65
17
31
40
(0.8, 0.01, 10, 2)
30
4503
40
65
17
31
14
(0.8, 0.01, 10, 10)
30
4503
40
65
17
31
14
(0.8, 0.01, 10, 30)
17
30
31
14
40
65
(0.8, 0.01, 30, 2)
17
30
31
4503
14
40
65
(0.8, 0.01, 30, 10)
17
30
31
4503
14
40
65
(0.8, 0.01, 30, 30)
17
30
31
14
40
65
(0.8, 0.1, 2, 2)
30
4503
40
65
17
31
14
(0.8, 0.1, 2, 10)
30
4503
40
65
17
31
14
(0.8, 0.1,

In [17]:
params_results.sort_values("n_vars", ascending=False)

Unnamed: 0,af,oth_af,ncells,oth_ncells,mean_cov,cov,oth_cov,Variants,n_vars
0,0.01,0.01,5,0.25,0,2,2,10397G;10589A;11453A;11761T;13188T;14674C;146C...,31
54,0.01,0.01,10,0.25,0,2,2,10397G;10589A;11453A;11761T;13188T;14674C;146C...,31
325,0.01,0.40,5,0.25,0,2,10,10397G;10589A;11453A;11761T;13188T;14674C;146C...,31
327,0.01,0.40,5,0.25,0,10,2,10397G;10589A;11453A;11761T;13188T;14674C;146C...,31
328,0.01,0.40,5,0.25,0,10,10,10397G;10589A;11453A;11761T;13188T;14674C;146C...,31
...,...,...,...,...,...,...,...,...,...
1132,0.40,0.01,30,0.80,10,30,10,15297C,1
1133,0.40,0.01,30,0.80,10,30,30,,1
38,0.01,0.01,5,0.80,0,2,30,,1
35,0.01,0.01,5,0.60,10,30,30,3244A,1


In [None]:
cmd = f"touch {outdir}/.complete"
!{cmd}

## For each pairwise variants check if the cells with the variant (and coverage) is over 50%, if it is remove the smaller one (create tree)