# This notebook contains identical code to SAILOR's read filter. 
- Doing this because there are some reads in our filtered set that don't seem to also be in the 'all' set??? Let's figure out why.

In [1]:
import pandas as pd
import numpy as np
import pysam 
import os
import re
import glob
from collections import defaultdict

In [2]:
input_dir = '/home/bay001/projects/kris_apobec_20200121/permanent_data2/07_scRNA_groups/sailor_outputs_groups/original_outputs/RPS2-STAMP_possorted_genome_bam_MD/results'
output_dir = '/oasis/tscc/scratch/bay001/sailor/'

input_bam = os.path.join(input_dir, 'RPS2-STAMP_possorted_genome_bam_MD.rev.sorted.rmdup.bam')
output_bam = os.path.join(output_dir, os.path.basename(input_bam) + ".filtered.bam")

assert os.path.exists(input_bam)

In [3]:
def get_softclip(cigar):
    """
    Returns the number of bases to be softclipped on either left or right
    side. Or both. If not softclipped, return 0

    :param cigar: string
        BAM/SAM CIGAR string
    :return left: int
        number of softclipped reads at the beginning
    :return right: int
        number of softclipped reads at the end
    """
    softclip_regex = ur"(\d+)S"
    softclip = re.findall(softclip_regex,cigar)

    softclip_right_regex = ur"[\w\d]{1}(\d+)S" # if the softclip comes from the right side
    softclip_right = re.findall(softclip_right_regex,cigar)

    left = 0
    right = 0

    if softclip:
        if len(softclip) == 2: # softclipped on both sides
            left = int(softclip[0])
            right = int(softclip[1])
        elif len(softclip_right) == 1: # softclipped only on the RIGHT side
            right = int(softclip[0])
        else: # softclipped only on the LEFT side
            left = int(softclip[0])
    return left, right


def remove_softclipped_reads(left, right, read_seq):
    """
    Returns the read after removing softclipped bases.
    :param left: int
        left softclip
    :param right: int
        right softclip
    :param read_seq: string
        read sequence
    :return softclipped_read_sequence: string


    """
    if right == 0:
        return read_seq[left:]
    return read_seq[left:-right]


def get_junction_overhangs(cigar, min_jct_overhang):
    """
    Returns the MIN number of reads left/right of ANY junction
    as indicated by the N in a cigar string.
    Return -1, -1 for reads that don't span junctions.

    :param cigar: string
    :param min_jct_overhang: int
    :return:
    """
    if cigar.count('N') == 1:
        return get_single_junction_overhang(cigar)

    min_left, min_right = get_single_junction_overhang(cigar)
    for i in [m.start() for m in re.finditer(r"N",cigar)]:
        sub_cigar = cigar[i+1:]
        left, right = get_single_junction_overhang(sub_cigar)
        if left != -1 and right != -1:
            min_left = min(min_left, left)
            min_right = min(min_right, right)
    return min_left, min_right


def get_single_junction_overhang(cigar):
    """
    Returns the number of reads left/right of a junction as indicated
    by the LEFTMOST N in a cigar string. Return -1, -1 for reads that don't span
    junctions.

    :param cigar: string
    :return left: int
    :return right: int
    """
    cigar_overhang_regex = ur"(\d+)M[\d]+N(\d+)M"

    overhangs = re.findall(cigar_overhang_regex, cigar)
    if overhangs:
        return int(overhangs[0][0]), int(overhangs[0][1])
    else:
        return -1, -1


def is_mismatch_before_n_flank_of_read(md, n):
    """
    Returns True if there is a mismatch before the first n nucleotides
    of a read, or if there is a mismatch before the last n nucleotides
    of a read.

    :param md: string
    :param n: int
    :return is_mismatch: boolean
    """
    is_mismatch = False
    flank_mm_regex = ur"^(\d+).*[ACGT](\d+)$"
    flank_mm = re.findall(flank_mm_regex,md)
    if flank_mm:
        flank_mm = flank_mm[0]
        if flank_mm[1]:
            if int(flank_mm[1]) < n:
                is_mismatch = True
        if flank_mm[0]:
            if int(flank_mm[0]) < n:
                is_mismatch = True
    return is_mismatch


def non_ag_mismatches(read_seq, md, sense):
    """
    Given a read sequence, MD tag, and 'sense' (look for AG if sense,
    look for TC if antisense), return the number of non-AG/TC mismatches
    seen in the read.

    :param read_seq: string
    :param md: string
    :param sense: boolean
    :return nonAG: int
    """
    mismatches_regex = ur"(\d+)([ATCG])"
    mismatches = re.findall(mismatches_regex,md)
    non_ag_mm_counts = 0
    if mismatches:
        read_pos = 0
        for mismatch in mismatches:
            ref_allele = mismatch[1]
            read_pos += int(mismatch[0])

            read_allele = read_seq[read_pos]
            if(not((ref_allele == 'A' and read_allele == 'G' and sense == True) or
                   (ref_allele == 'T' and read_allele == 'C' and sense == False))):
                non_ag_mm_counts += 1
            read_pos += 1

    return non_ag_mm_counts


