Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/overview/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ train_model: !obj:selene_sdk.TrainModel {
data_parallel: True,
logging_verbosity: 2,
metrics: {
roc_auc: !import:sklearn.metrics.roc_auc_score,
average_precision: !import:sklearn.metrics.average_precision_score
roc_auc: !import sklearn.metrics.roc_auc_score,
average_precision: !import sklearn.metrics.average_precision_score
},
checkpoint_resume: False
}
Expand Down
28 changes: 28 additions & 0 deletions selene_sdk/predict/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,34 @@ def get_reverse_complement(allele, complementary_base_dict):
return ''.join(list(reversed(a_complement)))


def get_reverse_complement_encoding(allele_encoding,
bases_arr,
complementary_base_dict):
"""
Get the reverse complement of the input allele one-hot encoding.

Parameters
----------
allele_encoding : numpy.ndarray
The sequence allele encoding, :math:`L \\times 4`
bases_arr : list(str)
The base ordering for the one-hot encoding
complementary_base_dict : dict(str: str)
The dictionary that maps each base to its complement

Returns
-------
np.ndarray
The reverse complement encoding of the allele, shape
:math:`L \\times 4`.

"""
base_ixs = {b: i for (i, b) in enumerate(bases_arr)}
complement_indices = [
base_ixs[complementary_base_dict[b]] for b in bases_arr]
return allele_encoding[:, complement_indices][::-1, :]


