In [520]:
# pair_haps - mate pairs of haplotypes that overlap in their softclipped en
# to optimize:
# - no need to prefetch before and after parts
# - pairing loop: right hap list traversal area can be limited
# - find_best_match() can be optimized for the thresholds that are later enforced

In [521]:
import os
import sys

# args
print(sys.argv)
if "pair_haps" in sys.argv[0] or "stdin" in sys.argv[0]:
    if len(sys.argv) != 5:
        print("usage: " + sys.argv[0] + " <input-cram> <in-bed> <ref-fasta> <output-prefix>\n")
        sys.exit(-1)
    # commandline invocation
    IN_CRAM = sys.argv[1]
    IN_BED = sys.argv[2]
    REF_FASTA = sys.argv[3]
    OUT_PREFIX = sys.argv[4]
else:
    IN_CRAM = os.path.expanduser("~/tmp/data/pair_haps/0030945-Z0114_merged_assembly.bam")
    IN_BED = os.path.expanduser("~/tmp/data/pair_haps/chr4_ex1.bed")
    REF_FASTA = os.path.expanduser("~/tmp/ref/Homo_sapiens_assembly38.fasta")
    OUT_PREFIX = "/tmp/pair_haps." + str(os.getpid())



['/Users/drorkessler/miniconda3/lib/python3.10/site-packages/ipykernel_launcher.py', '-f', '/Users/drorkessler/Library/Jupyter/runtime/kernel-07b28a21-3ea8-47ad-a6ab-d6136c112a3b.json']


In [522]:
# collect breakpoints from read
bp_quant = 500
min_softclip = 60
def quant(n):
    return round((round(n / bp_quant) + 0.5) * bp_quant)
    
def collect_breakpoints(read):

    lbp = []
    rbp = []
    
    if read.cigartuples[-1][0] == pysam.CSOFT_CLIP and read.cigartuples[-1][1] >= min_softclip:
        lbp.append((quant(read.reference_end), read))
    if read.cigartuples[0][0] == pysam.CSOFT_CLIP and read.cigartuples[0][1] >= min_softclip:
        rbp.append((quant(read.reference_start), read))

    return (lbp, rbp)    
        
        

In [523]:
# loop on entries from bed file, process one at a time
import pysam
from itertools import groupby

read_count = 0
lbp = []
rbp = []
header = None
with pysam.AlignmentFile(IN_CRAM, "rb", reference_filename=REF_FASTA) as samf:

    header = samf.header
    
    # loop on bed regions
    with open(IN_BED) as f:
        for line in f:
            bed_line = line.strip().split()
            chrom, start, end = bed_line[:3]
            start = int(start)
            end = int(end)

            # loop on reads
            for read in samf.fetch(chrom, start, end):
                read_count += 1
                if read_count % 1000000 == 0:
                    print("read_count", read_count, read.reference_name + ":" + str(read.reference_start))
                    sys.stdout.flush()
                    
                bp = collect_breakpoints(read)
                lbp += bp[0]
                rbp += bp[1]

# sort, groupby
lbp.sort(key=lambda x:x[0])
lbp = [(key, [x[1] for x in group]) for key, group in groupby(lbp, lambda x: x[0])]
rbp.sort(key=lambda x:x[0])
rbp = [(key, [x[1] for x in group]) for key, group in groupby(rbp, lambda x: x[0])]
print("len(lbp)", len(lbp))
print("len(rbp)", len(rbp))

len(lbp) 11479
len(rbp) 12424


In [530]:
from difflib import SequenceMatcher

# next we look for lr pairs which are 1K-10K in distance
min_d = 1000
max_d = 10000
max_match_jump = 20
min_match_size = 20
portion_min = 0.5

#h1_debug = ["HC_chr4:64571_1001", "HC_chr4:64571_1015"]
h1_debug = ["HC_chr4:41229_1002", "HC_chr4:64571_1015"]
h2_debug = ["HC_chr4:67904_1008", "HC_chr4:67904_1006"]
on_debug = False
find_longest_match_count = 0

def haps_are_paired(h1, h2):

    after1 = h1[4]
    after2 = h2[4]
    result = False
    jumps = None
    global find_longest_match_count

    a, b, size = SequenceMatcher(None, after1, after2).find_longest_match()
    find_longest_match_count += 1
    if on_debug and ((h1[0].qname in h1_debug and h2[0].qname in h2_debug)):
        print("======")
        print("a", a, "b", b, "size", size)
    if size >= min_match_size and a <= max_match_jump and b <= max_match_jump:
        p1 = size / len(after1)
        p2 = size / len(after2)
        if max(p1, p2) >= portion_min:
            result = True;
            jumps = (a, b, size, p1, p2, h1[0].cigarstring, h1[0].reference_start, h1[0].reference_end, len(h1[0].seq), h2[0].cigarstring, h2[0].reference_start, h2[0].reference_end, len(h2[0].seq))
    
    # debug
    if on_debug and ((h1[0].qname in h1_debug and h2[0].qname in h2_debug)):
        print("-----")
        print("h1")
        for h in h1[1:]:
            print("-", h)
        print("h2")
        for h in h2[1:]:
            print("-", h)
        print("result", result, "jumps", jumps)

    return (result, jumps)

def match_haps(lhaps, rhaps):
    h1s = []
    h2s = []
    matched = []
    # extract sequenecs
    for hap in lhaps:
        clip = hap.cigartuples[-1][1]
        before = hap.seq[:-clip]
        after = hap.seq[-clip:]
        h1s.append((hap, hap.qname, hap.cigartuples, before, after))
    for hap in rhaps:
        clip = hap.cigartuples[0][1]
        before = hap.seq[:clip]
        after = hap.seq[clip:]
        h2s.append((hap, hap.qname, hap.cigartuples, before, after))

    # look for matches
    for h1 in h1s:
        for h2 in h2s:
            result, jump = haps_are_paired(h1, h2)
            if result:
                matched.append((h1[0], h2[0], h1[0].qname, h2[0].qname, jump))

    return matched

