In [None]:
def is_notebook() -> bool:
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter

In [None]:
# Здесь загружаются названия референса, файла с ридами (bam), файла с геномами
import argparse

if is_notebook() == False:
    parser = argparse.ArgumentParser(description='Supply reference fasta and bam file')
    parser.add_argument('ref',
                        help='reference fasta')
    parser.add_argument('bam',
                        help='bam file')
    parser.add_argument('cont',
                        help='list of contaminants fasta')


    args = parser.parse_args()
    ref_fname = args.ref
    bam_fname = args.bam
    genomes_fname = args.cont
    
else:
    ref_fname     = 'refchrm.fa'
    bam_fname     = 'iintest.bam'
    genomes_fname = 'contaminants.fa'

In [None]:
import os
from collections import Counter
import pysam
import numpy as np
from tqdm import tqdm
from IPython.display import clear_output
from scipy.special import binom
import scipy.stats as st
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from multiprocess import Pool
import matplotlib.pyplot as plt
import seaborn as sns
from preprocess import *

os.system('python setup.py build_ext --inplace')
from MN import *

In [None]:
def get_base_err(bam_fname, ref, aln_pos, same_set):
    bam = pysam.AlignmentFile(bam_fname, "rb")
    correct = 0
    incorrect = 0
    for readId, read in enumerate(bam.fetch('chrM')):
        
        if not read.is_mapped or 'D' in read.cigarstring or 'I' in read.cigarstring:
            continue
            
        seq = read.query_sequence
        pos = read.reference_start
        
        if read.cigartuples[0][0] == 4: #read is soft clipped
            left_trim = read.cigartuples[0][1]
            seq = seq[left_trim:]
                        

                        
        
        offset = 0
        debug_str = ''

        for k in range(len(seq)):
            if aln_pos[pos+k] in same_set:
                if seq[k].upper() == ref[aln_pos[pos+k]]:
                    correct+=1
                else:
                    # print(pos, k, readId)
                    incorrect += 1
    return correct, incorrect, incorrect/(correct + incorrect)
                
    

In [None]:
# def consensus_caller(ref_fname, bam_fname):
#     base = bam_fname[:-4]
#     os.system(f"samtools view {bam_fname} chrM -o {base+'_mt.bam'}")
#     base = base + '_mt'
#     os.system(f'samtools consensus -o {base}_st.fa {bam_fname}') #st means samtools
#     os.system(f'bwa index -a bwtsw {base}_st.fa') #indexing consensus
#     os.system(f'samtools faidx {base}_st.fa')
#     os.system(f'rm {base}.dict')
#     os.system(f'picard CreateSequenceDictionary R={base}.fa O={base}.dict')
#     os.system(f'samtools fastq {bam_fname} > {base}.fq')
#     os.system(f'bwa aln -l 1000 -t 10 {base}_st.fa {base}.fq > {base}_ra.sai')
#     os.system(f"bwa samse -r '@RG\\tID:{base}\\tLB:{base}_L1\\tPL:ILLUMINA\\tSM:{base}' {base}_st.fa {base}_ra.sai {base}.fq |  samtools sort -O BAM -o {base}_ra.sort.bam")
#     os.system(f'samtools index {base}_ra.sort.bam')
#     consensus = bam2consensus(f'{base}_st.fa', f'{base}_ra.sort.bam')
#     consensus_fa = '>chrM\n'+''.join(consensus) +'\n'
#     with open(f'{base}.fa', 'w') as new_genomes:
#         new_genomes.write(consensus_fa)
#     return consensus
    
    

In [None]:
def make_genomes_arr(genomes_fname):
    genomes = list()
    for record in SeqIO.parse(genomes_fname, "fasta"):
        genomes.append(str(record.seq))
    genomes_arr = np.array([list(x) for x in genomes], dtype = 'S1')
    return genomes_arr

In [None]:
def get_same(genomes_arr):
    same_positions = []
    for i in range(genomes_arr.shape[1]):
        if len(np.unique(genomes_arr[:,i])) == 1:
            same_positions.append(i)
    return set(same_positions)

