# Using vireo for multiplexing, followed by lineage tracing of those groups

# 1. Change to analysis directory

In [None]:
import os
os.chdir("/data2/mito_lineage/Analysis/multiplex")

In [None]:
pwd

### Remove this when running in snakemake

## 1.1 Parameters

In [None]:
# Parameters
INDIR="data/CHIP_april08_2021/MTblacklist/chrM/pseudo/minC200_minAF0.0001/numC25000_ispropFalse/"
#INDIR="data/jan21_2021/chrM/pseudo/minC200_minAF0.01/numC25000_ispropFalse/"
N_DONORS=5
OUTDIR=""#"data/CHIP_april08_2021/MTblacklist/chrM/pseudo/minC200_minAF0.0001/numC25000_ispropFalse/flt3"
#sample_csv="/data2/mito_lineage/parameters/CHIP_april08_2021/CHIP_april08_2021.csv"
sample_names="" #"Control,Flt3l,Input"

In [None]:
# INDIR="data/jan21_2021/chrM/pseudo/minC200_minAF0.01/numC25000_ispropFalse"
# OUTDIR= "data/jan21_2021/chrM/pseudo/minC200_minAF0.01/numC25000_ispropFalse/flt3"
# N_DONORS=4 


## 1.2 Import packages

In [None]:
from os.path import join, exists, dirname
from glob import glob
import mplh.cluster_help as ch
import os
from vireoSNP.plot.base_plot import heat_matrix
from vireoSNP import Vireo
import src.pseudo_batch as pb
import vireoSNP
print(vireoSNP.__version__)
import numpy as np
from scipy import sparse
from scipy.io import mmread
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from src.vireo.vireo_utils import run_vireo, plot_vireo_out, extract_clusters, run_elbo, separate_donors
%load_ext autoreload
%autoreload 2

In [None]:
np.set_printoptions(formatter={'float': lambda x: format(x, '.5f')})

## 1.3 Load

In [None]:
AD_F = join(INDIR, "cellSNP.tag.AD.mtx")
DP_F = join(INDIR, "cellSNP.tag.DP.mtx")
cell_inds_names = join(INDIR,"cell_indices_*.txt")
sample_labels = join(INDIR, "cell_labels.txt")

In [None]:
# columns are 'ID', 'raw ID', and 'new index'. The first contains the
# suffix with the old id, the new index contains the mapping to the
# outputted subsampled cells, which is 1-based, and the raw ID is the initial cell IDs without the suffix.   
sample_labels = pd.read_csv(sample_labels)
sample_labels = sample_labels.sort_values('new index')
print('sample_labels', sample_labels.head())
sample_names = {ind:val for ind,val in enumerate(sample_names.split(","))}
print('sample_names', sample_names)

In [None]:
AD = mmread(AD_F).tocsc()
DP = mmread(DP_F).tocsc()
AD_shape = AD.todense().shape
print('AD_shape', AD_shape)
assert(AD_shape[1]) == sample_labels.shape[0] # These should match up.

## Add in cell IDs to get sample information

In [None]:
sns.palplot(sns.color_palette("Set2", 4))

In [None]:
# A. 
# Input: cell indices files, where first line is name of file, and the rest is a df with column header of 'old index', 'new index'
# Output: A. cell_map: Dictionary where each key is the sample file (relative), and the value is the 'new index' column. 
        # B. in_cell_names: List of the keys, which are files
in_cell_f = glob(cell_inds_names)
print('cell inds', in_cell_f)
in_cell_names = []
cell_map = {}
for curr_cell_f in in_cell_f:
    print(curr_cell_f)
    with open(curr_cell_f, "r") as f:
        count = 0
        lines = []
        for line in f: 
            if count == 0:
                curr_f = line.strip()
            else:
                lines.append(line.strip().split(','))
            count += 1
    curr_df = pd.DataFrame(lines[1:], columns=lines[0])
    cell_map[curr_f] = curr_df["new index"].astype(int).values
    in_cell_names.append(curr_f)

