In [1]:
__author__ = "Maggie Ruimin Sun"
__version__ = "v0.1"

In [2]:
import os
import datetime
import sys
import pysam
import math
import statistics
import operator
import argparse
import random
import multiprocessing
import traceback
import subprocess
from scipy import stats
from collections import defaultdict

In [3]:
# global constants
pcr_error = 1e-6
prc_no_error = 1.0 - 3e-5
atgc = ('A', 'T', 'G', 'C')

In [36]:
def cal_posterior(barcode_k, num_reads_k):
    out_dict = defaultdict(float)
    count_base = defaultdict(int)

    for nt in atgc:
        count_base[nt] = 1
    
    unique_base_list = ['A', 'T', 'G', 'C']
    prob_err = defaultdict(list)
    for read_info in barcode_k.values():
        base = read_info[0][0]
        if base not in unique_base_list:
            unique_base_list.append(base)
            count_base[base] = 0
        prob_err[base].append(1-(1-read_info[0][1])*(1-read_info[0][4])*(1-pcr_error)) 
        pair_order = read_info[0][2]
        count_base[base] += 1
    for base in unique_base_list:
        prob_err_base = statistics.median(prob_err[base])
        prob_correct = 1.0 - prob_err_base
        out_dict[base] = prob_correct #stats.binom.pmf(count_base[base], num_reads_k, prob_correct)
    return out_dict

In [35]:
prbs=[stats.binom.pmf(x, 300, 0.98) for x in range(301)]
for i, pi in zip(range(301), prbs):
    if pi == max(prbs):
        print('%d\t%f' % (i, pi))
        break

294	0.162253


In [5]:
def convert_to_vcf(orig_ref, alt_base):
    vtype = '.'
    ref = orig_ref
    alt = alt_base
    if len(alt_base) == 1:
        vtype = 'SNP'
    elif alt_base == 'DEL':
        vtype = 'SDEL'
    else:
        vals = alt_base.split('|')
        if vals[0] in ('DEL', 'INS'):
            vtype = 'INDEL'
            ref = vals[1]
            alt = vals[2]
    return (ref, alt, vtype)

In [17]:
bamFile = '/home/yaneng/RSun/Data/NGS2018_03_20/N223-M2_S10_vcready_sorted.bam'
samfile = pysam.AlignmentFile(bamFile, 'rb')
#for read in samfile.fetch(region='chr13:32911438:32911439'):
#    print(read)

for pileupcolumn in samfile.pileup(region='chr13:32911570:32911570', truncate=True, max_depth=1000000, stepper='nofilter'):
    print('coverage at base %s = %s' % (pileupcolumn.pos, pileupcolumn.n))
    for pileupread in pileupcolumn.pileups:
        if not pileupread.is_del:
                print('%s\t%s\t%s' % (pileupread.alignment.query_name, pileupread.query_position, 
                                      pileupread.alignment.pos))
        #alltags = pileupread.alignment.tags
        #for (t, v) in alltags:
        #    if t == 'NM':
        #        print('%s\t%s' % (pileupread.alignment.query_name, pileupread.alignment.isize))

coverage at base 32911569 = 2968
M05257:4:000000000-BHV93:1:1114:19541:12603:chr13-0-32911444-GAATAATGGCAT:GAATAATGGCAT	126	32911443
M05257:4:000000000-BHV93:1:1109:9041:21151:chr13-1-32911448-TTCGATCTCAAC:TTCGATCTCAAC	122	32911447
M05257:4:000000000-BHV93:1:2109:24847:5403:chr13-1-32911448-TTCGATCTCAAC:TTCGATCTCAAC	122	32911447
M05257:4:000000000-BHV93:1:2114:9933:11064:chr13-1-32911448-TTCGATCTCAAC:TTCGATCTCAAC	122	32911447
M05257:4:000000000-BHV93:1:1101:18405:3939:chr13-0-32911451-CTCAGTTTTTGG:CTCAGTTTTTGG	119	32911450
M05257:4:000000000-BHV93:1:1101:25525:8071:chr13-0-32911451-GATAATACAAGT:GATAATACAAGT	119	32911450
M05257:4:000000000-BHV93:1:1101:12026:8275:chr13-0-32911451-GGCGTAACGAGT:GGCGTAACGAGT	119	32911450
M05257:4:000000000-BHV93:1:1101:3529:8480:chr13-0-32911451-GGCGTAACGAGT:GGCGTAACGAGT	119	32911450
M05257:4:000000000-BHV93:1:1101:26951:8515:chr13-0-32911451-CACGTTCATCTT:CACGTTCATCTT	119	32911450
M05257:4:000000000-BHV93:1:1101:25068:9513:chr13-0-32911451-TAACGATTGCGA:TAA