In [None]:
def get_base_err(bam_fname, same_dict):
    samfile = pysam.AlignmentFile(bam_fname, "rb" )
    same_positions = list(same_dict.keys())
    correct = 0
    total = 0
    samfile = pysam.AlignmentFile(bam_fname, "rb" )
    for pileupcolumn in tqdm(samfile.pileup("chrM")):
        pos = pileupcolumn.pos
        if pos not in same_positions:
            continue

        for pileupread in pileupcolumn.pileups:
            if not pileupread.is_del and not pileupread.is_refskip:
                total += 1
                # query position is None if is_del or is_refskip is set.
                nbase =  pileupread.alignment.query_sequence[pileupread.query_position]
                if nbase == same[pos].decode('ascii').upper():
                    correct += 1
    base_err = 1 - correct/total
    samfile.close()
    return base_err

In [None]:
def get_aln_pos(reference):
    aln_coor = []
    for i in range(len(reference)):
        if reference[i] != '-':
            aln_coor.append(i)
            
    return np.asarray(aln_coor)

In [None]:
def do_mcmc(n_iterations = 50000, output_file='', n_threads=8, model=0, show_each=10):
    if output_file != '':
        res = open(output_file,'w')
    p_list = []
    num_reads, num_genomes  = MC.shape
    print(MC.shape)
    p = np.random.dirichlet([1]*num_genomes)
    # pool = Pool(n_threads)
    for i in tqdm(range(n_iterations) ):
        
        func = lambda x: get_Zi(MC, p, base_err, x)
        
        # Z = np.array(pool.map_async(func, range(num_reads)).get())
        Z = np.array([func(s) for s in range(num_reads) ])
        eta = get_eta(Z, num_genomes)
        if model == 0:
            p0 = np.random.beta(1 + eta[0],1+num_reads-eta[0])
            p_other = np.random.dirichlet(1+ eta[1:])
            p_other *= (1-p0)/p_other.sum()

            p[0] = p0
            p[1:] = p_other
            p_list.append(p[0])
        else:
            p = np.random.dirichlet(1+ eta)
            p_list.append(p[0])
        if output_file != '':
            res.write(f'iteration {i}')
            res.write(str(p[0]))
        if i % show_each == 0:
            # print(p[0], p[1:].sum()) 
            print(p)
    # pool.close()
    if output_file != '':
        res.close()
    return p_list

In [None]:
bam, genomes = preprocess(ref_fname, genomes_fname, bam_fname)

In [None]:
# bam = 'simulated_data.bam'

In [None]:
genomes_arr = make_genomes_arr(genomes)

In [None]:
np.sum(genomes_arr[0] != genomes_arr[1])

In [None]:
genomes_arr.shape

In [None]:
genomes_arr

In [None]:
genomes

In [None]:
genomes_arr

In [None]:
same = get_same(genomes_arr)

In [None]:
genomes0 = (''.join( np.array(genomes_arr, dtype = str)[0])).upper()
genomes1 = (''.join( np.array(genomes_arr, dtype = str)[1])).upper()

In [None]:
genomes0.count('N')

In [None]:
pysam.view('-c', bam)

In [None]:
genomes_arr

In [None]:
aln_coords = get_aln_pos(genomes0)

In [None]:
#glmc = get_glM(genomes_arr, bam, aln_coords, same)

In [None]:
#glmc

In [None]:
M, N, base_err = get_MN(genomes_arr, bam, aln_coords, same)

In [None]:
base_err

In [None]:
((M[:, 0] < N[:, 0])).sum() + (M[:,0] == -1).sum()

In [None]:
# it = 4147
# samfile = pysam.AlignmentFile(bam, "rb" )
# read = list(samfile.fetch('chrM'))[it]
# print(M[it], N[it])
# print(read.cigarstring)
# aln_pos = aln_coords[read.pos]
# print(aln_coords[read.pos])
# print(read.seq)
# genome = (''.join( np.array(genomes_arr, dtype = str)[1])).upper()
# print(genome[aln_coords[read.pos]: aln_coords[read.pos] + 120])

