diff --git a/docs/source/overview/cli.md b/docs/source/overview/cli.md index 25bc86c6..50b6076c 100644 --- a/docs/source/overview/cli.md +++ b/docs/source/overview/cli.md @@ -112,8 +112,8 @@ train_model: !obj:selene_sdk.TrainModel { data_parallel: True, logging_verbosity: 2, metrics: { - roc_auc: !import:sklearn.metrics.roc_auc_score, - average_precision: !import:sklearn.metrics.average_precision_score + roc_auc: !import sklearn.metrics.roc_auc_score, + average_precision: !import sklearn.metrics.average_precision_score }, checkpoint_resume: False } diff --git a/selene_sdk/predict/_common.py b/selene_sdk/predict/_common.py index baa38841..9dd15a7a 100644 --- a/selene_sdk/predict/_common.py +++ b/selene_sdk/predict/_common.py @@ -35,6 +35,34 @@ def get_reverse_complement(allele, complementary_base_dict): return ''.join(list(reversed(a_complement))) +def get_reverse_complement_encoding(allele_encoding, + bases_arr, + complementary_base_dict): + """ + Get the reverse complement of the input allele one-hot encoding. + + Parameters + ---------- + allele_encoding : numpy.ndarray + The sequence allele encoding, :math:`L \\times 4` + bases_arr : list(str) + The base ordering for the one-hot encoding + complementary_base_dict : dict(str: str) + The dictionary that maps each base to its complement + + Returns + ------- + np.ndarray + The reverse complement encoding of the allele, shape + :math:`L \\times 4`. + + """ + base_ixs = {b: i for (i, b) in enumerate(bases_arr)} + complement_indices = [ + base_ixs[complementary_base_dict[b]] for b in bases_arr] + return allele_encoding[:, complement_indices][::-1, :] + + def predict(model, batch_sequences, use_cuda=False): """ Return model predictions for a batch of sequences. diff --git a/selene_sdk/predict/_variant_effect_prediction.py b/selene_sdk/predict/_variant_effect_prediction.py index 1b4da3fc..1c78d04c 100644 --- a/selene_sdk/predict/_variant_effect_prediction.py +++ b/selene_sdk/predict/_variant_effect_prediction.py @@ -26,10 +26,12 @@ def read_vcf_file(input_path, Path to the VCF file. strand_index : int or None, optional Default is None. By default we assume the input sequence - surrounding a variant should be on the forward strand. If your - model is strand-specific, you may want to specify the column number - (0-based) in the VCF file that includes the strand corresponding - to each variant. + surrounding a variant is on the forward strand. If your + model is strand-specific, you may specify a column number + (0-based) in the VCF file that includes strand information. Please + note that variant position, ref, and alt should still be specified + for the forward strand and Selene will apply reverse complement + to this variant. require_strand : bool, optional Default is False. Whether strand can be specified as '.'. If False, Selene accepts strand value to be '+', '-', or '.' and automatically @@ -115,7 +117,7 @@ def read_vcf_file(input_path, continue for a in alt.split(','): variants.append((chrom, pos, name, ref, a, strand)) - + if reference_sequence and seq_context and output_NAs_to_file: with open(output_NAs_to_file, 'w') as file_handle: for na_row in na_rows: @@ -123,20 +125,11 @@ def read_vcf_file(input_path, return variants -def _get_ref_idxs(seq_len, strand, ref_len): - mid = None - if strand == '-': - mid = math.ceil(seq_len / 2) - else: - mid = seq_len // 2 - - start_pos = mid - if strand == '-' and ref_len > 1: - start_pos = mid - (ref_len + 1) // 2 + 1 - elif strand == '+' and ref_len == 1: - start_pos = mid - 1 - elif strand == '+': - start_pos = mid - ref_len // 2 - 1 +def _get_ref_idxs(seq_len, ref_len): + mid = seq_len // 2 + if seq_len % 2 == 0: + mid -= 1 + start_pos = mid - ref_len // 2 end_pos = start_pos + ref_len return (start_pos, end_pos) @@ -147,7 +140,6 @@ def _process_alt(chrom, alt, start, end, - strand, wt_sequence, reference_sequence): """ @@ -168,8 +160,6 @@ def _process_alt(chrom, The start coordinate for genome query end : int The end coordinate for genome query - strand : {'+', '-'} - The strand the variant is on wt_sequence : numpy.ndarray The reference sequence encoding reference_sequence : selene_sdk.sequences.Sequence @@ -194,13 +184,13 @@ def _process_alt(chrom, alt_encoding = reference_sequence.sequence_to_encoding(alt) if ref_len == alt_len: # substitution - start_pos, end_pos = _get_ref_idxs(len(wt_sequence), strand, ref_len) + start_pos, end_pos = _get_ref_idxs(len(wt_sequence), ref_len) sequence = np.vstack([wt_sequence[:start_pos, :], alt_encoding, wt_sequence[end_pos:, :]]) return sequence elif alt_len > ref_len: # insertion - start_pos, end_pos = _get_ref_idxs(len(wt_sequence), strand, ref_len) + start_pos, end_pos = _get_ref_idxs(len(wt_sequence), ref_len) sequence = np.vstack([wt_sequence[:start_pos, :], alt_encoding, wt_sequence[end_pos:, :]]) @@ -213,18 +203,13 @@ def _process_alt(chrom, chrom, start - ref_len // 2 + alt_len // 2, pos + 1, - strand=strand, pad=True) rhs = reference_sequence.get_sequence_from_coords( chrom, pos + 1 + ref_len, end + math.ceil(ref_len / 2.) - math.ceil(alt_len / 2.), - strand=strand, pad=True) - if strand == '-': - sequence = rhs + alt + lhs - else: - sequence = lhs + alt + rhs + sequence = lhs + alt + rhs return reference_sequence.sequence_to_encoding( sequence) @@ -232,11 +217,10 @@ def _process_alt(chrom, def _handle_standard_ref(ref_encoding, seq_encoding, seq_length, - reference_sequence, - strand): + reference_sequence): ref_len = ref_encoding.shape[0] - start_pos, end_pos = _get_ref_idxs(seq_length, strand, ref_len) + start_pos, end_pos = _get_ref_idxs(seq_length, ref_len) sequence_encoding_at_ref = seq_encoding[ start_pos:start_pos + ref_len, :] @@ -256,15 +240,11 @@ def _handle_long_ref(ref_encoding, seq_encoding, start_radius, end_radius, - reference_sequence, - reverse=True): + reference_sequence): ref_len = ref_encoding.shape[0] sequence_encoding_at_ref = seq_encoding - ref_start = ref_len // 2 - start_radius - ref_end = ref_len // 2 + end_radius - if not reverse: - ref_start -= 1 - ref_end -= 1 + ref_start = ref_len // 2 - start_radius - 1 + ref_end = ref_len // 2 + end_radius - 1 ref_encoding = ref_encoding[ref_start:ref_end] references_match = np.array_equal( sequence_encoding_at_ref, ref_encoding) diff --git a/selene_sdk/predict/model_predict.py b/selene_sdk/predict/model_predict.py index a336ad7f..12ada3c4 100644 --- a/selene_sdk/predict/model_predict.py +++ b/selene_sdk/predict/model_predict.py @@ -15,6 +15,7 @@ from ._common import _pad_sequence from ._common import _truncate_sequence from ._common import get_reverse_complement +from ._common import get_reverse_complement_encoding from ._common import predict from ._in_silico_mutagenesis import _ism_sample_id from ._in_silico_mutagenesis import in_silico_mutagenesis_sequences @@ -164,10 +165,10 @@ def __init__(self, self.sequence_length = sequence_length - self._start_radius = int(sequence_length / 2) + self._start_radius = sequence_length // 2 self._end_radius = self._start_radius if sequence_length % 2 != 0: - self._end_radius += 1 + self._start_radius += 1 self.batch_size = batch_size self.features = features @@ -483,7 +484,7 @@ def get_predictions_for_fasta_file(self, ["predictions"], os.path.join(output_dir, output_prefix), output_format, - ["index", "name", "contains_unk"], + ["index", "name"], output_size=len(fasta_file.keys()), mode="prediction")[0] sequences = np.zeros((self.batch_size, @@ -500,7 +501,6 @@ def get_predictions_for_fasta_file(self, elif len(cur_sequence) > self.sequence_length: cur_sequence = _truncate_sequence(cur_sequence, self.sequence_length) - contains_unk = self.reference_sequence.UNK_BASE in cur_sequence cur_sequence_encoding = self.reference_sequence.sequence_to_encoding( cur_sequence) @@ -511,14 +511,9 @@ def get_predictions_for_fasta_file(self, reporter.handle_batch_predictions(preds, batch_ids) batch_ids = [] - batch_ids.append([i, fasta_record.name, contains_unk]) + batch_ids.append([i, fasta_record.name]) sequences[i % self.batch_size, :, :] = cur_sequence_encoding - if contains_unk: - warnings.warn("Sequence ({0},{1}) " - " contains unknown base(s). " - "--will be marked `True` in the `contains_unk` column " - "of the .tsv or the row_labels .txt file.".format( - i, fasta_record.name )) + if (batch_ids and i == 0) or i % self.batch_size != 0: sequences = sequences[:i % self.batch_size + 1, :, :] preds = predict(self.model, sequences, use_cuda=self.use_cuda) @@ -956,41 +951,31 @@ def variant_effect_prediction(self, center = pos + len(ref) // 2 start = center - self._start_radius end = center + self._end_radius - seq_encoding, contains_unk = self.reference_sequence.get_encoding_from_coords_check_unk( - chrom, - start, - end, - strand=strand) - if len(ref) and strand == '-': - ref = get_reverse_complement( - ref, - self.reference_sequence.COMPLEMENTARY_BASE_DICT) - alt = get_reverse_complement( - alt, - self.reference_sequence.COMPLEMENTARY_BASE_DICT) + ref_sequence_encoding, contains_unk = \ + self.reference_sequence.get_encoding_from_coords_check_unk( + chrom, start, end) ref_encoding = self.reference_sequence.sequence_to_encoding(ref) - alt_encoding = _process_alt( - chrom, pos, ref, alt, start, end, strand, - seq_encoding, self.reference_sequence) + alt_sequence_encoding = _process_alt( + chrom, pos, ref, alt, start, end, + ref_sequence_encoding, + self.reference_sequence) match = True seq_at_ref = None if len(ref) and len(ref) < self.sequence_length: - match, seq_encoding, seq_at_ref = _handle_standard_ref( + match, ref_sequence_encoding, seq_at_ref = _handle_standard_ref( ref_encoding, - seq_encoding, + ref_sequence_encoding, self.sequence_length, - self.reference_sequence, - strand) + self.reference_sequence) elif len(ref) >= self.sequence_length: - match, seq_encoding, seq_at_ref = _handle_long_ref( + match, ref_sequence_encoding, seq_at_ref = _handle_long_ref( ref_encoding, - seq_encoding, + ref_sequence_encoding, self._start_radius, self._end_radius, - self.reference_sequence, - strand) + self.reference_sequence) if contains_unk: warnings.warn("For variant ({0}, {1}, {2}, {3}, {4}, {5}), " @@ -1008,8 +993,17 @@ def variant_effect_prediction(self, "column of the .tsv or the row_labels .txt file".format( chrom, pos, name, ref, alt, strand, seq_at_ref)) batch_ids.append((chrom, pos, name, ref, alt, strand, match, contains_unk)) - batch_ref_seqs.append(seq_encoding) - batch_alt_seqs.append(alt_encoding) + if strand == '-': + ref_sequence_encoding = get_reverse_complement_encoding( + ref_sequence_encoding, + self.reference_sequence.BASES_ARR, + self.reference_sequence.COMPLEMENTARY_BASE_DICT) + alt_sequence_encoding = get_reverse_complement_encoding( + alt_sequence_encoding, + self.reference_sequence.BASES_ARR, + self.reference_sequence.COMPLEMENTARY_BASE_DICT) + batch_ref_seqs.append(ref_sequence_encoding) + batch_alt_seqs.append(alt_sequence_encoding) if len(batch_ref_seqs) >= self.batch_size: _handle_ref_alt_predictions( diff --git a/selene_sdk/predict/tests/test__common.py b/selene_sdk/predict/tests/test__common.py new file mode 100644 index 00000000..1a85a903 --- /dev/null +++ b/selene_sdk/predict/tests/test__common.py @@ -0,0 +1,52 @@ +""" +Test methods in the _common methods module +""" +import numpy as np +import unittest + +from selene_sdk.predict._common import get_reverse_complement_encoding +from selene_sdk.sequences import Genome + + +class TestReverseComplement(unittest.TestCase): + + def setUp(self): + self.example_encoding = np.array( + [[0., 0., 0., 1.], + [1., 0., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 1., 0.], + [0.25, 0.25, 0.25, 0.25]] + ) + + def test_rc_encoding_default(self): + observed = get_reverse_complement_encoding( + self.example_encoding, + Genome.BASES_ARR, + Genome.COMPLEMENTARY_BASE_DICT) + expected = [ + [0.25, 0.25, 0.25, 0.25], + [0., 1., 0., 0.], + [0., 0., 1., 0.], + [0., 0., 0., 1.], + [1., 0., 0., 0.], + ] + self.assertEqual(observed.tolist(), expected) + + def test_rc_encoding_nonstandard_ordering(self): + observed = get_reverse_complement_encoding( + self.example_encoding, + ['A', 'T', 'G', 'C'], + Genome.COMPLEMENTARY_BASE_DICT) + expected = [ + [0.25, 0.25, 0.25, 0.25], + [0., 0., 0., 1.], + [1., 0., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 1., 0.], + ] + self.assertEqual(observed.tolist(), expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tutorials/regression_mpra_example/regression_train.yml b/tutorials/regression_mpra_example/regression_train.yml index d0743a7d..20ddd450 100644 --- a/tutorials/regression_mpra_example/regression_train.yml +++ b/tutorials/regression_mpra_example/regression_train.yml @@ -41,7 +41,7 @@ train_model: !obj:selene_sdk.TrainModel { data_parallel: False, logging_verbosity: 2, metrics: { - r2: !import:sklearn.metrics.r2_score + r2: !import sklearn.metrics.r2_score } } output_dir: ./