# Implementation of the EBfilter by genomon

* EBrun (originally EBFilter) is an argparse wrapper passing command line arguments to run.py (is not needed for internal use)
* passed arguments:
    * targetMutationFile: the .vcf or .anno containing the mutations – needed --> mut_file
    * targetBamPath: path to the tumor bam file (+.bai) – needed --> tumor_bam
    * controlBamPathList: text list of path to PoN bam files (+ .bai) – needed --> pon_list
    * outputPath: clear  – needed --> output_path
    * -f option for anno or vcf – not needed --> will be inferred from .ext
    * thread_num: –not needed --> taken from config
    * -q option for quality threshold – not needed --> default _q config
    * -Q option for base quality threshold - not needed --> default _Q from config
    * --ff option for filter flags – not needed because of preprocessing??
    * --loption for samtools mpileup -l option – must elaborate..
    * --region option for restriction of regions on mpileup -l – must elaborate..
    * --debug – not needed

## Initiation

### imports

In [54]:
import pandas as pd
import numpy as np
import vcf
import pysam
import sys
import os
import subprocess
import math
import scipy.stats as ss
import scipy.optimize as so
import re
import multiprocessing
region_exp = re.compile('^([^ \t\n\r\f\v,]+):(\d+)\-(\d+)')

### snakemake config

In [55]:
config = {'EB':{'run': True}}
params = {}
params['map_quality'] = 20
params['base_quality'] = 15
params['filter_flags'] = 'UNMAP,SECONDARY,QCFAIL,DUP'
params['loption'] = True
config['EB']['threads'] = 1
config['EB']['params'] = params

### function args

In [56]:
args = {}
args['mut_file'] = 'testdata/input.anno'
args['tumor_bam'] = 'testdata/tumor.bam'
args['pon_list'] = 'testdata/list_normal_sample.txt'
args['output_path'] = 'output/output.anno'
args['region'] = ''

### load the config and GLOBAL STATE

In [57]:
debug_mode = True
params = config['EB']['params']
threads = config['EB']['threads']
# mapping quality
_q = str(params['map_quality'])  # 20
# base quality
_Q = str(params['base_quality'])
filter_quals = ''
for qual in range( 33, 33 + _Q ):
    filter_quals += str( unichr( qual ) )
    
_ff = params['filter_flags'] # 'UNMAP,SECONDARY,QCFAIL,DUP'
is_loption = params['loption'] # False
log_file = 'output/logs'

### utils

In [58]:
def validate_region(region):
    '''
    returns True if region 
    '''
    region_exp = re.compile('^[^ \t\n\r\f\v,]+:\d+\-\d+')
    # region format check
    if region:
        region_match = region_exp.match(region)
        if region_match:
            return True

def validate(mut_file, tumor_bam, pon_list):
    # file existence check
    if not os.path.exists(mut_file):
        sys.stderr.write(f"No target mutation file: {mut_file}")
        sys.exit(1)
    if not os.path.exists(tumor_bam):
        sys.stderr.write(f"No target bam file: {tumor_bam}")
        sys.exit(1)
    if not os.path.exists(tumor_bam + ".bai") and not os.path.exists(re.sub(r'bam$', "bai", tumor_bam)):
        sys.stderr.write(f"No index for target bam file: {tumor_bam}")
        sys.exit(1)

    if not os.path.exists(pon_list):
        sys.stderr.write(f"No control list file: {pon_list}")
        sys.exit(1)
        
    with open(pon_list) as hIN:
        for file in hIN:
            file = file.rstrip()
            if not os.path.exists(file):
                sys.stderr.write(f"No control bam file: {file}")
                sys.exit(1)
            if not os.path.exists(file + ".bai") and not os.path.exists(re.sub(r'bam$', "bai", file)):
                sys.stderr.write(f"No index for control bam file: {file}")
                
def make_region_list(anno_path):
    # make bed file for mpileup
    out_path = f"{anno_path}.region_list.bed"
    with open(anno_path) as file_in:
        with open(out_path, 'w') as file_out:
            for line in file_in:
                field = line.rstrip('\n').split('\t')
                loc = int(field[1]) - (field[4] == "-")  # -1 if field 4 == '-' eg. deletion 
                print(field[0], (loc - 1), loc, file=file_out, sep='\t')