def non_ct_mismatches(read_seq, md, sense):
    """
    Given a read sequence, MD tag, and 'sense' (look for CT if sense,
    look for GA if antisense), return the number of non-CT/GA mismatches
    seen in the read.

    :param read_seq: string
    :param md: string
    :param sense: boolean
    :return nonAG: int
    """
    mismatches_regex = ur"(\d+)([ATCG])"
    mismatches = re.findall(mismatches_regex,md)
    non_ct_mm_counts = 0
    if mismatches:
        read_pos = 0
        for mismatch in mismatches:
            ref_allele = mismatch[1]
            read_pos += int(mismatch[0])

            read_allele = read_seq[read_pos]
            if(not((ref_allele == 'C' and read_allele == 'T' and sense == True) or
                   (ref_allele == 'G' and read_allele == 'A' and sense == False))):
                non_ct_mm_counts += 1
            read_pos += 1

    return non_ct_mm_counts


def non_gt_mismatches(read_seq, md, sense):
    """
    Given a read sequence, MD tag, and 'sense' (look for GT if sense,
    look for CA if antisense), return the number of non-GT/CA mismatches
    seen in the read.

    :param read_seq: string
    :param md: string
    :param sense: boolean
    :return nonAG: int
    """
    mismatches_regex = ur"(\d+)([ATCG])"
    mismatches = re.findall(mismatches_regex,md)
    non_gt_mm_counts = 0
    if mismatches:
        read_pos = 0
        for mismatch in mismatches:
            ref_allele = mismatch[1]
            read_pos += int(mismatch[0])

            read_allele = read_seq[read_pos]
            if(not((ref_allele == 'G' and read_allele == 'T' and sense == True) or
                   (ref_allele == 'C' and read_allele == 'A' and sense == False))):
                non_gt_mm_counts += 1
            read_pos += 1

    return non_gt_mm_counts


