In [None]:
# import argparse

# parser = argparse.ArgumentParser(description='Supply reference fasta and bam file')
# parser.add_argument('ref',
#                     help='reference fasta')
# parser.add_argument('bam',
#                     help='bam file')


# args = parser.parse_args()
# ref_fname = args.ref
# bam_fname = args.bam
ref_fname = 'data/reference/refchrm.fa'
bam_fname = 'test/in10.bam'
genomes_fname ='data/genomes/311humans.fasta'

In [None]:
import os
from collections import Counter
import pysam
import numpy as np
from tqdm import tqdm
from IPython.display import clear_output
from scipy.special import binom
import scipy.stats as st
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import subprocess
from multiprocess import Pool
%load_ext cython

In [None]:
#Cython part

In [None]:
%%cython -a --compile-args=-O3
from tqdm import tqdm
import pysam
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from collections import Counter
cimport cython
import numpy as np
from cython.parallel import prange
from libc.math cimport pow
from scipy.special import binom
# from mc_lib.rndm cimport RndmWrapper


def get_num_reads(str bam_fname):
    samfile = pysam.AlignmentFile(bam_fname, "rb" )
    num_reads = 0
    for read in samfile.fetch('chrM'):
        if not read.is_mapped:
            continue
        num_reads += 1
    samfile.close()
    return num_reads

def bam2consensus(
        ref_fname, bam_fname, double ac_threshold=0, double af_threshold=0):

    cdef str consensus = ''
    cdef int max_count, total_cgeount
    cdef str allele
    cdef str max_allele
    for record in SeqIO.parse(ref_fname, "fasta"):
        assert record.id == 'chrM'
        # consensus = "N" * len(record)

    with pysam.AlignmentFile(bam_fname, "rb") as bam:
        allele_counter = Counter()
        for pileup_column in tqdm(bam.pileup(), total=16569, desc = 'consensus dna'):
            assert pileup_column.reference_name == 'chrM'
            pos = pileup_column.reference_pos

            allele_counter.clear()
            for pileup_read in pileup_column.pileups:
                if pileup_read.is_del:
                    allele = "-"
                else:
                    allele = pileup_read.alignment.query_sequence[
                        pileup_read.query_position]
                allele_counter[allele] += 1

            max_allele = "N"
            max_count, total_count = 0, 0
            for allele, count in allele_counter.items():
                if count > max_count:
                    max_count = count
                    max_allele = allele
                total_count += count

            assert max_allele in "ACGTN-"
            if (max_count >= ac_threshold and
                max_count / total_count >= af_threshold):
                consensus += max_allele

    return consensus


@cython.cdivision(True)
@cython.wraparound(False)
@cython.boundscheck(False)
def get_MN(char[:, :] genomes,str bam_fname):
    bam = pysam.AlignmentFile(bam_fname, "rb")
    cdef double[:, :] N, M
    cdef int k, i, j, pos
    cdef double correct, incorrect, P_cor
    cdef str seq
    cdef long num_reads = get_num_reads(bam_fname)
    cdef int num_genomes = genomes.shape[0]
    N = np.zeros((num_reads, num_genomes))
    M = np.zeros((num_reads, num_genomes))
    i = 0
    j = 0
    for read in tqdm(bam.fetch('chrM'), total = bam.count(), desc = 'MN tables'):
        seq = read.query_sequence
        pos = read.reference_start
        qual = read.query_qualities
        # print(qual)
        if not read.is_mapped:
            continue
        for j in range(num_genomes):
            correct = 0
            incorrect = 0
            for k in range(len(seq)):
                if k+pos >= genomes.shape[1]:
                    break
                # if chr(genomes[j][k+pos]).upper() == '-':
                #     # print('FFFFF')
                #     M[i, j] = -1
                #     N[i, j] = -1
                if seq[k] == chr(genomes[j][k+pos]).upper():
                    P_cor = 10**(- qual[k]/10)
                    correct += P_cor
                    incorrect += 1 - P_cor
                else:
                    # print(seq[k], chr(genomes[j][k+pos]).upper())
                    P_cor = (1 - 10**(- qual[k]/10))/3
                    incorrect += 1 - P_cor
                    correct += P_cor
            if M[i, j] != -1:
                M[i, j] = correct
                N[i, j] = incorrect
        i += 1
    bam.close()
    return np.array(M, dtype=np.float64), np.array(N, dtype=np.float64)