# Get the maximum index across the two
max_v = 0
for i in in_cell_names:
    max_v = max(max_v, max(cell_map[i]))

# B. Input: in_cell_names: The list of keys
#    Output: cell_ind_map:  dict where keys are the sample map names and the values are unique ordered 0-indexed ints,
#                           with the order based on in_cell_names
cell_ind_map = {}
count = 0
for i in in_cell_names: 
    cell_ind_map[i] = count
    count += 1



# cell_inds: np.array where the length is the number of cells in the pseudo population, 
#            and the element is an int that maps to a sample based on the cell_ind_map. 
cell_inds = -1*np.ones(max_v)
for i in in_cell_names:
    cell_inds[cell_map[i]-1] = cell_ind_map[i] #cell_map[i]
cell_inds=cell_inds.astype(int)
print('cell_inds', cell_inds)
# Assign colors to the samples.
sample_colors = pd.DataFrame([sample_names[x] for x in cell_inds], columns=["sample ID"])
colors = sns.color_palette("Set2", len(list(set(sample_colors['sample ID'])))) #{0:"blue", 1:"red"}
cell_colors = [colors[x] for x in cell_inds]
cell_colors = pd.Series(cell_colors, name="sample ID")

sample_colors

# 2 Run vireo to demultiplex

In [None]:
modelCA, elbo = run_vireo(AD, DP, N_DONORS, n_cores=32, plot_qc=True,out_f=join(OUTDIR, "donors"))
doublet_prob = modelCA.predict_doublet(AD, DP, update_GT=False, update_ID=False)[0].sum(axis=1)

## Extract donors

In [None]:
# cell_clusters = extract_clusters(modelCA, prob_thresh=0.9, doublet_thresh=0.9, doublet_prob=doublet_prob,
#                                  sample_colors=sample_colors, outdir=OUTDIR, out_f="donors")
# cell_clusters

In [None]:
cell_clusters = separate_donors(AD, DP, modelCA, sample_labels, OUTDIR, N_DONORS,
                    doublet_prob, sample_colors,
                    prob_thresh = 0.9, doublet_thresh = 0.9)
cell_clusters

In [None]:
run_elbo(AD, DP, out_f=join(OUTDIR, "donors_elbo"), 
         n_clone_list=np.arange(N_DONORS-2, N_DONORS+3), 
         n_cores=12, sample_colors=sample_colors, save_clusters=False,)

In [None]:
clust_df, AF_SNPs = plot_vireo_out(modelCA, out_f=join(OUTDIR, "multiplex_clusters_all"), labels=sample_colors,
                                   to_sqrt=False, doublet_prob=doublet_prob)

In [None]:
for ind, val in sample_colors.groupby("sample ID"):
    print(ind)
    print(clust_df.loc[val.index]) #, curr_AF_SNPs
    f = plt.figure()
#     curr_AF_SNPs = AF_SNPs 
#     im = heat_matrix(np.sqrt(curr_AF_SNPs[(curr_AF_SNPs.sum(axis=1)>0.01),:]), cmap="Blues", alpha=0.8,
#             display_value=False, row_sort=True)
    ch.plot_cluster(clust_df.loc[val.index], cmap='Oranges', alpha=0.8,
                to_row_clust=True, to_col_clust=False,
                to_legend=True,
                white_name=None)
    plt.suptitle(ind)
    plt.savefig(join(OUTDIR, f"multiplex_clusters_{ind}"))

In [None]:
f = plt.figure()
im = heat_matrix(np.sqrt(AF_SNPs[(AF_SNPs.sum(axis=1)>0.01),:]), cmap="Blues", alpha=0.8,
                 display_value=False, row_sort=True)
plt.savefig(join(OUTDIR, "multiplex_AF_SNPs_all_afFilt"))

## Separate for each sample before