M05257:4:000000000-BHV93:1:2113:16050:16949:chr13-1-32911562-CATTCTAGGCAT:CATTCTAGGCAT	8	32911561
M05257:4:000000000-BHV93:1:2114:6930:10208:chr13-1-32911562-CATTCTAGGCAT:CATTCTAGGCAT	8	32911561
M05257:4:000000000-BHV93:1:2114:28032:21993:chr13-1-32911562-CGTTCTCTCAGG:CGTTCTCTCAGG	8	32911561
M05257:4:000000000-BHV93:1:2114:20080:23314:chr13-1-32911562-CACTCCTACGTC:CACTCCTACGTC	8	32911561
M05257:4:000000000-BHV93:1:2114:10355:23898:chr13-1-32911562-CGTTCTCTCAGG:CGTTCTCTCAGG	8	32911561
M05257:4:000000000-BHV93:1:1107:20333:22890:chr13-1-32911563-ATATTTAAGGAA:ATATTTAAGGAA	7	32911562
M05257:4:000000000-BHV93:1:1110:17106:26132:chr13-1-32911563-ATATTTAAGGAA:ATATTTAAGGAA	7	32911562
M05257:4:000000000-BHV93:1:1101:24300:20016:chr13-0-32911564-CGGTATGCAGCT:CGGTATGCAGCT	6	32911563
M05257:4:000000000-BHV93:1:1103:13022:18246:chr13-0-32911564-CGGTATGCAGCT:CGGTATGCAGCT	6	32911563
M05257:4:000000000-BHV93:1:1109:21503:7874:chr13-0-32911564-CGGTATGCAGCT:CGGTATGCAGCT	6	32911563
M05257:4:000000000-BHV

In [6]:
refGenome = '/home/yaneng/RSun/Data/qiagen-breast/target_breast.refSeq.fa'

refseq = pysam.FastaFile(refGenome)
ref_loci={}
ref_ids = refseq.references
for ref_id in ref_ids:
    chrom, ref_start, ref_stop = ref_id.split('_')
    for pos in range(int(ref_start), int(ref_stop)+1):
        ref_key = (chrom, str(pos))
        ref_loci[ref_key] = {}
        ref_loci[ref_key]['reference'] = ref_id
        ref_loci[ref_key]['pos_relative'] = str(pos-int(ref_start))
        ref_loci[ref_key]['length'] = str(int(ref_stop)-int(ref_start)+1)

In [7]:
ref_loci[('chr1', '45794680')]['length']

'743'

In [8]:
stats.fisher_exact([[10,1],[9,2]])

(2.2222222222222223, 1.0)

In [9]:
def is_homo_or_low_complex(chrom, pos, length, ref_base, alt_base, refseq, ref_loci):
    ref_name = ref_loci[(chrom, pos)]['reference']
    ref_length = int(ref_loci[(chrom, pos)]['length'])
    pos_relative = int(ref_loci[(chrom, pos)]['pos_relative'])
    len_refb = len(ref_base)
    len_altb = len(alt_base)
    seq_left = refseq.fetch(reference=ref_name, start=max(0, pos_relative-length), 
                            end=pos_relative).upper()
    seq_right_ref = refseq.fetch(reference=ref_name, start=pos_relative+len_refb-1, 
                             end=min(pos_relative+len_refb+length-1, ref_length)).upper()
    seq_right_alt = refseq.fetch(reference=ref_name, start=pos_relative+len_altb-1, 
                             end=min(pos_relative+len_altb+length-1, ref_length)).upper()
    ref_seq = seq_left + ref_base + seq_right_ref
    alt_seq = seq_left + alt_base + seq_right_alt
    
    # check homopolymer
    homo_A = ref_seq.find('A'*length)>=0 or alt_seq.find('A'*length)>=0
    homo_C = ref_seq.find('C'*length)>=0 or alt_seq.find('C'*length)>=0
    homo_G = ref_seq.find('G'*length)>=0 or alt_seq.find('G'*length)>=0
    homo_T = ref_seq.find('T'*length)>=0 or alt_seq.find('T'*length)>=0
    is_homo = homo_A or homo_C or homo_G or homo_T
    
    # check low complexity region
    lc_len = 2 * length
    seq_lc_left = refseq.fetch(reference=ref_name, start=max(0, pos_relative-lc_len), 
                               end=pos_relative).upper()
    seq_lc_right_ref = refseq.fetch(reference=ref_name, start=pos_relative+len_refb-1, 
                                    end=min(pos_relative+len_refb+lc_len-1, ref_length)).upper()
    seq_lc_right_alt = refseq.fetch(reference=ref_name, start=pos_relative+len_altb-1, 
                                    end=min(pos_relative+len_altb+lc_len-1, ref_length)).upper()
    ref_seq_lc = seq_lc_left + ref_base + seq_lc_right_ref
    alt_seq_lc = seq_lc_left + alt_base + seq_lc_right_alt
    sorted_counts_ref = sorted([ref_seq_lc.count('A'), ref_seq_lc.count('C'), ref_seq_lc.count('G'), 
                                ref_seq_lc.count('T')])
    top2freq_ref = 1.0*(sorted_counts_ref[0]+sorted_counts_ref[1])/len(ref_seq_lc)
    sorted_counts_alt = sorted([alt_seq_lc.count('A'),alt_seq_lc.count('C'), alt_seq_lc.count('G'),
                                alt_seq_lc.count('T')])
    top2freq_alt = 1.0*(sorted_counts_alt[0]+sorted_counts_alt[1])/len(alt_seq_lc)
    is_lowcomp = top2freq_ref < 0.8 or top2freq_alt < 0.8
    
    return(is_home, is_lowcomp)