@cython.cdivision(True)
@cython.wraparound(False)
@cython.boundscheck(False)
def get_mc(long[:, ::1] m, long[:, ::1] n, double eps):
    cdef double[:,::1] mc
    cdef long num_reads = m.shape[0]
    cdef long num_genomes = m.shape[1] 
    cdef long j
    mc = np.zeros((num_reads, num_genomes))
    for i in range(num_reads):
        for j in range(num_genomes):
            if m[i, j] == -1:
                mc[i, j] = 0
            else:
                mc[i, j] = binom(m[i, j] + n[i, j], m[i, j]) * pow((1 - eps),(m[i, j])) * pow(eps,(n[i, j]))
    return np.asarray(mc)
    




@cython.cdivision(True)
@cython.wraparound(False)
@cython.boundscheck(False)
def get_Zi(double[:,::1] mc, double[::1] p,double eps, long i):
    cdef long num_reads = mc.shape[0]
    cdef long num_genomes = mc.shape[1] 
    cdef long j
    cdef long Z
    cdef double[:] probs
    cdef double s
    s = 0
    probs = np.zeros(num_genomes, dtype = float)
    

    for j in range(num_genomes):
        probs[j] = mc[i, j] * p[j]
        s += probs[j]
    if s == 0:
        return np.random.randint(0, num_genomes)
    else:
        for j in range(num_genomes):
            probs[j] =probs[j] / s
    
    Z = np.random.choice(np.arange(0, num_genomes), p = probs)
    # print(np.asarray(probs))
    return Z



@cython.cdivision(True)
@cython.wraparound(False)
@cython.boundscheck(False)
def get_Zi_new(double[:,::1] mc, double[::1] p,double eps, long i):
    cdef long num_reads = mc.shape[0]
    cdef long num_genomes = mc.shape[1] 
    cdef long j
    cdef long Z
    cdef double[:] probs
    cdef double s
    cdef double rval = uniform()
    # print(rval)
    s = 0
    probs = np.zeros(num_genomes, dtype = float)
    

    for j in range(num_genomes):
        probs[j] = mc[i, j] * p[j]
        s += probs[j]

    for j in range(num_genomes):
        probs[j] =probs[j] / s
        
    
    
    
    cdef double cum_probs = 0
    cdef int choice = 0
    while True:
        if cum_probs > rval:
            break
        cum_probs += probs[choice]
        choice += 1
    # print(np.asarray(probs))
    return choice


@cython.cdivision(True)
@cython.wraparound(False)
@cython.boundscheck(False)
def get_eta(long[:] z,int num_genomes):
    cdef long[:] eta
    eta = np.zeros(num_genomes, dtype = int)
    cdef int num_reads
    num_reads = z.shape[0]
    cdef int i
    for i in range(num_reads):
        eta[z[i]] += 1
    return np.array(eta)

cdef extern from "stdlib.h":
    double drand48()
    void srand48(long int seedval)

def uniform():
    return drand48() #This gives a float in range [0,1)