In [None]:
M

In [None]:
genomes0.count('-')

In [None]:
# При большой ошибке большое число ридов не картируется, из-за чего точность оценки base_error падаеты

In [None]:
for i in range(M.shape[0]):
    for j in range(M.shape[1]):
        if M[i, j] < N[i, j]:
            M[i, j] = -1
            N[i, j] = -1

In [None]:
# (M[:,0] > M[:,1]).sum()/(M[:,0] != M[:,1]).sum()

In [None]:
print(f'#base error is {base_err}')

In [None]:
MC = get_mc(M, N, 0.01)

In [None]:
idx = [i for i in range(len(MC)) if not np.all(MC[i]==MC[i,0])]
MC = MC[idx]

In [None]:
# vec = np.array([2,2,2,2])

In [None]:
# vec[0] == vec[1:]

In [None]:
# idx

In [None]:
 # (MC[:,0]!=0) *  (MC[:,1]!=0) * (M[:,1]>N[:,1])

In [None]:
MC

In [None]:
(MC[:,0]>MC[:,1]).mean()

In [None]:
# MC = get_glM()

In [None]:
# np.sum((M[:,1]>M[:,0] + 4))

In [None]:
P = do_mcmc(10000, n_threads=1, model=1, show_each=100)

In [None]:
import matplotlib.pyplot as plt

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15,3), gridspec_kw={'width_ratios': [3.2, 1]})
ax1.plot(np.asarray(P[100::5]), label='MCMC')
# x = [0, 50000]
# y = [0.6, 0.6]
# ax1.plot(x, y, label='Presice')
ax1.legend()
sns.kdeplot(np.asarray(P[100::5])-0.01, cut=0);
# fig.suptitle('Test 9.')

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15,3), gridspec_kw={'width_ratios': [3.2, 1]})
ax1.plot(np.asarray(np.array(P[100::5])), label='MCMC')
# x = [0, 50000]
# y = [0.6, 0.6]
# ax1.plot(x, y, label='Presice')
ax1.legend()
sns.kdeplot(np.asarray(P[100::5]), cut=0);

In [None]:
sns.displot(np.asarray(P[::1000]), kind='kde', cut=0);
plt.show()
plt.hist(P[::1000], bins=20);

In [None]:
num_reads, num_genomes = MC.shape

In [None]:
p = np.array([0.5] + [1/(num_genomes-1)]*(num_genomes-1))

In [None]:
%timeit get_Zi(MC, p, num_genomes, 2)

In [None]:
dif

In [None]:
contamix_ref[514]

In [None]:
cons[16179]

In [None]:
def get_num_indels(bam_fname, trunc = 7):
    samfile = pysam.AlignmentFile(bam_fname, "rb" )
    num_reads = 0
    for read in samfile.fetch('chrM'):
        if not read.is_mapped or read.pos < trunc:
            continue
        if "I" in read.cigarstring:
            num_reads += 1
    samfile.close()
    return num_reads

In [None]:
get_num_indels(bam)

In [None]:
def get_cigar_string(bam_fname):
    ''''
    This function calculate mapped reads
    '''
    samfile = pysam.AlignmentFile(bam_fname, "rb" )
    
    for read in samfile.fetch('chrM'):
    print(read.cigartuples)
    samfile.close()

In [None]:
samfile = pysam.AlignmentFile(bam, "rb" )
read = list(samfile.fetch('chrM'))[432]
print(read.cigartuples)
print(aln_coords[read.pos])
print(read.seq)
genome = (''.join( np.array(genomes_arr, dtype = str)[1])).upper()
print(genome[213: 326].replace('-',''))

In [None]:
j = 0
for i in range(num_reads):
    if (MC[i,:].sum()) == 0:
        print(i)

In [None]:
get_num_reads(bam)

In [None]:
def calculate_likelihood(probs, mc):
    probs = np.asarray(probs)
    num_reads, num_genomes = mc.shape
    log_l = 0
    for i in range(num_reads):
        log_l += np.log((probs*MC[i,:]).sum())
    return log_l

            