In [10]:
def filter_variants(ref, alt, vtype, orig_ref, alt_base, used_MT, chrom, pos, hpLen, refseq, ref_loci, 
                    count_MT, count_allele, cov_eff, count_discord_pair, count_concord_pair, count_strand):
    fltr = ';'
    if used_MT < 5:
        fltr += 'LM;'
    
    # check homopolymer or low complexity regions
    (is_homopolymer, is_lowcomplexity) = is_homo_or_low_complex(chrom, pos, hpLen, ref, alt, refseq, ref_loci)
    freq_alt = 1.0 * count_MT[alt_base] / used_MT
    if is_homopolymer:
        fltr += 'HP;'
    if is_lowcomplexity:
        fltr += 'LowC;'
    
    # strand bias and discordant pairs filter
    af_alt = 100.0 * count_allele[alt_base] / cov_eff
    pairs = count_discord_pair[alt_base] + count_concord_pair[alt_base]
    if pairs >= cov_eff/4 and count_discord_pair[alt_base] > count_concord_pair[alt_base]:
        fltr += 'DP;'
    odds_ratio, pvalue = stats.fisher_exact([[count_strand[0][orig_ref], count_strand[1][orig_ref]], 
                                             [count_strand[0][alt_base], count_strand[1][alt_base]]])
    if pvalue < 0.01 and (odds_ratio > 20 or odds_ratio < 0.05):
        fltr += 'SB;'
    return fltr