In [None]:
def preprocess(ref_fname, genomes_fname, bam_fname):
    base = bam_fname[:-4]
    pysam.index(bam_fname);
    os.system(f"samtools view {bam_fname} chrM -o {base+'_mt.bam'}")
    base = base+'_mt'
    print('#EXTRACTING MTDNA OK')
    consensus = bam2consensus(ref_fname, bam_fname)
    consensus_fa = '>chrM\n'+''.join(consensus) +'\n'
    with open(f'{base}.fa', 'w') as new_genomes:
        new_genomes.write(consensus_fa)
        
    os.system(f'cat {base}.fa {genomes_fname} > {base}_genomes.fa');
    os.system(f'mafft {base}_genomes.fa >  {base}_aligned.fa')
    aligned_genomes = f'{base}_aligned.fa'
    print("#ALL GENOMES ARE READY")
    new_cons = list(SeqIO.parse(f'{base}_aligned.fa', "fasta"))[0]
    SeqIO.write(new_cons, f'{base}.real.fa', "fasta")
    os.system(f'bwa index -a bwtsw {base}.real.fa')
    os.system(f'samtools faidx {base}.real.fa')
    os.system(f'rm {base}.dict')
    os.system(f'picard CreateSequenceDictionary R={base}.fa O={base}.dict')
    os.system(f'samtools fastq {bam_fname} > {base}.fq')
    os.system(f'bwa aln -l 1000 -t 10 {base}.real.fa {base}.fq > {base}_ra.sai')
    os.system(f"bwa samse -r '@RG\\tID:{base}\\tLB:{base}_L1\\tPL:ILLUMINA\\tSM:{base}' {base}.real.fa {base}_ra.sai {base}.fq | samtools view -bh -q 30 | samtools sort -O BAM -o {base}_ra.sort.bam")
    os.system(f'picard MarkDuplicates I={base}_ra.sort.bam O={base}_ra.sort.rmdup.bam METRICS_FILE=metrics.txt REMOVE_DUPLICATES=true ASSUME_SORTED=false VALIDATION_STRINGENCY=LENIENT')
    pysam.index(f'{base}_ra.sort.rmdup.bam');
    os.system(f'samtools calmd -Erb {base}_ra.sort.rmdup.bam {base}.fa > {base}_ra.final.bam 2>/dev/null');
    bam_final = f'{base}_ra.final.bam'
    os.system(f'samtools index {base}_ra.final.bam')
    os.system(f'rm {base}_ra.sai')
    print("#BAM FILE IS READY")
    return bam_final, aligned_genomes

In [None]:
def make_genomes_arr(genomes_fname):
    genomes = list()
    for record in SeqIO.parse(genomes_fname, "fasta"):
            genomes.append(str(record.seq))
    genomes_arr = np.array([list(x) for x in genomes], dtype = 'S1')
    return genomes_arr

In [None]:
def get_same(genomes_arr):
    same_positions = []
    same_bases = []
    for i in range(genomes_arr.shape[1]):
        if len(np.unique(genomes_arr[:,i])) == 1:
            same_positions.append(i)
            same_bases.append(genomes_arr[0,i])
    same = dict(zip( same_positions, same_bases ))
    return same

In [None]:
def get_base_err(bam_fname, same_dict):
    samfile = pysam.AlignmentFile(bam_fname, "rb" )
    same_positions = list(same_dict.keys())
    correct = 0
    total = 0
    samfile = pysam.AlignmentFile(bam_fname, "rb" )
    for pileupcolumn in tqdm(samfile.pileup("chrM")):
        pos = pileupcolumn.pos
        if pos not in same_positions:
            continue

        for pileupread in pileupcolumn.pileups:
            if not pileupread.is_del and not pileupread.is_refskip:
                total += 1
                # query position is None if is_del or is_refskip is set.
                nbase =  pileupread.alignment.query_sequence[pileupread.query_position]
                if nbase == same[pos].decode('ascii').upper():
                    correct += 1
    base_err = 1 - correct/total
    samfile.close()
    return base_err

In [None]:
# samfile.close()
# del genomes

In [None]:
# def Pr_Dep(p, m, n, i, eps):
#     num_reads, num_genomes = M.shape
#     ret = 0
#     for j in range(num_genomes):
#         ret += p[j]* Pr_De(m, n, i, j, eps)
#     return ret

