Skip to content

Commit

Permalink
Switch back to reading pcols, but store reads temporarily
Browse files Browse the repository at this point in the history
  • Loading branch information
SamStudio8 committed Jan 24, 2017
1 parent 69eddeb commit 6676135
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 56 deletions.
10 changes: 5 additions & 5 deletions gretel/gretel.py
Expand Up @@ -57,6 +57,7 @@ def reweight_hansel_from_path(hansel, path, ratio):
t_i = i
t_j = j
size += hansel.reweight_observation(path[t_i], path[t_j], t_i, t_j, ratio)
sys.stderr.write("[REWT] Ratio %.3f, Removed %.1f\n" % (ratio, size))
return size


Expand Down Expand Up @@ -166,14 +167,13 @@ def process_bam(vcf_handler, bam_path, contig_name, start_pos, end_pos, L, use_e
A dictionary of metadata returned from the BAM parsing, such as
a list of the number of variants that each read spans.
"""
bam = pysam.AlignmentFile(bam_path)

#NOTE(samstudio8)
# Could we optimise for lower triangle by collapsing one of the dimensions
# such that Z[m][n][i][j] == Z[m][n][i + ((j-1)*(j))/2]


meta = util.load_from_bam(None, bam, contig_name, start_pos, end_pos, vcf_handler, use_end_sentinels, n_threads)
meta = util.load_from_bam(None, bam_path, contig_name, start_pos, end_pos, vcf_handler, use_end_sentinels, n_threads)
hansel = Hansel(meta["hansel"], ['A', 'C', 'G', 'T', 'N', "_"], ['N', "_"], L=L)

if hansel.L == 0:
Expand Down Expand Up @@ -258,14 +258,14 @@ def generate_path(n_snps, hansel, original_hansel):
# Find path
sys.stderr.write("*** ESTABLISH ***\n")
for snp in range(1, n_snps+1):
sys.stderr.write("\t*** ***\n")
sys.stderr.write("\t[SNP_] SNP %d\n" % snp)
#sys.stderr.write("\t*** ***\n")
#sys.stderr.write("\t[SNP_] SNP %d\n" % snp)

# Get marginal and calculate branch probabilities for each available
# mallele, given the current path seen so far
# Select the next branch and append it to the path
curr_branches = hansel.get_edge_weights_at(snp, current_path)
sys.stderr.write("\t[TREE] %s\n" % curr_branches)
#sys.stderr.write("\t[TREE] %s\n" % curr_branches)
# Return the symbol and probability of the next base to add to the
# current path based on the best marginal
next_v = 0.0
Expand Down
130 changes: 79 additions & 51 deletions gretel/util.py
Expand Up @@ -24,7 +24,7 @@ def partition_snps(region, n_parts, start_1pos, end_1pos):
#TODO SENTINEL SYMBOLS BEFORE AND AFTER A READ
#TODO What happens if we traverse backwards...?
#TODO Single SNP reads could use a pairwise observation with themselves? (A, A, i, i)
def load_from_bam(h, bam, target_contig, start_pos, end_pos, vcf_handler, use_end_sentinels=False, n_threads=1):
def load_from_bam(h, bam_path, target_contig, start_pos, end_pos, vcf_handler, use_end_sentinels=False, n_threads=1):
"""
Load variants observed in a :py:class:`pysam.AlignmentFile` to
an instance of :py:class:`hansel.hansel.Hansel`.
Expand All @@ -34,8 +34,8 @@ def load_from_bam(h, bam, target_contig, start_pos, end_pos, vcf_handler, use_en
hansel : :py:class:`hansel.hansel.Hansel`
An initialised instance of the `Hansel` data structure.
bam : :py:class:`pysam.AlignmentFile`
A BAM alignment.
bam : str
Path to the BAM alignment
target_contig : str
The name of the contig for which to recover haplotypes.
Expand Down Expand Up @@ -123,6 +123,7 @@ def __symbol_num(symbol):
slices = 0
covered_snps = 0

bam = pysam.AlignmentFile(bam_path)
while True:
work_block = bam_q.get()
if work_block is None:
Expand All @@ -134,71 +135,98 @@ def __symbol_num(symbol):
})
break

for read in bam.fetch(target_contig, start=work_block["start"]-1, end=work_block["end"], multiple_iterators=True):

START_POS_OFFSET = 0
reads = {}
for p_col in bam.pileup(reference=target_contig, start=work_block["start"]-1, end=work_block["end"]):
if p_col.reference_pos + 1 > end_pos:
# Ignore positions beyond the end_pos
break

if read.is_duplicate or read.is_secondary:
if vcf_handler["region"][p_col.reference_pos+1] != 1:
continue

LEFTMOST_1pos = read.reference_start + 1 # Convert 0-based reference_start to 1-based position (to match region array and 1-based VCF)
RIGHTMOST_1pos = read.reference_end #ofc this is 1-indexed instead of 0

