Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify haploid Viterbi/FB to handle NONCOPY state in reference panel #31

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions lshmm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
path_ll_hap,
)

MISSING = -1
NONCOPY = -2

EQUAL_BOTH_HOM = 4
UNEQUAL_BOTH_HOM = 0
BOTH_HET = 7
Expand All @@ -27,18 +30,27 @@

def check_alleles(alleles, m):
"""
Checks the specified allele list and returns a list of lists
of alleles of length num_sites.
If alleles is a 1D list of strings, assume that this list is used
for each site and return num_sites copies of this list.
Otherwise, raise a ValueError if alleles is not a list of length
num_sites.
Checks the specified allele list and returns a list of allele lists of length m.

If alleles is a 1D list of strings, assume that this list is used for each site
and return num_sites copies of this list. Otherwise, raise a ValueError
if alleles is not a list of length m.

Note MISSING and NONCOPY values are excluded from the counts.

:param list alleles: A list of lists of alleles or strings.
:param int m: Number of sites.
:return: An array of number of distinct alleles at each site.
:rtype: numpy.ndarray
"""
if isinstance(alleles[0], str):
return np.int8([len(alleles) for _ in range(m)])
if len(alleles) != m:
raise ValueError("Malformed alleles list")
n_alleles = np.int8([(len(alleles_site)) for alleles_site in alleles])
raise ValueError("Number of alleles list is not equal to number of sites.")
exclusion_set = np.array([MISSING, NONCOPY])
n_alleles = np.zeros(m, dtype=np.int8)
for i in range(m):
n_alleles[i] = np.sum(~np.isin(np.unique(alleles[i]), exclusion_set))
return n_alleles


Expand Down Expand Up @@ -132,12 +144,11 @@ def set_emission_probabilities(
# Check alleles should go in here, and modify e before passing to the algorithm
# If alleles is not passed, we don't perform a test of alleles, but set n_alleles based on the reference_panel.
if alleles is None:
n_alleles = np.int8(
[
len(np.unique(np.append(reference_panel[j, :], query[:, j])))
for j in range(reference_panel.shape[0])
]
)
exclusion_set = np.array([MISSING, NONCOPY])
n_alleles = np.zeros(m, dtype=np.int8)
for j in range(reference_panel.shape[0]):
uniq_alleles = np.unique(np.append(reference_panel[j, :], query[:, j]))
n_alleles[j] = np.sum(~np.isin(uniq_alleles, exclusion_set))
else:
n_alleles = check_alleles(alleles, m)

Expand Down
42 changes: 21 additions & 21 deletions lshmm/forward_backward/fb_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from lshmm import jit

MISSING = -1
NONCOPY = -2


@jit.numba_njit
Expand All @@ -17,9 +18,10 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):

c = np.zeros(m)
for i in range(n):
F[0, i] = (
1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
)
em_prob = 0
if H[0, i] != NONCOPY:
em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
F[0, i] = 1 / n * em_prob
c[0] += F[0, i]

for i in range(n):
Expand All @@ -29,9 +31,10 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
for l in range(1, m):
for i in range(n):
F[l, i] = F[l - 1, i] * (1 - r[l]) + r_n[l]
F[l, i] *= e[
l, np.int64(np.equal(H[l, i], s[0, l]) or s[0, l] == MISSING)
]
em_prob = 0
if H[l, i] != NONCOPY:
em_prob = e[l, np.int64(np.equal(H[l, i], s[0, l]) or s[0, l] == MISSING)]
F[l, i] *= em_prob
c[l] += F[l, i]

for i in range(n):
Expand All @@ -44,17 +47,19 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
c = np.ones(m)

for i in range(n):
F[0, i] = (
1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
)
em_prob = 0
if H[0, i] != NONCOPY:
em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
F[0, i] = 1 / n * em_prob

# Forwards pass
for l in range(1, m):
for i in range(n):
F[l, i] = F[l - 1, i] * (1 - r[l]) + np.sum(F[l - 1, :]) * r_n[l]
F[l, i] *= e[
l, np.int64(np.equal(H[l, i], s[0, l]) or s[0, l] == MISSING)
]
em_prob = 0
if H[l, i] != NONCOPY:
em_prob = e[l, np.int64(np.equal(H[l, i], s[0, l]) or s[0, l] == MISSING)]
F[l, i] *= em_prob

ll = np.log10(np.sum(F[m - 1, :]))