In [None]:
# def get_Z(m, n, p, eps):
#     num_reads, num_genomes = M.shape
#     Z = np.zeros(num_reads)
#     for i in tqdm(range(num_reads), desc = f'gettitng Z'):
#         probs = np.array([Pr_De(m, n, i, j, eps) * p[j] for j in range(num_genomes)])
#         probs = probs/probs.sum()
#         Z[i] = np.random.choice(np.arange(0, num_genomes), p = probs)
#         # print(f'p = {p}\nprobs={probs}\nZ={Z[j]}')
#     return Z

In [None]:
# def get_eta(z, num_genomes):
#     return np.array([len(z[z==j]) for j in range(num_genomes)])

In [None]:
# p = np.array(np.random.dirichlet([1]*num_genomes)
def do_mcmc(n_iterations = 50000, output_file = ''):
    if output_file != '':
        res = open(output_file,'w')
    num_reads, num_genomes  = MC.shape
    print(MC.shape)
    p = np.array([0.5] + [1/(num_genomes-1)]*(num_genomes-1))
    pool = Pool()
    for i in tqdm(range(n_iterations) ):
        
        func = lambda x: get_Zi(MC, p, base_err, x)
        
        Z = np.array(pool.map_async(func, range(num_reads)).get())
        
        
        eta = get_eta(Z, num_genomes)
        p0 = np.random.beta(1 + eta[0],1+num_reads-eta[0])
        p_other = np.random.dirichlet(1+ eta[1:])
        p_other *= (1-p0)/p_other.sum()
        p[0] = p0
        p[1:] = p_other
        if output_file != '':
            res.write(f'iteration {i}')
            res.write(str(p[0]))
        if i % 100 == 0:
            print(p[0], p[1:].sum()) 
    pool.close()
    if output_file != '':
        res.close()

In [None]:
# def do_mcmc_new(n_iterations = 50000):
#     res = open(f'{base}_mcmc.txt','w')
#     p = np.array([0.5] + [1/(num_genomes-1)]*(num_genomes-1))
#     pool = Pool()    
#     for i in tqdm(range(n_iterations), miniters=10):
        
#         func = lambda s : get_Zi_new(MC, p, base_err, s)
#         # print(rvals[i])
#         Z = np.array(pool.map_async(func, range(num_reads)).get())
        
        
#         eta = get_eta(Z, num_genomes)
#         p0 = np.random.beta(1 + eta[0],1+num_reads-eta[0])
#         p_other = np.random.dirichlet(1+ eta[1:])
#         p_other *= (1 - p0)/p_other.sum()
#         p[0] = p0
#         p[1:] = p_other
#         # res.write(f'iteration {i}')
#         # res.write(str(p[0]))
#         if i % 100 == 0:
#             print(p[0], p[1:].sum()) 
#     pool.close()
#     res.close()

In [None]:
# p = np.array([0.5] + [1/(num_genomes-1)]*(num_genomes-1))

In [None]:
bam, genomes = preprocess(ref_fname, genomes_fname, bam_fname)

In [None]:
genomes_arr = make_genomes_arr(genomes)

In [None]:
same = get_same(genomes_arr)

In [None]:
base_err = get_base_err(bam, same)

In [None]:
M, N = get_MN(genomes_arr, bam)

In [None]:
MC = get_mc(M, N, base_err)

In [None]:
M

In [None]:
do_mcmc()

In [None]:
p = np.array([0.5] + [1/(num_genomes-1)]*(num_genomes-1))

In [None]:
%timeit get_Zi(MC, p, num_genomes, 2)

In [None]:
save

In [None]:
bam_1 = pysam.AlignmentFile(bam_fname, "rb")
list(bam.fetch('chrM'))[0].seq

In [None]:
10**(-list(bam_1.fetch('chrM'))[0].query_qualities[-1]/10)