### main

In [59]:
def main(args):
    '''
    validates files and refers to respective functions
    '''

    # should add validity check for arguments
    mut_file = args['mut_file']
    tumor_bam = args['tumor_bam']
    pon_list = args['pon_list']
    output_path = args['output_path']
    is_anno = not(os.path.splitext(mut_file)[-1] == '.vcf')
    region = args['region']

    # file existence check
    validate(mut_file, tumor_bam, pon_list) 
    if threads == 1:
        # non multi-threading mode
        if is_anno:
            EBFilter_worker_anno(mut_file, tumor_bam, pon_list, output_path, region)
        else: 
            EBFilter_worker_vcf(mut_file, tumor_bam, pon_list, output_path, region)
    else:
        # multi-threading mode
        ##########

        if is_anno:
            # partition anno files
            partition_anno(mut_file, output_path, threads)
            jobs = []
            for i in range(threads):
                worker_args = (f"{output_path}.{i}", tumor_bam, pon_list, f"{output_path}.{i}", region)
                process = multiprocessing.Process(target=EBFilter_worker_anno, args=worker_args)                    
                jobs.append(process)
                process.start()       
            # wait all the jobs to be done
            for i in range(threads):
                jobs[i].join()      
            # merge the individual results
            merge_anno(output_path, threads)      
            # delete intermediate files
            if not debug_mode:
                for i in range(threads):
                    print('delete')
                    subprocess.check_call(["rm", f"{output_path}.{i}", f"{output_path}.{i}.control.pileup", f"{output_path}.{i}.target.pileup"])

### worker_anno

In [73]:
def EBFilter_worker_anno(mut_file, tumor_bam, pon_list, output_path, region):

    pon_count = sum(1 for line in open(pon_list, 'r'))

    # --> process_anno
    if is_loption:
        make_region_list(mut_file) # in utils

    # generate pileup files
    anno2pileup(mut_file, output_path, tumor_bam, region)
    anno2pileup(mut_file, output_path, pon_list, region)
    ##########

    # delete region_list.bed
    if is_loption and not debug_mode:
        subprocess.check_call(["rm", "-f", f"{mut_file}.region_list.bed"])

    ##########
    # load pileup files into dictionaries pos2pileup_target['chr1:123453'] = "depth \t reads \t rQ"
    pos2pileup_target = {}
    pos2pileup_control = {}
 
    with open(f"{output_path}.target.pileup", 'r') as file_in:
        for line in file_in:
            field = line.rstrip('\n').split('\t')
            pos2pileup_target[f"{field[0]}:{field[1]}"] = '\t'.join(field[3:])

    with open(f"{output_path}.control.pileup", 'r') as file_in:
        for line in file_in:
            field = line.rstrip('\n').split('\t')
            pos2pileup_control[f"{field[0]}:{field[1]}"] = '\t'.join(field[3:])
    ##########

     ##########
    # get restricted region if not None
    if is_loption and region:
        region_match = region_exp.match(region)
        reg_chr = region_match.group(1)
        reg_start = int(region_match.group(2))
        reg_end = int(region_match.group(3))

    ##########

    with open(mut_file, 'r') as file_in:
        with open(output_path, 'w') as file_out:

            for line in file_in:

                field = line.rstrip('\n').split('\t')
                chr, pos, pos2, ref, alt = field[0], field[1], field[2], field[3], field[4]
                # adjust pos for deletion
                if alt == "-":
                    pos -= 1

                if is_loption and region:
                    if reg_chr != chr:
                        continue
                    if (int(pos) < reg_start) or (int(pos) > reg_end):
                        continue

                # pileup dicts are read into field_target as arrays
                field_target = pos2pileup_target[f"{chr}:{pos}"].split('\t') if f"{chr}:{pos}" in pos2pileup_target else []
                field_control = pos2pileup_control[f"{chr}:{pos}"].split('\t') if f"{chr}:{pos}" in pos2pileup_control else [] 

                # set the variance
                # ref   alt    var
                #  A     T      T
                #  -     T     +T
                #  A     -     -A
                var = ""
                if ref != "-" and alt != "-":
                    var = alt
                else:
                    if ref == "-":
                        var = "+" + alt
                    elif alt == "-":
                        var = "-" + ref
                EB_score = "." # if the variant is complex, we ignore that
                if var:
                    # get_eb_score('+A', [depth, reads, rQ], [depth1, reads1, rQ1, depth2, reads2, rQ2, depth3, reads3, rQ3], 3)
                    EB_score = get_eb_score(var, field_target, field_control, pon_count)
                
                
                # add the score and write the vcf record
                # print('\t'.join(F + [str(EB_score)]), file=file_out)
            

    # delete intermediate files
    if not debug_mode:
        subprocess.check_call(["rm", output_path + '.target.pileup'])
        subprocess.check_call(["rm", output_path + '.control.pileup'])