def filter_reads(
        input_bam, output_bam, min_overhang, min_underhang,
        non_ag_mm_threshold, reverse_stranded=True, ct=False, gt=False
):
    """
    # Step 3: filter reads

    Per Boyko/Mike Washburn's editing paper:

    b) it had a junction overhang < 10nt according to its SAMtools CIGAR string
    c) it had > 1 non-A2G or non-C2T mismatch or any short indel, per its MD tag
    d) it had a mismatch less than 25nt away from either end of the read
    (this was changed to 5nt in the relaxed version used for quantification)

    * Removed 5d per discussions with the Hundley lab.

    :param input_bam: basestring
    :param output_bam: basestring
    :param min_overhang: int
        minimum number required to span a junction (filter b)
    :param min_underhang: int
        minimum number required at the end of the read (filter c)
    :param non_ag_mm_threshold: int
        any more non A-G mismatches than this will filter the read.
    :return:
    """

    print("Filtering reads on: {}".format(input_bam))
    i = pysam.AlignmentFile(input_bam)
    o = pysam.AlignmentFile(output_bam, "wb", template=i)
    flags = defaultdict(list) # number of flags in the bam file
    warn_mm = False
    warn_junction = False
    warn_x = False

    counter = 0
    for read in i:
        if counter % 1000000 == 0:
            print("Parsed {} reads".format(counter))
        try:
            flag = 1  # start out as a 'good' read
            cigar = read.cigarstring
            if 'X' in cigar or '=' in cigar:
                warn_x = True
            read_seq = read.query_sequence
            read_name = read.query_name
            """
            Throw out unmapped reads
            """
            if read.is_unmapped:
                flags['unmapped'].append(read.name)
                continue  # must be here, otherwise reads won't have CIGAR
            try:
                mm = read.get_tag('MD')
            except KeyError:
                warn_mm = True
                mm = ''
            """
            Remove
            """
            """
            Takes care of soft clipped bases
            (remove bases from the read_seq which are soft clipped
            to not interfere with mis-alignments downstream)
            """
            left_softclip, right_softclip = get_softclip(cigar)
            read_seq = remove_softclipped_reads(
                left_softclip, right_softclip, read_seq
            )

            """
            # 5b) Check for small junction overhangs.
            If match junction over hang is small, (if mismatches occur close to
            junctions), remove read
            """

            left_overhang, right_overhang = get_junction_overhangs(cigar, min_overhang)
            if left_overhang != -1 and right_overhang != -1:
                if left_overhang < min_overhang or right_overhang < min_overhang:
                    flags['small_overhang'].append(read_name)
                    flag = 3
            elif left_overhang != -1 or right_overhang != -1:
                warn_junction = True


            """
            # 5c) If there exists indels, remove them.
            """
            if 'I' in cigar or 'D' in cigar:
                flags['indel'].append(read_name)
                flag = 4

            """
            # 2) If primary not primary alignment, throw out
            """
    
            if read.is_secondary:
                flags['not_primary'].append(read_name)
                flag = 5

            """
            # MD:Z tag-based filters. If there is a mismatch on either end of the read, throw out.
            """

            if is_mismatch_before_n_flank_of_read(mm, min_underhang):
                flags['small_underhang'].append(read_name)
                flag = 6

            """
            # Manually setting reversed reads to 'sense' strand per truseq
            library protocols (Default is truseq reverse stranded)
            """
            if reverse_stranded:
                sense = True if read.is_reverse == True else False
            else:
                sense = True if read.is_reverse == False else False

            """
            # 5c) Search MDZ for A, T's in reference, if mutations are not 
            #     A-G (sense) or T-C (antisense), add to non_ag_mm_counts
            #     threshold. If non_ag_mm_counts > threshold, toss the whole
            #     read. Otherwise, allow up to the maximum allowable non-AG
            #     mismatches before tossing. Default: 1 mm
            """
            if ct:
                if non_ct_mismatches(read_seq, mm, sense) > non_ag_mm_threshold:
                    flags['non_ct_exceeded'].append(read_name)
                    flag = 7
            elif gt:
                if non_gt_mismatches(read_seq, mm, sense) > non_ag_mm_threshold:
                    flags['non_gt_exceeded'].append(read_name)
                    flag = 7
            else:
                if non_ag_mismatches(read_seq, mm, sense) > non_ag_mm_threshold:
                    flags['non_ag_exceeded'].append(read_name)
                    flag = 7

            if flag == 1:
                o.write(read)
                
            if read_name == 'A00953:41:HWMCFDSXX:1:1259:28800:4617':
                print("Flag for {} is {}".format("A00953:41:HWMCFDSXX:1:1259:28800:4617", flag))
        except TypeError as e:
            print("error! {}".format(e))
            
        counter += 1
    i.close()
    o.close()
    # warn if optional MD tag missing from any read
    if warn_mm:
        print("Warning, reads lack optional MD: tag (could not calculate non-ag mismatches)")
    if warn_junction:
        print("Warning: CIGAR {} has weird junction mark (or the regex is wrong)")
    if warn_x:
        print("Warning: I don't like X or = in CIGAR (use M instead)")

    return flags

In [4]:
flags = filter_reads(
    input_bam, output_bam, min_overhang=10, min_underhang=5,
    non_ag_mm_threshold=1, reverse_stranded=False, ct=True, gt=False
)

Filtering reads on: /home/bay001/projects/kris_apobec_20200121/permanent_data2/07_scRNA_groups/sailor_outputs_groups/original_outputs/RPS2-STAMP_possorted_genome_bam_MD/results/RPS2-STAMP_possorted_genome_bam_MD.rev.sorted.rmdup.bam
Parsed 0 reads
Parsed 1000000 reads
Parsed 2000000 reads
Parsed 3000000 reads
Parsed 4000000 reads
Parsed 5000000 reads
Parsed 6000000 reads
Parsed 7000000 reads
Parsed 8000000 reads
Parsed 9000000 reads
Parsed 10000000 reads
Parsed 11000000 reads
Parsed 12000000 reads
Parsed 13000000 reads
Parsed 14000000 reads
Parsed 15000000 reads
Parsed 16000000 reads
Parsed 17000000 reads
Parsed 18000000 reads
Parsed 19000000 reads
Parsed 20000000 reads
Parsed 21000000 reads
Parsed 22000000 reads
Parsed 23000000 reads
Parsed 24000000 reads
Parsed 25000000 reads
Parsed 26000000 reads
Parsed 27000000 reads
Parsed 28000000 reads
Parsed 29000000 reads
Parsed 30000000 reads
Parsed 31000000 reads
Parsed 32000000 reads
Parsed 33000000 reads
Parsed 34000000 reads
Parsed 350000

In [5]:
flags.keys()

['small_overhang', 'non_ct_exceeded', 'small_underhang', 'indel']

In [6]:
len(flags['small_overhang'])

6961507

In [7]:
len(flags['non_ct_exceeded'])

9137940

In [8]:
len(flags['indel'])

3798831