In [39]:
def call_variants(samfile, chrom, pos, minBQ, minMQ, hpLen, mismatchThr, mtDrop, refseq, ref_loci):
    
    ref_name = ref_loci[(chrom, pos)]['reference']
    pos_relative = int(ref_loci[(chrom, pos)]['pos_relative'])
    orig_ref = refseq.fetch(reference=ref_name, start=pos_relative, end=pos_relative+1)
    orig_ref = orig_ref.upper()
    
    # initializations
    coverage = 0
    coverage_eff = 0
    sum_qual_base = defaultdict(int)
    count_allele = defaultdict(int)
    count_mismatch = defaultdict(int)
    count_read_pair = {'R1':defaultdict(int), 'R2':defaultdict(int)}
    count_strand = {0:defaultdict(int), 1:defaultdict(int)}
    count_concord_pair = defaultdict(int)
    count_discord_pair = defaultdict(int)
    barcode_dict = defaultdict(lambda: defaultdict(list))
    all_barcode_dict = defaultdict(list)
    pred_index = defaultdict(lambda: defaultdict(float))
    final_PI_dict = defaultdict(float)
    count_MT = defaultdict(int)
    
    for read in samfile.pileup(region=chrom+':'+pos+':'+pos, truncate=True, max_depth=1000000, stepper='nofilter'):
        for pileup_read in read.pileups:
            coverage += 1
            qname = pileup_read.alignment.query_name
            qname_split = qname.split(':')
            read_id = ':'.join(qname_split[:-2])
            barcode = qname_split[-1]
            
            if read_id not in all_barcode_dict[barcode]:
                all_barcode_dict[barcode].append(read_id)
            
            # FILTRATION: if mapping quality is too small, remove the corresponding read sequence from future consideration.
            qual_map = pileup_read.alignment.mapping_quality
            if qual_map < minMQ:
                continue
            
            # If the query site is a 1bp-deletion, assign it the base quality to be the minBQ.
            if not pileup_read.is_del:
                qual_base = pileup_read.alignment.query_qualities[pileup_read.query_position]
                # FILTRATION: if the base quality is too small, remove it.
                if qual_base < minBQ:
                    continue
            
            # Tag NM means the edit distance between read and reference.
            tag_NM = 0
            all_tags = pileup_read.alignment.tags
            for (tag, value) in all_tags:
                if tag == 'NM':
                    tag_NM = value
                    break
            
            # Count INDEL number based on the list of cigar tuples: (operation, length)
            # Employed operations in cigar tuples:
            # 0 -- match
            # 1 -- insertion
            # 2 -- deletion
            # 3 -- reference skip
            # 4 -- soft clip
            # 5 -- hard clip
            num_indel = 0
            cigar = pileup_read.alignment.cigar
            cigar_order = 1
            soft_clip_left = 0
            soft_clip_right = 0
            for (op, value) in cigar:
                if op == 1 or op == 2:
                    num_indel += value
                if cigar_order == 1 and op == 4:
                    soft_clip_left = value
                if cigar_order > 1 and op == 4:
                    soft_clip_right += value
                cigar_order += 1
            
            # Count mismatch number, excluding soft-clipped bases
            num_mismatch = tag_NM - num_indel
            # Mismatch per 100 bases
            read_len = pileup_read.alignment.query_length
            mismatch_per_100bp = 100.0*mismatch/read_len if read_len > 0 else 0.0
            
            # FILTRATION: if mismatch ratio is too large, remove the read.
            if mismatch_per_100bp > mismatchThr:
                continue
                
            # paired read
            if pileup_read.alignment.is_read1:
                pair_order = 'R1'
            if pileup_read.alignment.is_read2:
                pair_order = 'R2'
            
            strand = 1 if pileup_read.alignment.is_reverse else 0
            query_pos = pileup_read.query_position
            # check if the pileup site is the beginning of insertion(>1bp)
            if pileup_read.indel > 0:
                site = pileup_read.alignment.query_sequence[query_pos]
                inserted = pileup_read.alignment.query_sequence[
                    (query_pos+1):(query_pos+1+pileup_read.indel)
                ]
                base = 'INS|'+site+'|'+site+inserted
                
            # check if the pileup site is the beginning of deletion(>1bp)
            elif pileup_read.indel < 0:    
                site = pileup_read.alignment.query_sequence[query_pos]
                deleted = refseq.fetch(reference=ref_name, start=pos_relative,
                                       end=pos_relative+abs(pileup_read.indel))
                deleted = deleted.upper()
                base = 'DEL|'+site+deleted+'|'+site
                
            else:
                if pileup_read.is_del:
                    base = 'DEL'
                    qual_base = minBQ
                else:
                    base = pileup_read.alignment.query_sequence[query_pos]
            sum_qual_base[base] += qual_base
            count_allele[base] += 1
            count_mismatch[base] += 1
            count_read_pair[pair_order][base] += 1
            count_strand[strand][base] += 1    
            coverage_eff += 1
                    
            if read_id not in barcode_dict[barcode]:
                prob_base = pow(10.0, -qual_base/10.0) # pow(x,y) = x^y
                dist_to_5end = strand * (1+read_len) + pow(-1, strand)*query_pos
                prob_seq_error = 0.01*math.log(dist_to_5end)/math.log(read_len)
                read_info = [base, prob_base, pair_order, pileup_read.alignment.mpos, prob_seq_error]
                barcode_dict[barcode][read_id].append(read_info)
            elif base == barcode_dict[barcode][read_id][0][0] or base in ['N', '*']:
                barcode_dict[barcode][read_id][0][1] = max(
                    (pow(10.0, -qual_base/10.0), barcode_dict[barcode][read_id][0][1])
                )
                barcode_dict[barcode][read_id][0][2] = 'Paired'
                if base == barcode_dict[barcode][read_id][0][0]:
                    count_concord_pair[base] += 1
            else:
                del barcode_dict[barcode][read_id]
                count_discord_pair[base] += 1
    
    all_MT = len(all_barcode_dict)
    all_frag = sum([len(all_barcode_dict[bc]) for bc in all_barcode_dict])
    for bc in barcode_dict:
        if len(barcode_dict[bc]) < mtDrop:
            del barcode_dict[bc]
    
    used_MT = len(barcode_dict)
    threshold_PI = 2.0 * used_MT
    
    if used_MT == 0:
        out_long = '\t'.join([chrom, pos, orig_ref]+[' ']*50+['Zero_Coverage'])
        return out_long
    keys_bc = barcode_dict.keys()
    used_frag = sum(len(barcode_dict[bc]) for bc in keys_bc)
    total_R1 = sum(count_read_pair['R1'].values())
    total_R2 = sum(count_read_pair['R2'].values())
    
    # add sudo-counts
    sudo_used_frag = used_frag + 4*used_MT
    for bc in keys_bc:
        sudo_bc_frag = len(barcode_dict[bc]) + 4
        # pr_bc = 1.0*(sudo_bc_frag)/sudo_used_frag
        post_pr_base = cal_posterior(barcode_dict[bc], sudo_bc_frag)
        for char in post_pr_base.iterkeys():
            x = 1.0 - post_pr_base[char] # *pr_bc
            log10P = -math.log10(x) if x > 1e-10 else 10.0
            pred_index[bc][char] = log10P
            final_PI_dict[char] += log10P
            count_MT[char] += 1
    sorted_PI_list = sorted(final_PI_dict.items(), key=operator.itemgetter(1), reverse=True)
    max_base = sorted_PI_list[0][0]
    max_PI = sorted_PI_list[0][1]
    second_max_base = sorted_PI_list[1][0]
    second_max_PI = sorted_PI_list[1][1]
    
    alt_base = second_max_base if max_base == orig_ref else max_base
    alt_PI = second_max_PI if max_base == orig_ref else max_PI
    (ref, alt, vtype) = convert_to_vcf(orig_ref, alt_base)
    
    filtration = ';'
    if used_MT > 1 and vtype in ('SNP', 'INDEL'):
        filtration = filterVariants(ref, alt, vtype, orig_ref, alt_base, used_MT, chrom, pos, hpLen, refseq, 
                                    ref_loci, count_MT, count_allele, coverage_eff, count_discord_pair, count_concord_pair, 
                                    count_strand)
        
    mf_alt = 1.0 * count_MT[max_base] / used_MT
    mf_alt2 = 1.0 * count_MT[second_max_base] / used_MT
    
    if max_base != orig_ref and second_max_base != orig_ref and mf_alt >= 0.4 and mf_alt2 >= 0.4:
        alt_base2 = second_max_base
        (ref2, alt2, vtype2) = convert_to_vcf(orig_ref, alt_base2)
        filtration2 = ';'
        if vtype in ('SNP', 'INDEL'):
            filtration2 = filterVariants(ref2, alt2, vtype2, orig_ref, alt_base2, chrom, pos, hpLen, refseq, 
                                         ref_loci, count_MT, count_allele, coverage_eff, count_discord_pair, 
                                         count_concord_pair, count_strand)
        if filtration == ';' and filtration2 == ';':
            alt = alt+','+alt2
            vtype = vtype+','+vtype2
        elif filtration != ';' and filtration2 == ';':
            alt = alt2
            filtration = filtration2
            alt_base = alt_base2
    
    frac_alt = round((100.0*count_allele[alt_base]/coverage_eff), 4)
    frac_A = round((100.0*count_allele['A']/coverage_eff), 4)
    frac_C = round((100.0*count_allele['C']/coverage_eff), 4)
    frac_G = round((100.0*count_allele['G']/coverage_eff), 4)
    frac_T = round((100.0*count_allele['T']/coverage_eff), 4)
    fracs = (count_allele['A'], count_allele['T'], count_allele['G'], count_allele['C'],
             frac_A, frac_T, frac_G, frac_C)
    
    frac_alt_MT = round((100.0*count_MT[alt_base]/used_MT), 4)
    frac_A_MT = round((100.0*count_MT['A']/used_MT), 4)
    frac_C_MT = round((100.0*count_MT['C']/used_MT), 4)
    frac_G_MT = round((100.0*count_MT['G']/used_MT), 4)
    frac_T_MT = round((100.0*count_MT['T']/used_MT), 4)
    fracs_MT = (count_MT['A'], count_MT['T'], count_MT['G'], count_MT['C'],
               frac_A_MT, frac_T_MT, frac_G_MT, frac_C_MT)
    pred_index = (round(final_PI_dict['A'], 2), round(final_PI_dict['T'], 2), 
                  round(final_PI_dict['G'], 2), round(final_PI_dict['C'], 2))
    
    vec_out = [chrom, pos, ref, alt, vtype, coverage_eff, all_frag, all_MT, used_frag, used_MT, threshold_PI, 
               round(final_PI_dict[alt_base], 2), count_allele[alt_base], frac_alt, count_MT[alt_base],
               frac_alt_MT]
    vec_out.extend(fracs)
    vec_out.extend(fracs_MT)
    vec_out.extend(pred_index)
    vec_out.append(filtration)
    out_long = '\t'.join((str(x) for x in vec_out))
    return out_long