# Special case: Consider reads that begin before the start_pos, but overlap the 0th block
if work_block["i"] == 0:
if LEFTMOST_1pos < start_pos:
# Read starts before the start_pos
if read.reference_start + 1 + read.query_alignment_length < start_pos:
# Read ends before the start_pos
for p_read in p_col.pileups:

curr_read_1or2 = None
if p_read.alignment.is_read1:
curr_read_1or2 = 1
elif p_read.alignment.is_read2:
curr_read_1or2 = 2

curr_read_name = "%s_%d" % (p_read.alignment.query_name, curr_read_1or2)

LEFTMOST_1pos = p_read.alignment.reference_start + 1 # Convert 0-based reference_start to 1-based position (to match region array and 1-based VCF)

# Special case: Consider reads that begin before the start_pos, but overlap the 0th block
if work_block["i"] == 0:
if LEFTMOST_1pos < start_pos:
# Read starts before the start_pos
if p_read.alignment.reference_start + 1 + p_read.alignment.query_alignment_length < start_pos:
# Read ends before the start_pos
continue
LEFTMOST_1pos = start_pos
else:
# This read begins before the start of the current (non-0) block
# and will have already been covered by the block that preceded it
if LEFTMOST_1pos < work_block["start"]:
continue

LEFTMOST_1pos = start_pos
START_POS_OFFSET = (start_pos - (read.reference_start + 1))

else:
# This read begins before the start of the current (non-0) block
# and will have already been covered by the block that preceded it
if LEFTMOST_1pos < work_block["start"]:
if curr_read_name not in reads:
reads[curr_read_name] = {
"rank": np.sum(vcf_handler["region"][1 : LEFTMOST_1pos]),
"seq": [],
"quals": [],
"refs_1pos": [],
"read_variants_0pos": [],
}


## Read ends after the end_pos of interest, so clip it
#if RIGHTMOST_1pos > work_block["region_end"]:
# RIGHTMOST_1pos = work_block["region_end"]

sequence = None
qual = None
if not p_read.query_position:
# qpos is None for deletion and reference skips
# TODO Not sure about how to estimate quality of deletion?
sequence = "_" * abs(p_read.indel)
qual = p_read.alignment.query_qualities[p_read.query_position_or_next] * abs(p_read.indel)
elif p_read.indel > 0:
sequence = p_read.alignment.query_sequence[p_read.query_position : p_read.query_position + p_read.indel + 1]
qual = p_read.alignment.query_qualities[p_read.query_position : p_read.query_position + p_read.indel + 1]
else:
sequence = p_read.alignment.query_sequence[p_read.query_position]
qual = p_read.alignment.query_qualities[p_read.query_position]

if not sequence:
print("Help!")
continue

# If the current read begins after the region of interest, stop parsing the sorted BAM
#if LEFTMOST_1pos > end_pos:
# break
reads[curr_read_name]["seq"].append(sequence)
reads[curr_read_name]["quals"].append(qual)
reads[curr_read_name]["refs_1pos"].append(p_col.reference_pos+1)
reads[curr_read_name]["read_variants_0pos"].append(p_read.query_position)

# Read ends after the end_pos of interest, so clip it
if RIGHTMOST_1pos > work_block["region_end"]:
RIGHTMOST_1pos = work_block["region_end"]
print("DONE")

# Check if the read actually covers any SNPs
support_len = np.sum(vcf_handler["region"][LEFTMOST_1pos : RIGHTMOST_1pos + 1])

# Ignore reads without evidence
if support_len == 0:
continue
num_reads = len(reads)
for qi, qname in enumerate(reads):
progress_q.put({"pos": num_reads-qi, "worker_i": worker_i})

if not len(reads[qname]["seq"]) > 0:
# Ignore reads without evidence
continue
slices += 1

rank = np.sum(vcf_handler["region"][1 : LEFTMOST_1pos])
support_seq = []
rank = reads[qname]["rank"]
support_len = len(reads[qname]["seq"])

aligned_residues = [x for x in read.get_aligned_pairs(with_seq=True) if x[1] is not None] # Filter out SOFTCLIP and INS
for i in range(0, support_len):
snp_rev = vcf_handler["snp_rev"][rank + i]

snp_pos_on_read = snp_rev - LEFTMOST_1pos + START_POS_OFFSET
snp_pos_on_aligned_read = aligned_residues[snp_pos_on_read][0]

try:
support_seq.append(read.query_sequence[snp_pos_on_aligned_read])
except TypeError:
sys.stderr.write("NoneType (DEL) found at SNP site on read '%s', reference position %d\n" % (read.qname, snp_rev))
support_seq.append("_")

support_seq = "".join(support_seq)
# TODO Still not really sure how to handle indels, our matrix is designed for single symbols... :<
support_seq = "".join([b[0] for b in reads[qname]["seq"]])
covered_snps += len(support_seq.replace("N", "").replace("-", ""))

progress_q.put({"pos": read.reference_start + 1, "worker_i": worker_i})

# For each position in the supporting sequence (that is, each covered SNP)
for i in range(0, support_len):
snp_a = support_seq[i]
Expand Down

0 comments on commit 6676135

Please sign in to comment.