# this is naive
pairs = []
count = 0
for l,lhaps in lbp:
    count += 1
    if count % 1000 == 0:
        print(l)
    for r,rhaps in rbp:
        if r-l >= min_d:
            if r-l <= max_d:
                pairs += match_haps(lhaps, rhaps)
            else:
                break
    

print("len(pairs)", len(pairs))
#for pair in pairs:
#    print(pair)


882250
1696750
2556250
3359250
4234250
5098750
5966250
6816750
7662250
8500750
9576750
len(pairs) 22


In [531]:
# group by left and right haps
lgrp = [(x,list(g)) for x,g in groupby(sorted(pairs, key=lambda x: x[2]), lambda x: x[2])]
rgrp = [(x,list(g)) for x,g in groupby(sorted(pairs, key=lambda x: x[3]), lambda x: x[3])]
if len(lgrp): lgrp[0]

In [532]:
# realignment code
def realign(grp, left):
    debug = grp[0] in h1_debug or grp[0] in h2_debug

    # sort pairs on lowest jump, get first
    pair = sorted(grp[1], key=lambda x: max(x[4][0], x[4][1]))[0]
    jump = pair[4]
    if debug:
        print("left", left)
        print("pair", pair[2:4])
        print("jump", jump)

    # create a copy, access other hap
    if left:
        d = pair[0].to_dict()
        hap1 = pysam.AlignedSegment.from_dict(d, header)
        hap2 = pair[1]
    else:
        hap1 = pair[0]
        d = pair[1].to_dict()
        hap2 = pysam.AlignedSegment.from_dict(d, header)

    # modify alignment according to jump
    if left:
        del_delta = hap2.reference_start - hap1.reference_end
        ins_delta = jump[0]
        if debug:
            print("del_delta", del_delta, "ins_delta", ins_delta)
        clip = hap1.cigartuples[-1][1]
        hap1.cigartuples = hap1.cigartuples[:-1] # remove softclip
        if ins_delta:
            hap1.cigartuples += [(pysam.CDIFF, ins_delta)] # add X element
        hap1.cigartuples += [(pysam.CDEL, del_delta - ins_delta)] # add D element
        hap1.cigartuples += [(pysam.CMATCH, clip - ins_delta)] # add M element for what was clipped
        if debug:
            print("new cigar: ", hap1.cigarstring)
        
        return hap1
    else:
        del_delta = hap2.reference_start - hap1.reference_end - 1
        ins_delta = jump[0]
        #ins_delta = 0
        if debug:
            print("del_delta", del_delta, "ins_delta", ins_delta)
        clip = hap2.cigartuples[0][1]
        hap2.cigartuples = hap2.cigartuples[1:] # remove softclip
        hap2.cigartuples = [(pysam.CDEL, del_delta - ins_delta)] + hap2.cigartuples # add D element
        if ins_delta:
            hap2.cigartuples = [(pysam.CDIFF, ins_delta)] + hap2.cigartuples # add X element
        hap2.cigartuples = [(pysam.CMATCH, clip - ins_delta)] + hap2.cigartuples # add M element for what was clipped
        hap2.reference_start -= (del_delta + clip - ins_delta)
        if debug:
            print("new cigar: ", hap2.cigarstring)

        return hap2


#realign(rgrp[0], False)

In [533]:
# create output file
sam_fname = OUT_PREFIX + ".sam"
print("sam_fname", sam_fname)
with pysam.AlignmentFile(sam_fname, "w", header=header) as sam_file:

    for grp in lgrp:
        read = realign(grp, True)
        sam_file.write(read)

    for grp in rgrp:
        read = realign(grp, False)
        sam_file.write(read)

# convert to sorted bam
bam_fname = OUT_PREFIX + ".bam"
cmd = "samtools sort " + sam_fname + " >" + bam_fname
print("cmd", cmd)
os.system(cmd)
cmd = "samtools index " + bam_fname
print("cmd", cmd)
os.system(cmd)

print("-----")
os.system("samtools view -c " + bam_fname)
None

sam_fname /tmp/pair_haps.54108.sam
left True
pair ('HC_chr4:64571_1015', 'HC_chr4:67904_1006')
jump (15, 0, 175, 0.6782945736434108, 1.0, '112M1D38M258S', 64470, 64621, 408, '74S175M', 67995, 68170, 249)
del_delta 3374 ins_delta 15
new cigar:  112M1D38M15X3359D243M
left False
pair ('HC_chr4:64571_1014', 'HC_chr4:67904_1006')
jump (14, 0, 72, 0.5950413223140496, 0.4114285714285714, '151M121S', 64470, 64621, 272, '74S175M', 67995, 68170, 249)
del_delta 3373 ins_delta 14
new cigar:  60M14X3359D175M
left False
pair ('HC_chr4:64571_1014', 'HC_chr4:67904_1008')
jump (14, 0, 72, 0.5950413223140496, 0.4114285714285714, '151M121S', 64470, 64621, 272, '83S175M', 67995, 68170, 258)
del_delta 3373 ins_delta 14
new cigar:  69M14X3359D175M
cmd samtools sort /tmp/pair_haps.54108.sam >/tmp/pair_haps.54108.bam
cmd samtools index /tmp/pair_haps.54108.bam
-----
26


In [534]:
find_longest_match_count

595578