### anno2pileup

In [74]:
def anno2pileup(anno_path, out_path, bam_or_pon, region):
    '''
    creates a pileup from all the entries in the anno file
    '''
    with open(log_file, 'w') as log:
        with open(anno_path, 'r') as file_in:
            # determine wether it is bam or pon
            is_bam = (os.path.splitext(bam_or_pon)[-1] == '.bam')
            if is_bam:
                out_file = f"{out_path}.target.pileup"
            else:
                out_file = f"{out_path}.control.pileup"
            with open(out_file, 'w') as file_out:
                mpileup_cmd = ["samtools", "mpileup", "-B", "-d", "10000000", "-q", _q, "-Q", _Q, "--ff", _ff]

                # add tumor_bam or pon_list of bam files depending on file extension of bam_or_pon
                if is_bam:
                    mpileup_cmd += [bam_or_pon]
                else:
                    mpileup_cmd += ["-b", bam_or_pon]

                if is_loption:
                    # region_list.bed is generated by worker_anno
                    mpileup_cmd += ["-l", f"{anno_path}.region_list.bed"]

                    if region:
                        mpileup_cmd = mpileup_cmd + ["-r", region]
                    subprocess.check_call([str(command) for command in mpileup_cmd], stdout=file_out, stderr=log) # maybe logging
                # no loption 
                else: 
                    # get lines of anno file
                    for line in file_in:
                        print('anno2pileup', line, bam_or_pon)
                        field = line.rstrip('\n').split('\t')
                        loc = int(field[1]) - (field[4] == "-") # -1 if field 4 == '-' eg. deletion
                        mutReg = f"{field[0]}:{loc}-{loc}"
                
                        # set region for mpileup
                        mpileup_cmd += ["-r", mutReg]
                        subprocess.check_call([str(command) for command in mpileup_cmd], stdout=file_out, stderr=log)

### get EB score

In [75]:
def get_eb_score(var, F_target, F_control, pon_count):
    """
    calculate the EBCall score from pileup bases of tumor and control samples
    """

    # var = '+A'
    # F_target = [depth, reads, rQ]
    # F_control = [depth1, reads1, rQ1, depth2, reads2, rQ2, depth3, reads3, rQ3]
    # pon_count = 3)

    # obtain the mismatch numbers and depths of target sequence data for positive and negative strands
    if len(F_target) > 0:
        vars_target_p, depth_target_p, vars_target_n, depth_target_n = var_count_check(var, *F_target, False)
    else:
        vars_target_p, depth_target_p, vars_target_n, depth_target_n = 0

    # create [0,0,0,0,0,...,0] arrays for the 4 parameters
    vars_control_p = [0] * pon_count
    vars_control_n = [0] * pon_count
    depth_control_p = [0] * pon_count
    depth_control_n = [0] * pon_count

    # obtain the mismatch numbers and depths (for positive and negative strands) of control sequence data
    # for i in range(len(F_control) / 3):
    for i in range(pon_count):
        vars_control_p[i], depth_control_p[i], vars_control_n[i], depth_control_n[i] = var_count_check(var, *F_control[3*i:3*i+3], True)

    # estimate the beta-binomial parameters for positive and negative strands
    alpha_p, beta_p = fit_beta_binomial(numpy.array(depth_control_p), numpy.array(vars_control_p))
    alpha_n, beta_n = fit_beta_binomial(numpy.array(depth_control_n), numpy.array(vars_control_n))

    # evaluate the p-values of target mismatch numbers for positive and negative strands
    pvalue_p = beta_binom_pvalue([alpha_p, beta_p], depth_target_p, vars_target_p)
    pvalue_n = beta_binom_pvalue([alpha_n, beta_n], depth_target_n, vars_target_n)

    # perform Fisher's combination methods for integrating two p-values of positive and negative strands
    EB_pvalue = utils.fisher_combination([pvalue_p, pvalue_n])
    EB_score = 0
    if EB_pvalue < 1e-60:
        EB_score = 60
    elif EB_pvalue > 1.0 - 1e-10:
        EB_score = 0
    else:
        EB_score = -round(math.log10(EB_pvalue), 3)

    return EB_score


