Skip to content

Commit

Permalink
Update np.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dave-doty committed Mar 28, 2023
1 parent d1c1d37 commit 70ff9c2
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions nuad/np.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def seq2arr(seq: str, base2bits_local: Dict[str, int] | None = None) -> np.ndarr
return np.array([base2bits_local[base] for base in seq], dtype=np.ubyte)


# this was about 5 times slower than the new implementation of `seqs2arr` below
# def seqs2arr_old(seqs: Sequence[str]) -> np.ndarray:
# """Return numpy 2D array converting the given DNA sequences to integers."""
# if len(seqs) == 0:
Expand All @@ -67,7 +68,7 @@ def seqs2arr(seqs: Sequence[str]) -> np.ndarray:
if isinstance(seqs, str):
raise ValueError('seqs must be a sequence of strings, not a single string')

# check equal length (a bit faster than a Python loop,
# check equal length of each sequence (a bit faster than a Python loop,
# e.g., 3.5 ms for 10^5 seqs compared to 5 ms with Python loop)
seq_len = len(seqs[0])
lengths_it = map(len, seqs)
Expand All @@ -81,18 +82,25 @@ def seqs2arr(seqs: Sequence[str]) -> np.ndarray:
seqs_cat = ''.join(seqs)
seqs_cat = seqs_cat.upper()

# arr1d = np.fromstring(seqs_cat_bytes, dtype=np.ubyte)
# arr1d = np.fromstring(seqs_cat_bytes, dtype=np.ubyte) # generates warning about using frombuffer
seqs_cat_bytes = seqs_cat.encode()
seqs_cat_byte_array = bytearray(seqs_cat_bytes)
arr1d = np.frombuffer(seqs_cat_byte_array, dtype=np.ubyte)

arr2d = arr1d.reshape((num_seqs, seq_len))
# code below is somewhat magical to me, but it works: https://stackoverflow.com/a/35464758
from_values = [ord(base) for base in ['A', 'C', 'G', 'T']]
to_values = np.arange(4)
sort_idx = np.argsort(from_values)
idx = np.searchsorted(from_values, arr1d, sorter=sort_idx)
arr1d = to_values[sort_idx][idx]

# this is a bit slower than the code above, e.g., 75 ms compared to 55 ms for 10^5 sequences
# # convert ASCII bytes for 'A', 'C', 'G', 'T' to 0, 1, 2, 3, respectively
# for i, base in enumerate(['A', 'C', 'G', 'T']):
# idxs_with_base = arr2d == ord(base)
# arr2d[idxs_with_base] = i

# convert ASCII bytes for 'A', 'C', 'G', 'T' to 0, 1, 2, 3, respectively
arr2d[arr2d == ord('A')] = 0
arr2d[arr2d == ord('C')] = 1
arr2d[arr2d == ord('G')] = 2
arr2d[arr2d == ord('T')] = 3
arr2d = arr1d.reshape((num_seqs, seq_len))

return arr2d

Expand Down

0 comments on commit 70ff9c2

Please sign in to comment.