In [None]:
calculate_likelihood([0.8, 0.2, 0], MC)

In [None]:
calculate_likelihood([0.7, 0.1, 0.2], MC)

In [None]:
calculate_likelihood([0.99, 0.01, 0], MC)

In [None]:
p3 = np.asarray([0.9] + [0.1]*(num_genomes-1))

In [None]:
np.where(M[:,0]>M[:,1]+5)

In [None]:
calculate_likelihood(p3, MC)

In [None]:
p4 = np.asarray([0.7] + [0.3/(num_genomes-1)]*(num_genomes-1))

In [None]:
calculate_likelihood([0.8, 0.2], MC)

In [None]:
p5 = np.asarray([0.6] + [0.4/(num_genomes-1)]*(num_genomes-1))

In [None]:
calculate_likelihood(p5, MC)

In [None]:
seq1==seq2

In [None]:
genomes_arr

In [None]:
def get_probs(mc, p):
    num_reads, num_genomes = mc.shape
    p = np.asarray(p)
    
    probs = np.zeros_like(mc)
    # probs = np.zeros(num_genomes, dtype = float)
    for i in range(num_reads):
        s = 0
        for j in range(num_genomes):
            probs[i, j] = mc[i, j] * p[j]
            s += probs[i, j]
        for j in range(num_genomes):
            probs[i, j] = probs[i, j] / s
    return probs

In [None]:
f = get_probs(MC,[0.5, 0.5])

In [None]:
(f[:,0]>f[:,1]).mean()

In [None]:
p = np.array([0.5, 0.5])

In [None]:
np.where(np.bitwise_and((-1<M[:,0]),  M[:,0]+1<M[:,1]))[0].shape

In [None]:
np.where(M[:,0]<M[:,1])[0].shape

In [None]:
M[1590]
# N[1590]

In [None]:
genome[801:801+100]

In [None]:
read.pos

In [None]:
genome1 = (''.join( np.array(genomes_arr, dtype = str)[0])).upper()
genome2 = (''.join( np.array(genomes_arr, dtype = str)[1])).upper()

In [None]:
s = 100
print(genome1[s: s+100])
print(genome2[s: s+100])

In [None]:
M[19245, 1]

In [None]:
import sys

In [None]:
np.set_printoptions(threshold=300)

In [None]:
np.where(-1 < M[:,0])[0][0:]

In [None]:
M[598]

In [None]:
np.where(np.bitwise_and(M[:,1]<N[:,1], M[:,0] != -1))

In [None]:
p

In [None]:
get_probs(MC, p, 158)

In [None]:
np.where(np.bitwise_and(M[:,1]>M[:,0]+1, M[:,0]>0))[0].shape

In [None]:
np.where(M[:,0]<M[:,1])[0]

In [None]:
M[101]

In [None]:
(M[:,0]>M[:,1]).sum()

In [None]:
N[157]

In [None]:
M.shape

In [None]:
genomes0.count('-')

In [None]:
genome1 = (''.join( np.array(genomes_arr, dtype = str)[1])).upper()

In [None]:
in10_1 =  genome.replace('-', '')

In [None]:
f = open('data/bam/in10.fa')
in10 = f.read()[5:].replace('\n', '')
f = open('data/bam/in1.fa')
in1 = f.read()[5:].replace('\n', '')

In [None]:
dif = [i for i in range(16569) if consensus[i] != in10[i]]

In [None]:
dif

In [None]:
cons = bam2consensus('data/bam/output40_in1_60_in10.bam', 1, 0.5)

In [None]:
cons1 = pysam.consensus('data/bam/output40_in1_60_in10.bam')[5:].replace('\n','')

In [None]:
len(cons)

In [None]:
len(cons1)

In [None]:
dif = [i for i in range(len(cons1)) if cons[i] !=cons1[i]]

In [None]:
A = 300
B = 30
print(cons[A: A + B])
print(cons1[A: A + B])

In [None]:
cons.count('N')

In [None]:
def glMC()