In [40]:
def call_variants_wrapper(*args):
    try:
        output = call_variants(*args)
    except:
        print('Exception thrown in call_variants() function at genome location:',
             args[1], args[2])
        output = 'Exception thrown!\n' + traceback.format_exc()
    return output

In [41]:
parser = None
def arg_parse_init():
    global parser
    parser = argparse.ArgumentParser(description='Variant calling using molecular barcodes', 
                                    fromfile_prefix_chars='@')
    parser.add_argument('--outPrefix', default=None, required=True, 
                        help='prefix for output files')
    parser.add_argument('--bamFile', default=None, required=True, help='BAM file')
    parser.add_argument('--bedTarget', default=None, required=True, help='bed file of target region')
    # parser.add_argument('--mtDepth', default=None, required=True, type=int, 
    #                     help='Mean MT depth')
    # parser.add_argument('--rpb', type=float, help='Mean read pairs per MT')
    parser.add_argument('--nCPU', type=int, default=1, help='CPU number used in parallel')
    parser.add_argument('--minBQ', type=int, default=20, 
                        help='minimum base quality allowed for analysis')
    parser.add_argument('--minMQ', type=int, default=30, 
                        help='minimum mapping quality allowed for analysis')
    parser.add_argument('--hpLen', type=int, default=10, help='minimum length of homopolymers')
    parser.add_argument('--mismatchThr', type=float, default=6.0, 
                        help='average number of mismatches per 100bp allowed')
    parser.add_argument('--mtDrop', type=int, default=0, help='drop MTs containted by <=[mtDrop] reads')
    # parser.add_argument('--primerDist', type=int, default=2, help='minimum distant from variant to primer')
    parser.add_argument('--threshold', type=int, default=0, help='minimum prediction index (PI)>=0 required \
                        to call variants. Range from 10 to 60 is recommended. If PI=0 by default, smCounter \
                        will choose an approriate cutoff based on the mean MT depth.')
    parser.add_argument('--refGenome', default=None, help='reference DNA sequences of target')
    parser.add_argument('--bedTandemRepeats', default=None, help='UCSC tandem repeats')
    parser.add_argument('--bedRepeatMaskerSubset', default=None, help='RepeatMasker simple repeats, \
                        low complexity, microsatellite regions')
    parser.add_argument('--bedtoolsPath', default=None, help='path to bedtools')
    parser.add_argument('--runPath', default=None, help='path to working directory')
    parser.add_argument('--logFile', default=None, help='log file')
    parser.add_argument('--paramFile', default=None, help='optional parameter file that contains the above \
                        parameter.')