def predict(model, batch_sequences, use_cuda=False):
"""
Return model predictions for a batch of sequences.
Expand Down
60 changes: 20 additions & 40 deletions selene_sdk/predict/_variant_effect_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ def read_vcf_file(input_path,
Path to the VCF file.
strand_index : int or None, optional
Default is None. By default we assume the input sequence
surrounding a variant should be on the forward strand. If your
model is strand-specific, you may want to specify the column number
(0-based) in the VCF file that includes the strand corresponding
to each variant.
surrounding a variant is on the forward strand. If your
model is strand-specific, you may specify a column number
(0-based) in the VCF file that includes strand information. Please
note that variant position, ref, and alt should still be specified
for the forward strand and Selene will apply reverse complement
to this variant.
require_strand : bool, optional
Default is False. Whether strand can be specified as '.'. If False,
Selene accepts strand value to be '+', '-', or '.' and automatically
Expand Down Expand Up @@ -115,28 +117,19 @@ def read_vcf_file(input_path,
continue
for a in alt.split(','):
variants.append((chrom, pos, name, ref, a, strand))

if reference_sequence and seq_context and output_NAs_to_file:
with open(output_NAs_to_file, 'w') as file_handle:
for na_row in na_rows:
file_handle.write(na_row)
return variants


def _get_ref_idxs(seq_len, strand, ref_len):
mid = None
if strand == '-':
mid = math.ceil(seq_len / 2)
else:
mid = seq_len // 2

start_pos = mid
if strand == '-' and ref_len > 1:
start_pos = mid - (ref_len + 1) // 2 + 1
elif strand == '+' and ref_len == 1:
start_pos = mid - 1
elif strand == '+':
start_pos = mid - ref_len // 2 - 1
def _get_ref_idxs(seq_len, ref_len):
mid = seq_len // 2
if seq_len % 2 == 0:
mid -= 1
start_pos = mid - ref_len // 2
end_pos = start_pos + ref_len
return (start_pos, end_pos)

Expand All @@ -147,7 +140,6 @@ def _process_alt(chrom,
alt,
start,
end,
strand,
wt_sequence,
reference_sequence):
"""
Expand All @@ -168,8 +160,6 @@ def _process_alt(chrom,
The start coordinate for genome query
end : int
The end coordinate for genome query
strand : {'+', '-'}
The strand the variant is on
wt_sequence : numpy.ndarray
The reference sequence encoding
reference_sequence : selene_sdk.sequences.Sequence
Expand All @@ -194,13 +184,13 @@ def _process_alt(chrom,

alt_encoding = reference_sequence.sequence_to_encoding(alt)
if ref_len == alt_len: # substitution
start_pos, end_pos = _get_ref_idxs(len(wt_sequence), strand, ref_len)
start_pos, end_pos = _get_ref_idxs(len(wt_sequence), ref_len)
sequence = np.vstack([wt_sequence[:start_pos, :],
alt_encoding,
wt_sequence[end_pos:, :]])
return sequence
elif alt_len > ref_len: # insertion
start_pos, end_pos = _get_ref_idxs(len(wt_sequence), strand, ref_len)
start_pos, end_pos = _get_ref_idxs(len(wt_sequence), ref_len)
sequence = np.vstack([wt_sequence[:start_pos, :],
alt_encoding,
wt_sequence[end_pos:, :]])
Expand All @@ -213,30 +203,24 @@ def _process_alt(chrom,
chrom,
start - ref_len // 2 + alt_len // 2,
pos + 1,
strand=strand,
pad=True)
rhs = reference_sequence.get_sequence_from_coords(
chrom,
pos + 1 + ref_len,
end + math.ceil(ref_len / 2.) - math.ceil(alt_len / 2.),
strand=strand,
pad=True)
if strand == '-':
sequence = rhs + alt + lhs
else:
sequence = lhs + alt + rhs
sequence = lhs + alt + rhs
return reference_sequence.sequence_to_encoding(
sequence)


def _handle_standard_ref(ref_encoding,
seq_encoding,
seq_length,
reference_sequence,
strand):
reference_sequence):
ref_len = ref_encoding.shape[0]

start_pos, end_pos = _get_ref_idxs(seq_length, strand, ref_len)
start_pos, end_pos = _get_ref_idxs(seq_length, ref_len)

sequence_encoding_at_ref = seq_encoding[
start_pos:start_pos + ref_len, :]
Expand All @@ -256,15 +240,11 @@ def _handle_long_ref(ref_encoding,
seq_encoding,
start_radius,
end_radius,
reference_sequence,
reverse=True):
reference_sequence):
ref_len = ref_encoding.shape[0]
sequence_encoding_at_ref = seq_encoding
ref_start = ref_len // 2 - start_radius
ref_end = ref_len // 2 + end_radius
if not reverse:
ref_start -= 1
ref_end -= 1
ref_start = ref_len // 2 - start_radius - 1
ref_end = ref_len // 2 + end_radius - 1
ref_encoding = ref_encoding[ref_start:ref_end]
references_match = np.array_equal(
sequence_encoding_at_ref, ref_encoding)
Expand Down
66 changes: 30 additions & 36 deletions selene_sdk/predict/model_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ._common import _pad_sequence
from ._common import _truncate_sequence
from ._common import get_reverse_complement
from ._common import get_reverse_complement_encoding
from ._common import predict
from ._in_silico_mutagenesis import _ism_sample_id
from ._in_silico_mutagenesis import in_silico_mutagenesis_sequences
Expand Down Expand Up @@ -164,10 +165,10 @@ def __init__(self,

self.sequence_length = sequence_length

self._start_radius = int(sequence_length / 2)
self._start_radius = sequence_length // 2
self._end_radius = self._start_radius
if sequence_length % 2 != 0:
self._end_radius += 1
self._start_radius += 1

self.batch_size = batch_size
self.features = features
Expand Down Expand Up @@ -483,7 +484,7 @@ def get_predictions_for_fasta_file(self,
["predictions"],
os.path.join(output_dir, output_prefix),
output_format,
["index", "name", "contains_unk"],
["index", "name"],
output_size=len(fasta_file.keys()),
mode="prediction")[0]
sequences = np.zeros((self.batch_size,
Expand All @@ -500,7 +501,6 @@ def get_predictions_for_fasta_file(self,
elif len(cur_sequence) > self.sequence_length:
cur_sequence = _truncate_sequence(cur_sequence, self.sequence_length)

contains_unk = self.reference_sequence.UNK_BASE in cur_sequence
cur_sequence_encoding = self.reference_sequence.sequence_to_encoding(
cur_sequence)

Expand All @@ -511,14 +511,9 @@ def get_predictions_for_fasta_file(self,
reporter.handle_batch_predictions(preds, batch_ids)
batch_ids = []

batch_ids.append([i, fasta_record.name, contains_unk])
batch_ids.append([i, fasta_record.name])
sequences[i % self.batch_size, :, :] = cur_sequence_encoding
if contains_unk:
warnings.warn("Sequence ({0},{1}) "
" contains unknown base(s). "
"--will be marked `True` in the `contains_unk` column "
"of the .tsv or the row_labels .txt file.".format(
i, fasta_record.name ))

if (batch_ids and i == 0) or i % self.batch_size != 0:
sequences = sequences[:i % self.batch_size + 1, :, :]
preds = predict(self.model, sequences, use_cuda=self.use_cuda)
Expand Down Expand Up @@ -956,41 +951,31 @@ def variant_effect_prediction(self,
center = pos + len(ref) // 2
start = center - self._start_radius
end = center + self._end_radius
seq_encoding, contains_unk = self.reference_sequence.get_encoding_from_coords_check_unk(
chrom,
start,
end,
strand=strand)
if len(ref) and strand == '-':
ref = get_reverse_complement(
ref,
self.reference_sequence.COMPLEMENTARY_BASE_DICT)
alt = get_reverse_complement(
alt,
self.reference_sequence.COMPLEMENTARY_BASE_DICT)
ref_sequence_encoding, contains_unk = \
self.reference_sequence.get_encoding_from_coords_check_unk(
chrom, start, end)

ref_encoding = self.reference_sequence.sequence_to_encoding(ref)
alt_encoding = _process_alt(
chrom, pos, ref, alt, start, end, strand,
seq_encoding, self.reference_sequence)
alt_sequence_encoding = _process_alt(
chrom, pos, ref, alt, start, end,
ref_sequence_encoding,
self.reference_sequence)

match = True
seq_at_ref = None
if len(ref) and len(ref) < self.sequence_length:
match, seq_encoding, seq_at_ref = _handle_standard_ref(
match, ref_sequence_encoding, seq_at_ref = _handle_standard_ref(
ref_encoding,
seq_encoding,
ref_sequence_encoding,
self.sequence_length,
self.reference_sequence,
strand)
self.reference_sequence)
elif len(ref) >= self.sequence_length:
match, seq_encoding, seq_at_ref = _handle_long_ref(
match, ref_sequence_encoding, seq_at_ref = _handle_long_ref(
ref_encoding,
seq_encoding,
ref_sequence_encoding,
self._start_radius,
self._end_radius,
self.reference_sequence,
strand)
self.reference_sequence)

if contains_unk:
warnings.warn("For variant ({0}, {1}, {2}, {3}, {4}, {5}), "
Expand All @@ -1008,8 +993,17 @@ def variant_effect_prediction(self,
"column of the .tsv or the row_labels .txt file".format(
chrom, pos, name, ref, alt, strand, seq_at_ref))
batch_ids.append((chrom, pos, name, ref, alt, strand, match, contains_unk))
batch_ref_seqs.append(seq_encoding)
batch_alt_seqs.append(alt_encoding)
if strand == '-':
ref_sequence_encoding = get_reverse_complement_encoding(
ref_sequence_encoding,
self.reference_sequence.BASES_ARR,
self.reference_sequence.COMPLEMENTARY_BASE_DICT)
alt_sequence_encoding = get_reverse_complement_encoding(
alt_sequence_encoding,
self.reference_sequence.BASES_ARR,
self.reference_sequence.COMPLEMENTARY_BASE_DICT)
batch_ref_seqs.append(ref_sequence_encoding)
batch_alt_seqs.append(alt_sequence_encoding)

if len(batch_ref_seqs) >= self.batch_size:
_handle_ref_alt_predictions(
Expand Down
52 changes: 52 additions & 0 deletions selene_sdk/predict/tests/test__common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Test methods in the _common methods module
"""
import numpy as np
import unittest

from selene_sdk.predict._common import get_reverse_complement_encoding
from selene_sdk.sequences import Genome


class TestReverseComplement(unittest.TestCase):

def setUp(self):
self.example_encoding = np.array(
[[0., 0., 0., 1.],
[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0.25, 0.25, 0.25, 0.25]]
)

def test_rc_encoding_default(self):
observed = get_reverse_complement_encoding(
self.example_encoding,
Genome.BASES_ARR,
Genome.COMPLEMENTARY_BASE_DICT)
expected = [
[0.25, 0.25, 0.25, 0.25],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.],
]
self.assertEqual(observed.tolist(), expected)

def test_rc_encoding_nonstandard_ordering(self):
observed = get_reverse_complement_encoding(
self.example_encoding,
['A', 'T', 'G', 'C'],
Genome.COMPLEMENTARY_BASE_DICT)
expected = [
[0.25, 0.25, 0.25, 0.25],
[0., 0., 0., 1.],
[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
]
self.assertEqual(observed.tolist(), expected)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tutorials/regression_mpra_example/regression_train.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ train_model: !obj:selene_sdk.TrainModel {
data_parallel: False,
logging_verbosity: 2,
metrics: {
r2: !import:sklearn.metrics.r2_score
r2: !import sklearn.metrics.r2_score
}
}
output_dir: ./
Expand Down