### control count

In [76]:
indel_re = re.compile(r'([\+\-])([0-9]+)([ACGTNacgtn]+)') # +23ATTTNNGC or -34TTCCAAG
sign_re = re.compile(r'\^\]|\$')


def varCountCheck(var, depth, reads, rQ, is_verbose):
    '''
    per anno entry: outputs 
    '''
    # var   = '+A'
    # depth = 20
    # reads = 'AAATTCCGGG^]ACGTA$CCT'
    # rQ = 'IIIIIIIFFCDDD'

    if var[0].upper() not in "+-ATGCN":
            print(var + ": input var has wrong format!", file=sys.stderr)
            sys.exit(1)
    if len(reads) != len(rQ):
        print("f{reads}\n{rQ}", file=sys.stderr)
        print("lengths of bases and qualities are different!", file=sys.stderr)
        sys.exit(1)

    if depth == 0:
        return [0,0,0,0]

    # delete the start and end signs
    reads = sign_re.sub('', reads)

    # init
    ins_p, ins_n, del_p, del_n = 0
    ins_vb_p, ins_vb_n, del_vb_p, del_vb_n = 0

    #######################       INDELS   #########################################
    #######################       ??????   #########################################
    deleted = 0

    for m in indel_re.finditer(reads):  # match object generator
        site = m.start()
        type = m.group(1)
        indel_size = int(m.group(2))
        varChar = m.group(3)[0:indel_size]

        """
        # just leave these codes for the case where I want to evaluate indels in more detail....
        if not (type in indel and varChar.upper() in indel[type]):
            indel[type][varChar.upper()]['+'] = 0
            indel[type][varChar.upper()]['-'] = 0
       
        strand = '+' if varChar.islower() else '-' 
        indel[type][varChar.upper()][strand] += 1
        """
        # checking if size of del is OK
        var_match = False
        if var[0] == "-":
            if len(var[1:]) == len(varChar):
                var_match = True
        elif var[1:].upper() == varChar.upper():
            var_match = True


        if type == "+": # ins
            if varChar.isupper():
                ins_vb_p += 1
                if var_match:
                    ins_p += 1
            else:
                ins_vb_n += 1
                if var_match:
                    ins_n += 1
        else:   # del
            if varChar.isupper():
                del_vb_p += 1
                if var_match:
                    del_p += 1
            else:
                del_vb_n += 1
                if var_match:
                    del_n += 1

        print("Indels: {type}\t}{var}\t'{varChar.upper()}\t{strand}")

        reads = reads[0:(site - deleted)] + reads[(site + indel_size + len(indel_size) + 1 - deleted):]
        deleted += 1 + len(indel_size) + indel_size

    #############################################################################

    base_count = {"A": 0, "C": 0, "G": 0, "T": 0, "N": 0, "a": 0, "c": 0, "g": 0, "t": 0, "n": 0}

    # count all the bases in the read and allocate to base_count dict
    if var.upper() in "ACGT":
        for base, qual in zip(reads, rQ):
            if not (qual in filter_quals):
                if base in "ATGCNatgcn":
                    base_count[base] += 1
    # for indel check we ignore base qualities
    else:
        for base in reads:
            if base in "ATGCNatgcn":
                base_count[base] += 1


    # sum up the forward and reverse bases
    depth_p = base_count["A"] + base_count["C"] + base_count["G"] + base_count["T"] + base_count["N"]
    depth_n = base_count["a"] + base_count["c"] + base_count["g"] + base_count["t"] + base_count["n"]

    mismatch_p = 0
    mismatch_n = 0
   
    if var.upper() in "ACGT":
        mismatch_p = base_count[var.upper()]
        mismatch_n = base_count[var.lower()]    
    else:
        if var[0] == "+":
            if is_verbose:
                mismatch_p, mismatch_n = ins_vb_p, ins_vb_n
            else:
                mismatch_p, mismatch_n = ins_p, ins_n
        elif var[0] == "-":
            if is_verbose:
                mismatch_p, mismatch_n = del_vb_p, del_vb_n
            else:
                mismatch_p, mismatch_n = del_p, del_n

    return [mismatch_p, depth_p, mismatch_n, depth_n]