In [44]:
def main(args):
    time_start = datetime.datetime.now()
    print('Variant calling with smCounter started at '+str(time_start))
    
    # Initialize global argument parser
    if parser == None:
        arg_parse_init()
    
    # pass arguments in args
    if type(args) is not argparse.Namespace:
        args_list = []
        for arg_name, arg_val in args.iteritems():
            args_list.append('--{0}={1}'.format(arg_name, arg_val))
        args = parser.parse_args(args_list)
    elif args.paramFile != None:
        args = parser.parse_args(('@'+args.paramFile))
    
    for arg_name, arg_val in vars(args).iteritems():
        print(arg_name, arg_val)
    
    if args.runPath != None:
        os.chdir(args.runPath)
    
    # get the reference sequences
    refseq = pysam.FastaFile(args.refGenome)
    ref_loci = {}
    ref_ids = refseq.references
    for ref_id in ref_ids:
        chrom, ref_start, ref_stop = ref_id.split('_')
        for pos in range(int(ref_start), int(ref_stop)+1):
            ref_key = (chrom, str(pos))
            ref_loci[ref_key] = {}
            ref_loci[ref_key]['reference'] = ref_id
            ref_loci[ref_key]['pos_relative'] = str(pos-int(ref_start))
            ref_loci[ref_key]['length'] = str(int(ref_stop)-int(ref_start)+1)
    
    # make list of candidate loci to call variants
    loci_list = []
    bed = open(args.bedTarget)
    for line in bed:
        if not line.startswith("track "):
            (chrom, region_start, region_stop) = line.strip().split('\t')[0:3]
            for pos in range(int(region_start), int(region_stop)+1):
                loci_list.append((chrom, str(pos)))
    bed.close()
    
    # read the BAM file
    samfile = pysam.AlignedFile(args.bamFile, 'rb')
    
    # call variants in parallel
    pool = multiprocessing.Pool(processes=args.nCPU)
    results = [pool.apply_async(call_variants_wrapper, args=(samfile, x[0], x[1], args.minBQ, args.minMQ, 
                                                     args.mtDepth, args.rpb, args.hpLen, args.mismatchThr,
                                                     args.mtDrop, args.primerDist, refseq, ref_loci)) for x in loci_list]
    output = [p.get() for p in results]
    pool.close()
    pool.join()
    
    for idx in range(len(output)):
        line = output[idx]
        if line.startswith("Exception thrown!"):
            print(line)
            raise Exception("Exception thrown in variant_calling() at location: "+str(loci_list[idx]))
    
    bed_repeat_masker = args.outPrefix + args.bedRepeatMaskerSubset
    bed_tandem_repeat = args.outPrefix + args.bedTandemRepeats
    
    trf_regions = defaultdict(list)
    trf = open(bed_repeat_masker)
    for line in trf:
        vals = line.strip().split()
        (chrom, start, stop) = vals[0:3]
        trf_regions[chrom].append((int(start), int(stop), "RepT;"))
    trf.close()
    
    rm_regions = defaultdict(list)
    rm = open(bed_repeat_masker)
    for line in rm:
        (chrom, start, stop, code_types) = line.strip().split()
        repeat_types = []
        for code_type in code_types.split(","):
            if code_type == 'Simple_repeat':
                repeat_types.append('RepS')
            elif code_type == 'Low_complexity':
                repeat_types.append('LowC')
            elif code_type == 'Satellite':
                repeat_types.append('SL')
            else:
                repeat_types.append('Others')
        rm_regions[chrom].append((int(start), int(stop), ';'.join(repeat_types)+';'))
    rm.close()
    
    header_all = ('CHROM', 'POS', 'REF', 'ALT', 'TYPE', 'DP', 'FR', 'MT', 'UFR', 'UMT', 'THR', 'PI', 
                  'VDP', 'VAF', 'VMT', 'VMF', 'VSM', 'DP_A', 'DP_T', 'DP_G', 'DP_C', 'AF_A', 
                  'AF_T', 'AF_G', 'AF_C', 'UMT_A', 'UMT_T', 'UMT_G', 'UMT_C', 'UMF_A', 'UMF_T', 
                  'UMF_G', 'UMF_C', 'VSM_A', 'VSM_T', 'VSM_G', 'VSM_C', 'PI_A', 'PI_T', 'PI_G', 
                  'PI_C', 'FILTER')
    header_var = ('CHROM', 'POS', 'REF', 'ALT', 'TYPE', 'DP', 'MT', 'UMT', 'THR', 'PI', 'VMT', 
                  'VMF', 'VSM', 'FILTER')
    header_all_idx = {}
    for idx in range(len(header_all)):
        header_all_idx[header_all[idx]] = idx
    
    # If MT fraction < 40% and the variant is inside the tandem repeat region, delete the called-variant.
    for i in range(len(output)):
        outline = output[i]
        line_list = outline.split('\t')
        chrom_tr = line_list[header_all_idx['CHROM']]
        alt_tr = line_list[header_all_idx['ALT']]
        pos_tr = int(line_list[header_all_idx['POS']])
        frac_alt_MT_tr = float(line_list[header_all_idx['VMF']])
        pred = int(float(line_list[header_all_idx['PI']]))
        
        if alt_tr !='DEL':
            if frac_alt_MT_tr < 40:
                for (rep_start, rep_stop, rep_type) in trf_regions[chrom_tr]:
                    if rep_start < pos_tr <= rep_stop:
                        line_list[-1] += rep_type
                        break
            for (rep_start, rep_stop, rep_type) in rm_regions[chrom_tr]:
                if rep_start < pos_tr <= rep_stop:
                    line_list[-1] += rep_type
                    break
        line_list[-1] = 'PASS' if line_list[-1] == ';' else line_list[-1].strip(';')
        output[i] = '\t'.join(line_list)
    
    header_vcf = \
        '##fileformat=VCFv4.2\n' + \
        '##reference=GRCh37\n' + \
        '##INFO=<ID=TYPE, Number=1, Type=String, Description="Variant type: SNP or INDEL">\n' + \
        '##INFO=<ID=DP,Number=1,Type=Integer,Description="Total read depth">\n' + \
        '##INFO=<ID=MT,Number=1,Type=Integer,Description="Total MT depth">\n' + \
        '##INFO=<ID=UMT,Number=1,Type=Integer,Description="Filtered MT depth">\n' + \
        '##INFO=<ID=PI,Number=1,Type=Float,Description="Variant prediction index">\n' + \
        '##INFO=<ID=THR,Number=1,Type=Integer,Description="Variant prediction index minimum threshold">' + \
                '\n' + \
        '##INFO=<ID=VMT,Number=1,Type=Integer,Description="Variant MT depth">\n' + \
        '##INFO=<ID=VMF,Number=1,Type=Float,Description="Variant MT fraction">\n' + \
        '##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n' + \
        '##FORMAT=<ID=AD,Number=.,Type=Integer,Description="Filtered allelic MT depths for the ref and' + \
                  ' alt alleles">\n' + \
        '##FORMAT=<ID=VF,Number=1,Type=Float,Description="Variant MT fraction, same as VMF">\n' + \
        '##FILTER=<ID=RepT,Description="Variant in simple tandem repeat region, as defined by Tandem Repeats' + \
                 ' Finder">\n' + \
        '##FILTER=<ID=RepS,Description="Variant in simple repeat region, as defined by RepeatMasker">\n' + \
        '##FILTER=<ID=LowC,Description="Variant in low complexity region, as defined by RepeatMasker">\n' + \
        '##FILTER=<ID=SL,Description="Variant in micro-satelite region, as defined by RepeatMasker">\n' + \
        '##FILTER=<ID=HP,Description="Inside or flanked by homopolymer region">\n' + \
        '##FILTER=<ID=LM,Description="Low coverage (fewer than 5 MTs)">\n' + \
        '##FILTER=<ID=SB,Description="Strand bias">\n' + \
        '##FILTER=<ID=MM,Description="Too many genome reference mismatches in reads (default threshold is 6.5 per' + \
                  ' 100 bases)">\n' + \
        '##FILTER=<ID=DP,Description="Too many discordant read pairs">\n' + \
        '\t'.join(('#CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT', args.outPrefix))+'\n'
    
    out_all = open(args.outPrefix + '.variant_calling.all.txt', 'w')
    out_var = open(args.outPrefix + '.variant_calling.cut.txt', 'w')
    out_vcf = open(args.outPrefix + '.variant_calling.cut.vcf', 'w')
    
    out_all.write('\t'.join(header_all) + '\n')
    out_var.write('\t'.join(header_var))
    out_vcf.write(header_vcf)
    
    for line in output:
        out_all.write(line+'\n')
        
        fields = line.split('\t')
        PI = fields[header_all_idx['PI']]
        if len(PI) == 0:
            continue
        threshold = float(fields[header_all_idx['THR']])
        alt = fields[header_all_idx['ALT']]
        qual = int(float(PI))
        if qual > threshold and alt != 'DEL':
            chrom = fields[header_all_index['CHROM']]
            pos = fields[header_all_index['POS']]
            ref = fields[header_all_index['REF']]
            vtype = fields[header_all_index['TYPE']]
            dp = fields[header_all_index['DP']]
            mt = fields[header_all_index['MT']]
            umt = fields[header_all_index['UMT']]
            vmt = fields[header_all_index['VMT']]
            vmf = fields[header_all_index['VMF']]
            fltr = fields[header_all_index['FILTER']]
            thr = str(threshold)
            info = ';'.join(('TYPE='+vtype, 'DP='+dp, 'MT='+mt, 'UMT='+umt, 'PI='+PI, 'THR='+thr, 
                             'VMT='+vmt, 'VMF='+vmf, 'VSM='+vsm))
            
            
            alts = alt.split(',')
            ref_mt = str(int(umt)-int(vmt))
            ad = ref_mt+','+vmt
            if len(alts) == 2:
                genotype = '1/2'
                ad = ad + ',1'
            elif len(alts) != 1:
                raise Exception("error hacking genotype field for "+alts)
            elif chrom == 'chrY' or chrom == 'chrM':
                genotype = '1'
            elif float(vmf) > 0.95:
                genotype = '1/1'
            else:
                genotype = '0/1'
            
            out_format = 'GT:AD:VF'
            sample = ':'.join((genotype, ad, vmf))
            var_id = '.'
            vcf_line = '\t'.join((chrom, pos, var_id, ref, alt, str(qual), fltr, info, out_format, 
                                  sample)) + '\n'
            short_line = '\t'.join((chrom, pos, ref, alt, vtype, dp, mt, umt, PI, thr, vmt, vmf, 
                                    fltr)) + '\n'
            out_vcf.write(vcf_line)
            out_var.write(short_line)
            
            if vtype == 'SNP':
                num_called_snps = 0
            else:
                num_called_indels = 0
    out_vcf.close()
    out_all.close()
    out_var.close()
    
    time_end = datetime.datetime.now()
    print('Variant calling process compleleted running at '+str(time_end))
    print('Variant calling total running time: '+str(time_end - time_start))
    
    return threshold

In [None]:
if __name__ == '__main__':
    arg_parse_init()
    args = parser.parse_args()
    
    main(args)