Expand All @@ -75,15 +80,10 @@ def backwards_ls_hap(n, m, H, s, e, c, r):
tmp_B = np.zeros(n)
tmp_B_sum = 0
for i in range(n):
tmp_B[i] = (
e[
l + 1,
np.int64(
np.equal(H[l + 1, i], s[0, l + 1]) or s[0, l + 1] == MISSING
),
]
* B[l + 1, i]
)
em_prob = 0
if H[l + 1, i] != NONCOPY:
em_prob = e[l + 1, np.int64(np.equal(H[l + 1, i], s[0, l + 1]) or s[0, l + 1] == MISSING)]
tmp_B[i] = em_prob * B[l + 1, i]
tmp_B_sum += tmp_B[i]
for i in range(n):
B[l, i] = r_n[l + 1] * tmp_B_sum
Expand Down
70 changes: 45 additions & 25 deletions lshmm/vit_haploid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from . import jit

MISSING = -1
NONCOPY = -2


@jit.numba_njit
Expand All @@ -13,10 +14,10 @@ def viterbi_naive_init(n, m, H, s, e, r):
P = np.zeros((m, n)).astype(np.int64)
r_n = r / n
for i in range(n):
V[0, i] = (
1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
)

em_prob = 0
if H[0, i] != NONCOPY:
em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
V[0, i] = 1 / n * em_prob
return V, P, r_n


Expand All @@ -29,9 +30,10 @@ def viterbi_init(n, m, H, s, e, r):
r_n = r / n

for i in range(n):
V_previous[i] = (
1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
)
em_prob = 0
if H[0, i] != NONCOPY:
em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
V_previous[i] = 1 / n * em_prob

return V, V_previous, P, r_n

Expand All @@ -47,10 +49,10 @@ def forwards_viterbi_hap_naive(n, m, H, s, e, r):
# Get the vector to maximise over
v = np.zeros(n)
for k in range(n):
v[k] = (
e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
* V[j - 1, k]
)
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
v[k] = em_prob * V[j - 1, k]
if k == i:
v[k] *= 1 - r[j] + r_n[j]
else:
Expand All @@ -74,7 +76,10 @@ def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r):
for i in range(n):
v = np.copy(v_tmp)
v[i] += V[j - 1, i] * (1 - r[j])
v *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
v *= em_prob
P[j, i] = np.argmax(v)
V[j, i] = v[P[j, i]]

Expand All @@ -94,10 +99,10 @@ def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r):
# Get the vector to maximise over
v = np.zeros(n)
for k in range(n):
v[k] = (
e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
* V_previous[k]
)
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
v[k] = (em_prob * V_previous[k])
if k == i:
v[k] *= 1 - r[j] + r_n[j]
else:
Expand Down Expand Up @@ -125,10 +130,10 @@ def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r):
# Get the vector to maximise over
v = np.zeros(n)
for k in range(n):
v[k] = (
e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
* V_previous[k]
)
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
v[k] = em_prob * V_previous[k]
if k == i:
v[k] *= 1 - r[j] + r_n[j]
else:
Expand Down Expand Up @@ -161,7 +166,10 @@ def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r):
if V[i] < r_n[j]:
V[i] = r_n[j]
P[j, i] = argmax
V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
V[i] *= em_prob
V_previous = np.copy(V)

ll = np.sum(np.log10(c)) + np.log10(np.max(V))
Expand All @@ -175,7 +183,10 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r):
# Initialise
V = np.zeros(n)
for i in range(n):
V[i] = 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
em_prob = 0
if H[0, i] != NONCOPY:
em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
V[i] = 1 / n * em_prob
P = np.zeros((m, n)).astype(np.int64)
r_n = r / n
c = np.ones(m)
Expand All @@ -190,7 +201,10 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r):
if V[i] < r_n[j]:
V[i] = r_n[j]
P[j, i] = argmax
V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
V[i] *= em_prob

ll = np.sum(np.log10(c)) + np.log10(np.max(V))

Expand All @@ -203,7 +217,10 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r):
# Initialise
V = np.zeros(n)
for i in range(n):
V[i] = 1 / n * e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
em_prob = 0
if H[0, i] != NONCOPY:
em_prob = e[0, np.int64(np.equal(H[0, i], s[0, 0]) or s[0, 0] == MISSING)]
V[i] = 1 / n * em_prob
r_n = r / n
c = np.ones(m)
recombs = [
Expand All @@ -224,7 +241,10 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r):
recombs[j] = np.append(
recombs[j], i
) # We add template i as a potential template to recombine to at site j.
V[i] *= e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
em_prob = 0
if H[j, i] != NONCOPY:
em_prob = e[j, np.int64(np.equal(H[j, i], s[0, j]) or s[0, j] == MISSING)]
V[i] *= em_prob

V_argmaxes[m - 1] = np.argmax(V)
ll = np.sum(np.log10(c)) + np.log10(np.max(V))
Expand Down
Loading
Loading