## Single threading

In [77]:
threads = 1
main(args)
!ls output

['36', 'C$gggGGggcGgGcggCcGgggGgGgcGCgggcgggg', 'qIIIqqIIIIIIIIIIIIIIIIIIIIIIIIIIIIII']
['19', 'cccTcCCCcCCCTCTCTCT', 'IIIqIIIIIIIIIIIIIII']
['29', 'cccaacccacacaaccacccCCCCcCCCC', 'IIIIIIIIIIIIIIIIIIIIIIIIIIIII']
['41', 'TTCCCCCCCTCTCCCccCCCCCtCccCCCTCcTcCCCCcTC', 'IqIIIIIqqIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII']
['49', 'G$GgGccCGgCCcgGGGGggCGcgGGGgggcgcCgggcgggGgGGcgCcC', 'IqIIIIqIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII']
['29', 'GGGGGTTGGGTGGTTGTGGTGGGTGGTGG', 'IIIIIIIIqIIIIIIIIIIIIIIIIIIII']
['40', 'C$aACccccACAAcAccacacAaccaccaaccaccccacac', 'IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII']
['23', 'GGAGAAGAGAGAGGAGGGGAGAA', 'IIIIIIIIIIIIIIIIIIIIIII']
['41', 'C$CgGgGGGCCCCgCCcGCccgccCCgGgCgCCGggccccCc', 'qIIqIqIqqIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII']
['49', 'AGaAgAAAggaAAAaaaaAAgGAAGAaaaAaAagAgaaaaGgGgGaAG^]A', 'IIIIIqIqIIIqIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII']
logs                       output.anno.control.pileup
output.anno                output.anno.target.pileup


## Multithreading

In [65]:
def partition_anno(anno_path, out_path, threads):

    
    with open(anno_path, 'r') as file_in:
        # get line number
        record_num = sum(1 for line in file_in)
        file_in.seek(0,0)
        threads = min(record_num, threads)
        # get lines per subprocess
        frac_lines = record_num / threads

        current_sub = current_line = 0
        file_out = open(f"{out_path}.{current_sub}", 'w')
        for line in file_in:
            print(line.rstrip("\n"), file=file_out) 
            current_line += 1
            if (current_line >= frac_lines) and (current_sub < threads - 1):
                current_sub += 1
                current_line = 0
                file_out.close()
                file_out = open(f"{out_path}.{current_sub}", 'w')
        file_out.close()

    return threads


def merge_anno(out_path, threads):

    file_out = open(out_path, 'w')
    for i in range(threads):
        file_in = open(f"{out_path}.{i}", 'r')
        for line in file_in:
            print(line.rstrip('\n'), file=file_out)

In [37]:
threads = 3
debug_mode = True
main(args)

In [112]:
ls output/

logs         output.anno


In [105]:
rmpileup = pd.read_csv('output/tumor.mpileup', sep='\t', header=None, names=['Chr', 'Start', 'ref', 'depth', 'reads', 'mapQ'], dtype={'Start':int, 'reads':str, 'mapQ': str})

In [168]:
reads = '+AGC^]TT^CC$GAN'
start_re = re.compile(r'\^\]|\$')
start_re.sub('', reads)

'+AGCTT^CCGAN'

In [160]:
len(test)


TypeError: object of type 'callable_iterator' has no len()

In [112]:
a = 'asdfgasdgf'
b = 'asdfsadfgd'
for a in zip(a,b):
    print(a[0], '\t', a[1])

a 	 a
s 	 s
d 	 d
f 	 f
g 	 s
a 	 a
s 	 d
d 	 f
g 	 g
f 	 d


In [96]:
for m in match:
    m.group(1)

In [120]:
filter_quals = ''
for qual in range( 33, 33 + 40 ):
    filter_quals += chr(qual)

In [165]:
'+'.upper()

'+'