diff --git a/.gitignore b/.gitignore index d6aadb5..98fe745 100644 --- a/.gitignore +++ b/.gitignore @@ -181,5 +181,8 @@ cython_debug/ # scripts scripts -### IDE +# IDE .idea + +# notes +notes/ diff --git a/README.md b/README.md index 011a52a..1875a95 100644 --- a/README.md +++ b/README.md @@ -40,8 +40,40 @@ As the header suggests, `setriq` is a no-frills Python package for fast computat a focus on immunoglobulins. It is a declarative framework and borrows many concepts from the popular `torch` library. It has been optimized for parallel compute on CPU architectures. -It can **only** perform pairwise, all-v-all distance computations. This decision was made to maximize consistency and -cohesion. +Available distance functions: +* CDRdist +* Levenshtein +* TCRdist +* Hamming +* Jaro +* Jaro-Winkler + +These distance functions are available either through the object-based API (as seen above), which provides the CPU-based +parallelism, or the functional API in `setriq.single_dispatch`. Unlike the object-based API, the functional API does a +single comparison between two sequences for every call, i.e. it exposes the `C++` distance functions without the +parallelism wrapper. This can be useful for integration of `setriq` with other tools such as `PySpark`. For example: + +```python +from pyspark.sql import SparkSession +from pyspark.sql.functions import udf +from pyspark.sql.types import DoubleType + +from setriq import single_dispatch as sd + +spark = SparkSession \ + .builder \ + .appName("setriq-spark") \ + .getOrCreate() + +df = spark.createDataFrame([('CASSLKPNTEAFF',), ('CASSAHIANYGYTF',), ('CASRGATETQYF',)], ['sequence']) +df = df.withColumnRenamed('sequence', 'a').crossJoin(df.withColumnRenamed('sequence', 'b')) + +lev_udf = udf(sd.levenshtein, returnType=DoubleType()) # single dispatch levenshtein distance +df = df.withColumn('distance', lev_udf('a', 'b')) +df.show() +``` + +It is important to note, that for `setriq.single_dispatch` the returned value is always a single float value. ## Requirements A `Python` version of 3.7 or above is required, as well as a `C++` compiler equipped with OpenMP. The package has been @@ -62,8 +94,13 @@ brew install libomp llvm 1. Dash, P., Fiore-Gartland, A.J., Hertz, T., Wang, G.C., Sharma, S., Souquette, A., Crawford, J.C., Clemens, E.B., Nguyen, T.H., Kedzierska, K. and La Gruta, N.L., 2017. Quantifiable predictive features define epitope-specific T cell receptor repertoires. Nature, 547(7661), pp.89-93. (https://doi.org/10.1038/nature22383) - 2. Levenshtein, V.I., 1966, February. Binary codes capable of correcting deletions, insertions, and reversals. In + 2. Jaro, M.A., 1989. Advances in record-linkage methodology as applied to matching the 1985 census of Tampa, + Florida. Journal of the American Statistical Association, 84(406), pp.414-420. + 3. Levenshtein, V.I., 1966, February. Binary codes capable of correcting deletions, insertions, and reversals. In Soviet physics doklady (Vol. 10, No. 8, pp. 707-710). - 3. python-Levenshtein (https://github.com/ztane/python-Levenshtein) - 4. Thakkar, N. and Bailey-Kellogg, C., 2019. Balancing sensitivity and specificity in distinguishing TCR groups by CDR + 4. python-Levenshtein (https://github.com/ztane/python-Levenshtein) + 5. Thakkar, N. and Bailey-Kellogg, C., 2019. Balancing sensitivity and specificity in distinguishing TCR groups by CDR sequence similarity. BMC bioinformatics, 20(1), pp.1-14. (https://doi.org/10.1186/s12859-019-2864-8) + 6. Van der Loo, M.P., 2014. The stringdist package for approximate string matching. R J., 6(1), p.111. + 7. Winkler, W.E., 1990. String comparator metrics and enhanced decision rules in the Fellegi-Sunter model of record + linkage. diff --git a/include/setriq/metrics/Hamming.h b/include/setriq/metrics/Hamming.h new file mode 100644 index 0000000..02ca095 --- /dev/null +++ b/include/setriq/metrics/Hamming.h @@ -0,0 +1,22 @@ +// +// Created by Benjamin Tenmann on 20/02/2022. +// + +#ifndef SETRIQ_HAMMING_H +#define SETRIQ_HAMMING_H + +#include "utils/type_defs.h" + +namespace metric { + class Hamming { + private: + double mismatch_score_{}; + + public: + explicit Hamming(const double &mismatch_score) : mismatch_score_{mismatch_score} {}; + + double forward(const std::string &a, const std::string &b) const; + }; +} + +#endif //SETRIQ_HAMMING_H diff --git a/include/setriq/metrics/Jaro.h b/include/setriq/metrics/Jaro.h new file mode 100644 index 0000000..32d098b --- /dev/null +++ b/include/setriq/metrics/Jaro.h @@ -0,0 +1,27 @@ +// +// Created by Benjamin Tenmann on 05/03/2022. +// + +#ifndef SETRIQ_JARO_H +#define SETRIQ_JARO_H + +#include +#include "utils/type_defs.h" + +typedef std::array jaro_weighting_t; + +namespace metric { + class Jaro { + private: + jaro_weighting_t weights_ = {1. / 3, 1. / 3, 1. / 3}; + + public: + Jaro() = default; + + explicit Jaro(jaro_weighting_t weights) : weights_(weights) {}; + + double forward(const std::string &a, const std::string &b) const; + }; +} + +#endif //SETRIQ_JARO_H diff --git a/include/setriq/metrics/JaroWinkler.h b/include/setriq/metrics/JaroWinkler.h new file mode 100644 index 0000000..d4952a7 --- /dev/null +++ b/include/setriq/metrics/JaroWinkler.h @@ -0,0 +1,27 @@ +// +// Created by Benjamin Tenmann on 21/02/2022. +// + +#ifndef SETRIQ_JAROWINKLER_H +#define SETRIQ_JAROWINKLER_H + +#include "utils/type_defs.h" +#include "metrics/Jaro.h" + +namespace metric { + class JaroWinkler { + private: + double p_ = 0.; + size_t max_l_ = 4; + Jaro jaro_{}; + + public: + JaroWinkler() = default; + + explicit JaroWinkler(const double &p, const size_t &max_l, Jaro jaro) : p_{p}, max_l_{max_l}, jaro_{jaro} {}; + + double forward(const std::string &a, const std::string &b) const; + }; +} + +#endif //SETRIQ_JAROWINKLER_H diff --git a/src/setriq/_C/metrics/Hamming.cpp b/src/setriq/_C/metrics/Hamming.cpp new file mode 100644 index 0000000..04c7115 --- /dev/null +++ b/src/setriq/_C/metrics/Hamming.cpp @@ -0,0 +1,20 @@ +// +// Created by Benjamin Tenmann on 20/02/2022. +// + +#include "metrics/Hamming.h" + +double metric::Hamming::forward(const std::string &a, const std::string &b) const { + /*! + * Compute the Hamming distance between two input strings. + * + * @param a: an input string to be compared + * @param b: an input string to be compared + */ + auto&& distance = 0.; + for (auto i = 0ul; i < a.size(); i++) { + if (a[i] != b[i]) + distance += this->mismatch_score_; + } + return distance; +} diff --git a/src/setriq/_C/metrics/Jaro.cpp b/src/setriq/_C/metrics/Jaro.cpp new file mode 100644 index 0000000..2629e25 --- /dev/null +++ b/src/setriq/_C/metrics/Jaro.cpp @@ -0,0 +1,81 @@ +// +// Created by Benjamin Tenmann on 05/03/2022. +// + +#include +#include "metrics/Jaro.h" + +#define either_zero(x, y) (x == 0) || (y == 0) +#define max(x, y) x > y ? x : y +#define min(x, y) x > y ? y : x + + +void collapse_into_match_str(const std::string& sequence, const std::vector& matches_idx, char* match_str) { + auto&& j = 0ul; + for (const auto& idx : matches_idx) { + if (idx){ + match_str[j] = sequence[idx - 1]; + j++; + } + } +} + +double metric::Jaro::forward(const std::string &a, const std::string &b) const { + /*! + * Compute the Jaro distance between two input strings. + * Adapted from https://github.com/markvanderloo/stringdist/blob/master/pkg/src/jaro.c + * + * @param a: an input string to be compared + * @param b: an input string to be compared + */ + const auto& s_i = a.size(); + const auto& s_j = b.size(); + if (either_zero(s_i, s_j)) + // if one of the strings is of length 0 and the other isn't, then the distance is maximal (1) + // if both are length 0, then the distance is minimal, i.e. 0 + return (double) ((s_i > 0) || (s_j > 0)); + + const auto& max_len = s_i > s_j ? s_i : s_j; + const auto& max_match_distance = (int) std::floor(max_len / 2) - 1; + if (max_match_distance < 0) + // catch the case when both strings are of length == 1 + return a[0] == b[0] ? 0.0 : 1.0; + + auto&& matches_s_i = std::vector(s_i, 0); + auto&& matches_s_j = std::vector(s_j, 0); + + auto&& n_matches = 0ul; + for (auto i = 0; i < s_i; i++) { + const auto& left = max((i - max_match_distance), 0); + const auto& right = min((i + max_match_distance) + 1, s_j); + // can we collapse this in some way? + for (auto j = left; j < right; j++) { + if ((a[i] == b[j]) && (matches_s_j[j] == 0)) { + n_matches++; + matches_s_i[i] = i + 1; + matches_s_j[j] = j + 1; + break; + } + } + } + if (n_matches == 0) + return 1.0; + + char *match_str_i = new char[n_matches]; + char *match_str_j = new char[n_matches]; + + collapse_into_match_str(a, matches_s_i, match_str_i); + collapse_into_match_str(b, matches_s_j, match_str_j); + + auto&& t = 0.0; + for (auto k = 0ul; k < n_matches; k++) { + if (match_str_i[k] != match_str_j[k]) + t += 0.5; + } + delete []match_str_i; + delete []match_str_j; + + const auto& m = (double) n_matches; + // allow arbitrary weighting + return 1 - (this->weights_[0] * (m / s_i) + this->weights_[1] * (m / s_j) + this->weights_[2] * ((m - t) / m)); +} diff --git a/src/setriq/_C/metrics/JaroWinkler.cpp b/src/setriq/_C/metrics/JaroWinkler.cpp new file mode 100644 index 0000000..5321b4c --- /dev/null +++ b/src/setriq/_C/metrics/JaroWinkler.cpp @@ -0,0 +1,27 @@ +// +// Created by Benjamin Tenmann on 21/02/2022. +// + +#include "metrics/JaroWinkler.h" + +size_t min_sequence_len(const std::string& a, const std::string& b) { + const auto& length_a = a.size(); + const auto& length_b = b.size(); + return length_a < length_b ? length_a : length_b; +} + +double metric::JaroWinkler::forward(const std::string &a, const std::string &b) const { + /*! + * Compute the Jaro-Winkler distance between two input strings. + * + * @param a: an input string to be compared + * @param b: an input string to be compared + */ + const auto& jaro_distance = this->jaro_.forward(a, b); + const auto& min_length = min_sequence_len(a, b); + + auto&& l = 0ul; + while ((a[l] == b[l]) && (l < min_length) && (l < this->max_l_)) + l++; + return jaro_distance * (1 - l * this->p_); +} diff --git a/src/setriq/_C/python_wrapper.cpp b/src/setriq/_C/python_wrapper.cpp index c9347cf..332e827 100644 --- a/src/setriq/_C/python_wrapper.cpp +++ b/src/setriq/_C/python_wrapper.cpp @@ -16,10 +16,14 @@ #include "metrics/CdrDist.h" #include "metrics/Levenshtein.h" #include "metrics/TcrDist.h" +#include "metrics/Hamming.h" +#include "metrics/Jaro.h" +#include "metrics/JaroWinkler.h" #include "utils/type_defs.h" namespace py = pybind11; +// ----- pairwise distances ----------------------------------------------------------------------------------------- // py::list cdr_dist(const string_vector_t& sequences, const double_matrix_t& substitution_matrix, const token_index_map_t& index, @@ -50,9 +54,84 @@ py::list tcr_dist_component(const string_vector_t& sequences, return py::cast(out); } +py::list hamming(const string_vector_t& sequences, const double& mismatch_score) { + metric::Hamming metric {mismatch_score}; + + double_vector_t out = pairwise_distance_computation(metric, sequences); + return py::cast(out); +} + +py::list jaro(const string_vector_t& sequences, const jaro_weighting_t& jaro_weights) { + metric::Jaro metric {jaro_weights}; + + double_vector_t out = pairwise_distance_computation(metric, sequences); + return py::cast(out); +} + +py::list jaro_winkler(const string_vector_t& sequences, + const double& p, + const size_t& max_l, + const jaro_weighting_t& jaro_weights) { + metric::JaroWinkler metric {p, max_l, metric::Jaro{jaro_weights}}; + + double_vector_t out = pairwise_distance_computation(metric, sequences); + return py::cast(out); +} + +// ----- single dispatch -------------------------------------------------------------------------------------------- // +py::float_ cdr_dist_sd(const std::string& a, std::string& b, + const double_matrix_t& substitution_matrix, + const token_index_map_t& index, + const double& gap_opening_penalty, + const double& gap_extension_penalty) { + metric::CdrDist metric {substitution_matrix, index, gap_opening_penalty, gap_extension_penalty}; + double out = metric.forward(a, b); + return py::cast(out); +} + +py::float_ levenshtein_sd(const std::string& a, const std::string& b, const double& extra_cost) { + metric::Levenshtein metric {extra_cost}; + double out = metric.forward(a, b); + return py::cast(out); +} + +py::float_ tcr_dist_component_sd(const std::string& a, const std::string& b, + const double_matrix_t& substitution_matrix, + const token_index_map_t& index, + const double& gap_penalty, + const char& gap_symbol, + const double& distance_weight) { + metric::TcrDist metric {substitution_matrix, index, gap_penalty, gap_symbol, distance_weight}; + double out = metric.forward(a, b); + return py::cast(out); +} + +py::float_ hamming_sd(const std::string& a, const std::string& b, const double& mismatch_score) { + metric::Hamming metric {mismatch_score}; + double out = metric.forward(a, b); + return py::cast(out); +} + +py::float_ jaro_sd(const std::string& a, const std::string& b, const jaro_weighting_t& jaro_weights) { + metric::Jaro metric {jaro_weights}; + double out = metric.forward(a, b); + return py::cast(out); +} + +py::float_ jaro_winkler_sd(const std::string& a, const std::string& b, + const double& p, + const size_t& max_l, + const jaro_weighting_t& jaro_weights) { + metric::JaroWinkler metric {p, max_l, metric::Jaro{jaro_weights}}; + double out = metric.forward(a, b); + return py::cast(out); +} + +// ----- module def ------------------------------------------------------------------------------------------------- // PYBIND11_MODULE(EXTENSION_NAME, m) { m.doc() = "Python module written in C++ for pairwise distance computation for sequences."; + // pairwise m.def("cdr_dist", &cdr_dist, "Compute the pairwise CDR-dist metric for a set of CDR3 sequences.", py::arg("sequences"), py::arg("substitution_matrix"), py::arg("index"), py::arg("gap_opening_penalty"), py::arg("gap_extension_penalty")); @@ -64,6 +143,36 @@ PYBIND11_MODULE(EXTENSION_NAME, m) { py::arg("sequences"), py::arg("substitution_matrix"), py::arg("index"), py::arg("gap_penalty"), py::arg("gap_symbol"), py::arg("weight")); + m.def("hamming", &hamming, "Compute pairwise Hamming distance for a set of sequences.", + py::arg("sequences"), py::arg("mismatch_score")); + + m.def("jaro", &jaro, "Compute pairwise Jaro distance for a set of sequences.", + py::arg("sequences"), py::arg("jaro_weights")); + + m.def("jaro_winkler", &jaro_winkler, "Compute pairwise Jaro-Winkler distance for a set of sequences.", + py::arg("sequences"), py::arg("p"), py::arg("max_l"), py::arg("jaro_weights")); + + // single dispatch + m.def("cdr_dist_sd", &cdr_dist_sd, "Compute the CDR-dist metric between two CDR3 sequences.", + py::arg("a"), py::arg("b"), py::arg("substitution_matrix"), py::arg("index"), + py::arg("gap_opening_penalty"), py::arg("gap_extension_penalty")); + + m.def("levenshtein_sd", &levenshtein_sd, "Compute the Levenshtein distance between two sequences.", + py::arg("a"), py::arg("b"), py::arg("extra_cost")); + + m.def("tcr_dist_component_sd", &tcr_dist_component_sd, "Compute TCR-dist between two TCR components.", + py::arg("a"), py::arg("b"), py::arg("substitution_matrix"), py::arg("index"), + py::arg("gap_penalty"), py::arg("gap_symbol"), py::arg("weight")); + + m.def("hamming_sd", &hamming_sd, "Compute the Hamming distance between two sequences.", + py::arg("a"), py::arg("b"), py::arg("mismatch_score")); + + m.def("jaro_sd", &jaro_sd, "Compute the Jaro distance between two sequences.", + py::arg("a"), py::arg("b"), py::arg("jaro_weights")); + + m.def("jaro_winkler_sd", &jaro_winkler_sd, "Compute the Jaro-Winkler distance between two sequences.", + py::arg("a"), py::arg("b"), py::arg("p"), py::arg("max_l"), py::arg("jaro_weights")); + #ifdef VERSION_INFO m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO); #else diff --git a/src/setriq/__init__.py b/src/setriq/__init__.py index 393c393..ab778e0 100644 --- a/src/setriq/__init__.py +++ b/src/setriq/__init__.py @@ -39,6 +39,7 @@ """ from .modules import ( - CdrDist, Levenshtein, TcrDist, - SubstitutionMatrix, BLOSUM45, BLOSUM62, BLOSUM90 + CdrDist, Levenshtein, TcrDist, Hamming, Jaro, JaroWinkler, + SubstitutionMatrix, BLOSUM45, BLOSUM62, BLOSUM90, + single_dispatch ) diff --git a/src/setriq/modules/__init__.py b/src/setriq/modules/__init__.py index 738c731..a70c2ff 100644 --- a/src/setriq/modules/__init__.py +++ b/src/setriq/modules/__init__.py @@ -22,9 +22,14 @@ from .distances import ( CdrDist, Levenshtein, - TcrDist + TcrDist, + Hamming, + Jaro, + JaroWinkler ) +from . import single_dispatch + __all__ = [ 'SubstitutionMatrix', 'BLOSUM45', @@ -33,4 +38,8 @@ 'CdrDist', 'Levenshtein', 'TcrDist', + 'Hamming', + 'Jaro', + 'JaroWinkler', + 'single_dispatch' ] diff --git a/src/setriq/modules/distances.py b/src/setriq/modules/distances.py index c84a387..c3d9f2f 100644 --- a/src/setriq/modules/distances.py +++ b/src/setriq/modules/distances.py @@ -17,12 +17,22 @@ BLOSUM62, SubstitutionMatrix ) +from .utils import ( + enforce_list, + ensure_equal_sequence_length, + check_jaro_weights, + check_jaro_winkler_params, + TCR_DIST_DEFAULT +) __all__ = [ 'CdrDist', 'Levenshtein', 'TcrDist', 'TcrDistComponent', + 'Hamming', + 'Jaro', + 'JaroWinkler' ] @@ -43,6 +53,7 @@ class Metric(abc.ABC): def forward(self, *args, **kwargs): pass + @enforce_list(argnum=1, convert_iterable=True) def __call__(self, *args, **kwargs): out = self.forward(*args, **kwargs) @@ -151,9 +162,8 @@ def __init__(self, } self.fn = C.tcr_dist_component + @ensure_equal_sequence_length(argnum=1) def forward(self, sequences: List[str]) -> List[float]: - if not (len(sequences[0]) == pd.Series(sequences).str.len()).all(): - raise ValueError('Sequences must be of equal length') out = self.fn(sequences, **self.call_args) return out @@ -185,12 +195,7 @@ class TcrDist(Metric): Nguyen, T.H., Kedzierska, K. and La Gruta, N.L., 2017. Quantifiable predictive features define epitope-specific T cell receptor repertoires. Nature, 547(7661), pp.89-93. (https://doi.org/10.1038/nature22383) """ - _default = [ - ('cdr_1', {'substitution_matrix': BLOSUM62, 'gap_penalty': 4., 'weight': 1.}), - ('cdr_2', {'substitution_matrix': BLOSUM62, 'gap_penalty': 4., 'weight': 1.}), - ('cdr_2_5', {'substitution_matrix': BLOSUM62, 'gap_penalty': 4., 'weight': 1.}), - ('cdr_3', {'substitution_matrix': BLOSUM62, 'gap_penalty': 8., 'weight': 3.}) - ] + _default = TCR_DIST_DEFAULT _default_msg = ( 'TcrDist has been initialized using the default configuration. ' 'Please ensure that the input is a list of dictionaries, each with keys: {}' @@ -302,3 +307,82 @@ def forward(self, sequences: List[Dict[str, str]]) -> List[float]: # aggregate the component outputs out = np.array(out).sum(axis=0) return out.tolist() + + +class Hamming(Metric): + """ + Hamming distance class. Inherits from Metric. Sequences must be of equal length. + + Examples + -------- + >>> metric = Hamming(mismatch_score=2.0) + >>> sequences = ['CASSLKPNTEAFF', 'CASSAHIANYGYTF', 'CASRGATETQYF'] + >>> distances = metric(sequences) + + References + ---------- + [1] ... + """ + # TODO: add reference + def __init__(self, mismatch_score: float = 1.0): + self.call_args = { + 'mismatch_score': mismatch_score + } + self.fn = C.hamming + + @ensure_equal_sequence_length(argnum=1) + def forward(self, sequences: List[str]) -> List[float]: + out = self.fn(sequences, **self.call_args) + return out + + +class Jaro(Metric): + """ + Jaro distance class. Inherits from Metric. Adapted from [2]. + + Examples + -------- + >>> metric = Jaro() + >>> sequences = ['CASSLKPNTEAFF', 'CASSAHIANYGYTF', 'CASRGATETQYF'] + >>> distances = metric(sequences) + + References + ---------- + [1] Jaro, M.A., 1989. Advances in record-linkage methodology as applied to matching the 1985 census of Tampa, + Florida. Journal of the American Statistical Association, 84(406), pp.414-420. + [2] Van der Loo, M.P., 2014. The stringdist package for approximate string matching. R J., 6(1), p.111. + """ + @check_jaro_weights + def __init__(self, jaro_weights: List[float] = None): + self.call_args = { + 'jaro_weights': jaro_weights + } + self.fn = C.jaro + + def forward(self, sequences: List[str]) -> List[float]: + out = self.fn(sequences, **self.call_args) + return out + + +class JaroWinkler(Jaro): + """ + Jaro-Winkler distance class. Inherits from Jaro. + + Examples + -------- + >>> metric = JaroWinkler(p=0.10) + >>> sequences = ['CASSLKPNTEAFF', 'CASSAHIANYGYTF', 'CASRGATETQYF'] + >>> distances = metric(sequences) + + References + ---------- + [1] Winkler, W.E., 1990. String comparator metrics and enhanced decision rules in the Fellegi-Sunter model of record + linkage. + + """ + @check_jaro_winkler_params + def __init__(self, p: float, max_l: int = 4, jaro_weights: List[float] = None): + super(JaroWinkler, self).__init__(jaro_weights) + self.call_args['p'] = p + self.call_args['max_l'] = max_l + self.fn = C.jaro_winkler diff --git a/src/setriq/modules/single_dispatch.py b/src/setriq/modules/single_dispatch.py new file mode 100644 index 0000000..1bdc044 --- /dev/null +++ b/src/setriq/modules/single_dispatch.py @@ -0,0 +1,254 @@ +""" +single_dispatch +=============== + +Functional implementations of the normal Metric classes, with the difference that the comparison happens only between +two strings. This can be useful for integration with other tools such as PySpark, where we want to access the fast +distance function implementations, without the forced pairwise comparisons. + +Examples +-------- +An example for computing the pairwise sequence distances using PySpark and setriq: +>>> from pyspark.sql import SparkSession +>>> from pyspark.sql.functions import udf +>>> from pyspark.sql.types import DoubleType +>>> +>>> spark = SparkSession \ +... .builder \ +... .appName("setriq-spark") \ +... .getOrCreate() +>>> +>>> df = spark.createDataFrame([('CASSLKPNTEAFF',), ('CASSAHIANYGYTF',), ('CASRGATETQYF',)], ['sequence']) +>>> df = df.withColumnRenamed('sequence', 'a').crossJoin(df.withColumnRenamed('sequence', 'b')) +>>> +>>> lev_udf = udf(levenshtein, returnType=DoubleType()) # single dispatch levenshtein distance +>>> df = df.withColumn('distance', lev_udf('a', 'b')) +>>> df.show() ++--------------+--------------+--------+ +| a| b|distance| ++--------------+--------------+--------+ +| CASSLKPNTEAFF| CASSLKPNTEAFF| 0.0| +| CASSLKPNTEAFF|CASSAHIANYGYTF| 8.0| +| CASSLKPNTEAFF| CASRGATETQYF| 8.0| +|CASSAHIANYGYTF| CASSLKPNTEAFF| 8.0| +|CASSAHIANYGYTF|CASSAHIANYGYTF| 0.0| +|CASSAHIANYGYTF| CASRGATETQYF| 9.0| +| CASRGATETQYF| CASSLKPNTEAFF| 8.0| +| CASRGATETQYF|CASSAHIANYGYTF| 9.0| +| CASRGATETQYF| CASRGATETQYF| 0.0| ++--------------+--------------+--------+ + +""" + +from typing import List + +import setriq._C as C +from .substitution import SubstitutionMatrix, BLOSUM45 +from .utils import ( + ensure_equal_sequence_length_sd, + single_dispatch, + tcr_dist_sd_component_check, + check_jaro_weights, + check_jaro_winkler_params +) + +__all__ = [ + 'cdr_dist', + 'levenshtein', + 'tcr_dist', + 'hamming', + 'jaro', + 'jaro_winkler' +] + + +@single_dispatch +def cdr_dist(a: str, b: str, + substitution_matrix: SubstitutionMatrix = BLOSUM45, + gap_opening_penalty: float = 10.0, + gap_extension_penalty: float = 1.0) -> float: + """ + Compute the CDRdist [1] metric between two sequences. + + {params} + substitution_matrix: SubstitutionMatrix + A substitution matrix object to inform the alignment scoring. + gap_opening_penalty: float + The penalty given to an alignment based on a gap opening. Values other than 0 give affine gap scoring. + (default=10.0) + gap_extension_penalty: float + The penalty used to score the extension of the gap after opening. (default=1.0) + + {returns} + + Examples + -------- + >>> seq = ('AASQ', 'PASQ') + >>> cdr_dist(*seq) # default params + >>> cdr_dist(*seq, substitution_matrix=BLOSUM45, # custom params + ... gap_opening_penalty=5.0, gap_extension_penalty=2.0) + + References + ---------- + [1] Thakkar, N. and Bailey-Kellogg, C., 2019. Balancing sensitivity and specificity in distinguishing TCR groups by + CDR sequence similarity. BMC bioinformatics, 20(1), pp.1-14. (https://doi.org/10.1186/s12859-019-2864-8) + + """ + distance = C.cdr_dist_sd(a, b, **substitution_matrix, + gap_opening_penalty=gap_opening_penalty, + gap_extension_penalty=gap_extension_penalty) + return distance + + +@single_dispatch +def levenshtein(a: str, b: str, extra_cost: float = 0.0) -> float: + """ + Compute the Levenshtein distance [1] between two sequences. Based on the implementation in [2]. + + {params} + extra_cost: float + Additional cost assigned by Levenshtein algorithm. + + {returns} + + Examples + -------- + >>> levenshtein('AASQ', 'PASQ') + + References + ---------- + [1] Levenshtein, V.I., 1966, February. Binary codes capable of correcting deletions, insertions, and reversals. In + Soviet physics doklady (Vol. 10, No. 8, pp. 707-710). + [2] python-Levenshtein (https://github.com/ztane/python-Levenshtein) + + """ + distance = C.levenshtein_sd(a, b, extra_cost=extra_cost) + return distance + + +@single_dispatch +@ensure_equal_sequence_length_sd +def tcr_dist_component(a: str, b: str, + substitution_matrix: SubstitutionMatrix, + gap_penalty: float, + gap_symbol: str = '-', + weight: float = 1.) -> float: + distance = C.tcr_dist_component_sd(a, b, + **substitution_matrix, + gap_penalty=gap_penalty, + gap_symbol=gap_symbol, + weight=weight) + return distance + + +@tcr_dist_sd_component_check +def tcr_dist(a: dict, b: dict, **component_def) -> float: + """ + Compute the TCRdist metric between two sequences. + + Parameters + ---------- + a: dict + + b: dict + component_def + + Returns + ------- + distance: float + The distance between the two sequences. + + References + ---------- + [1] Dash, P., Fiore-Gartland, A.J., Hertz, T., Wang, G.C., Sharma, S., Souquette, A., Crawford, J.C., Clemens, E.B., + Nguyen, T.H., Kedzierska, K. and La Gruta, N.L., 2017. Quantifiable predictive features define + epitope-specific T cell receptor repertoires. Nature, 547(7661), pp.89-93. (https://doi.org/10.1038/nature22383) + + """ + distance = 0.0 + for name, component in component_def.items(): + distance += tcr_dist_component(a[name], b[name], **component_def[name]) + return distance + + +@single_dispatch +@ensure_equal_sequence_length_sd +def hamming(a: str, b: str, mismatch_score: float = 1.0) -> float: + """ + Compute the Hamming [1] distance between two sequences. + + {params} + mismatch_score: float + The weight given to a mismatch between the two sequence positions. + + {returns} + + Examples + -------- + >>> hamming('AASQ', 'PASQ') + >>> hamming('AASQ', 'PAS') # error! different length sequences + + References + ---------- + [1] ... + + """ + distance = C.hamming_sd(a, b, mismatch_score=mismatch_score) + return distance + + +@single_dispatch +@check_jaro_weights +def jaro(a: str, b: str, jaro_weights: List[float] = None) -> float: + """ + Compute the Jaro [1] distance between two sequences. Adapted from [2]. + + {params} + jaro_weights: List[float] + + {returns} + + Examples + -------- + >>> jaro('AASQ', 'PASQ') + + References + ---------- + [1] Jaro, M.A., 1989. Advances in record-linkage methodology as applied to matching the 1985 census of Tampa, + Florida. Journal of the American Statistical Association, 84(406), pp.414-420. + [2] Van der Loo, M.P., 2014. The stringdist package for approximate string matching. R J., 6(1), p.111. + + """ + distance = C.jaro_sd(a, b, jaro_weights=jaro_weights) + return distance + + +@single_dispatch +@check_jaro_weights +@check_jaro_winkler_params +def jaro_winkler(a: str, b: str, p: float, max_l: int = 4, jaro_weights: List[float] = None) -> float: + """ + Compute the Jaro-Winkler [1] distance between two sequences. + + {params} + p: float + The scaling factor applied to the common prefix re-weighting. The value needs to be in the range [0.0, 0.25]. If + set to 0.0, Jaro-Winkler reduces down to Jaro. + max_l: int + The maximum length common prefix. (default=4) + jaro_weights: List[float] + + {returns} + + Examples + -------- + >>> jaro_winkler('AASQ', 'PASQ', p=0.10) + + References + ---------- + [1] Winkler, W.E., 1990. String comparator metrics and enhanced decision rules in the Fellegi-Sunter model of record + linkage. + + """ + distance = C.jaro_winkler_sd(a, b, p=p, max_l=max_l, jaro_weights=jaro_weights) + return distance diff --git a/src/setriq/modules/utils.py b/src/setriq/modules/utils.py new file mode 100644 index 0000000..28aad7f --- /dev/null +++ b/src/setriq/modules/utils.py @@ -0,0 +1,350 @@ +""" +Package utilities. Not meant for outside use. +""" + +import enum +import inspect +from functools import wraps, WRAPPER_ASSIGNMENTS +from typing import Callable, Iterable + +from .substitution import SubstitutionMatrix, BLOSUM62 + +__all__ = [ + 'enforce_list', + 'ensure_equal_sequence_length', + 'single_dispatch', + 'tcr_dist_sd_component_check', + 'ensure_equal_sequence_length_sd', + 'check_jaro_weights', + 'check_jaro_winkler_params', + 'TCR_DIST_DEFAULT' +] + +WRAPPER_ASSIGNMENTS = (*WRAPPER_ASSIGNMENTS, '__signature__') +TCR_DIST_DEFAULT = [ + ('cdr_1', {'substitution_matrix': BLOSUM62, 'gap_penalty': 4., 'weight': 1.}), + ('cdr_2', {'substitution_matrix': BLOSUM62, 'gap_penalty': 4., 'weight': 1.}), + ('cdr_2_5', {'substitution_matrix': BLOSUM62, 'gap_penalty': 4., 'weight': 1.}), + ('cdr_3', {'substitution_matrix': BLOSUM62, 'gap_penalty': 8., 'weight': 3.}) +] + + +class Argument(enum.Enum): + POSITIONAL = 0 + KEYWORD = 1 + DEFAULT = 2 + + +def _get_argument_index_from_argnum(n_params: int, argnum: int): + # updates the argument index provided -- fixes issues with `_get_argument` when `argnum` < 0 (negative indexing) + if argnum < 0: + return argnum + n_params + return argnum + + +def _get_argument(params, argname: str, argnum: int, args: tuple, kwargs: dict): + # convenience function for getting a specific argument from all the args and kwargs passed to an arbitrary function + # in Python, any argument can be passed as either positional or keyword (except if explicitly encoded otherwise) + # and so we need to account for this, as well as the default case, when the argument is not passed + if argname in kwargs: + # we also return where we found the argument, i.e. the enum describes the argument type. This helps us + # downstream, if we want to put an augmented argument back into args, kwargs + return kwargs[argname], Argument.KEYWORD + if len(args) > argnum: + return args[argnum], Argument.POSITIONAL + return params[argname].default, Argument.DEFAULT + + +def _put_argument(params, argname: str, argnum: int, arg_type: enum.Enum, argval, args: tuple, kwargs: dict): + # this function is the complement to `_get_argument` + if arg_type == Argument.POSITIONAL: + args = tuple(arg if argnum != idx else argval for idx, arg in enumerate(args)) + elif arg_type == Argument.KEYWORD: + kwargs[argname] = argval + elif arg_type == Argument.DEFAULT: + if params[argname].default != argval: + kwargs[argname] = argval + return args, kwargs + + +def _get_func_argument_info(fn: Callable, argnum: int): + signature = inspect.signature(fn) + params = signature.parameters + argname = list(params)[argnum] + + return signature, params, argname + + +def _get_func_argument_info_from_name(fn: Callable, argname: str): + signature = inspect.signature(fn) + params = signature.parameters + argnum = list(params).index(argname) + + return signature, params, argnum + + +def _add_func_signature(fn: Callable, signature: inspect.Signature) -> Callable: + if not hasattr(fn, '__signature__'): + fn.__signature__ = None + fn.__signature__ = (fn.__signature__ or signature) + return fn + + +def enforce_list(argnum: int = 0, convert_iterable: bool = True): + """ + Enforce that a specified argument is always passed as a list to a given function. This is a decorator factory. + + Parameters + ---------- + argnum: int + The positional (integer) index of the argument to be forced into list format. Works like regular positional + indexing. (default=0) + convert_iterable: bool + Boolean defining whether to force convert any iterable (except `str`) into a list. + + Returns + ------- + decorator: Callable + A new decorator function, which can be used to decorate an arbitrary function / method. + + Examples + -------- + A basic example: + >>> @enforce_list() + ... def f(x): + ... return x * 3 + >>> + >>> f([3]) + ... [3, 3, 3] + >>> f(3) + ... [3, 3, 3] + Notice, that we ommit the `argnum` parameter here. This is because `argnum` is 0 by default, i.e. it looks at the + first argument passed to `f`. Note, that `enforce_list` needs to be called before decorating a function. + + Enforce list can also be composed arbitrarily, to enforce multiple arguments to be lists: + >>> @enforce_list(argnum=0) + ... @enforce_list(argnum=1) + ... def f(x, y): + ... return x + y + + Finally, `enforce_list` can also force convert other iterables (excluding str) into lists: + >>> @enforce_list(argnum=0, convert_iterable=True) + ... def f(x): + ... return x * 3 + >>> + >>> x = np.array([1.]) + >>> f(x) + ... [1., 1., 1.] + + """ + if callable(argnum): + raise TypeError(f'Make sure to call {repr(enforce_list.__name__)} before decorating.') + + def decorator(fn): + signature, params, argname = _get_func_argument_info(fn, argnum) + argidx = _get_argument_index_from_argnum(len(params), argnum) + fn = _add_func_signature(fn, signature) + + @wraps(fn, assigned=WRAPPER_ASSIGNMENTS) + def _fn(*args, **kwargs): + argument, arg_type = _get_argument(params, argname, argidx, args, kwargs) + if isinstance(argument, Iterable) and not isinstance(argument, str): + if not isinstance(argument, list) and convert_iterable: + argument = list(argument) + else: + argument = [argname] + args, kwargs = _put_argument(params, argname, argidx, arg_type, argument, args, kwargs) + out = fn(*args, **kwargs) + return out + + return _fn + + return decorator + + +def ensure_equal_sequence_length(argnum: int): + """ + Ensure that all input sequences are of equal length. + + Parameters + ---------- + argnum: int + The positional (integer) index of the argument to be forced into list format. Works like regular positional + indexing. (default=0) + + Returns + ------- + decorator: Callable + The decorator used to wrap the function where all sequences ought to be of equal length. + + Examples + -------- + >>> @ensure_equal_sequence_length(argnum=0) + ... def f(sequences): + ... return sequences + >>> + >>> a = ['AASQ', 'PWSQ'] # sequences of equal length + >>> b = ['GAT', 'AAFFD'] # sequences with varying length + >>> + >>> f(a) # no error + >>> f(b) # error! + """ + import pandas as pd + + def decorator(fn): + signature, params, argname = _get_func_argument_info(fn, argnum) + argidx = _get_argument_index_from_argnum(len(params), argnum) + fn = _add_func_signature(fn, signature) + + @wraps(fn, assigned=WRAPPER_ASSIGNMENTS) + def _fn(*args, **kwargs): + argument, _ = _get_argument(params, argname, argidx, args, kwargs) + if not (len(argument[0]) == pd.Series(argument).str.len()).all(): + raise ValueError('Sequences must be of equal length') + out = fn(*args, **kwargs) + return out + + return _fn + + return decorator + + +def ensure_equal_sequence_length_sd(fn: Callable) -> Callable: + signature = inspect.signature(fn) + fn = _add_func_signature(fn, signature) + + add_doc = """Note + ---- + `a` and `b` must be of equal length. + """ + + fn.__doc__ = (fn.__doc__ or '') + add_doc + + @wraps(fn, assigned=WRAPPER_ASSIGNMENTS) + def _fn(a, b, *args, **kwargs): + if len(a) != len(b): + raise ValueError('Sequences must be of equal length') + out = fn(a, b, *args, **kwargs) + return out + + return _fn + + +def single_dispatch(fn: Callable) -> Callable: + signature = inspect.signature(fn) + fn = _add_func_signature(fn, signature) + + empty_doc = f""" + Compute the `{fn.__name__}` metric between two sequences. + + {{params}} + + {{returns}} + """ + + param_doc = """ + Parameters + ---------- + a: str + A sequence to be compared. + b: str + A sequence to be compared.""" + + return_doc = """Returns + ------- + distance: float + The computed distance between the two sequences.""" + + fn.__doc__ = (fn.__doc__ or empty_doc).format(params=param_doc, returns=return_doc) + + @wraps(fn, assigned=WRAPPER_ASSIGNMENTS) + def _fn(a, b, *args, **kwargs): + if any(not isinstance(sequence, str) for sequence in (a, b)): + raise TypeError(f'`a` and `b` must be of type str') + out = fn(a, b, *args, **kwargs) + return out + + return _fn + + +def tcr_dist_sd_component_check(fn): + # wrapper specifically used for the single dispatch tcr_dist function to check inputs (components) + import os + + signature = inspect.signature(fn) + fn = _add_func_signature(fn, signature) + + def _check_component(name, component: dict): + essential_keys = ['substitution_matrix', 'gap_penalty'] + optional_keys = ['gap_symbol', 'weight'] + missing_keys = set(essential_keys).difference(component) + if missing_keys: + msg = ', '.join(map(repr, missing_keys)) + raise ValueError(f'missing keys in component def {repr(name)}: {msg}') + + init_types = [SubstitutionMatrix, (float, int), str, (float, int)] + for key, _type in zip(essential_keys + optional_keys, init_types): + elem = component.get(key) + if elem is not None and not isinstance(elem, _type): + given_type = type(elem) + raise TypeError(f'{repr(key)} needs to be of type {repr(_type)}, not {repr(given_type)}') + + @wraps(fn, assigned=WRAPPER_ASSIGNMENTS) + def _fn(a, b, **component_def): + if not os.environ.get('SKIP_TCR_DIST_COMPONENT_CHECK'): + for name, component in component_def.items(): + _check_component(name, component) + if not component_def: + component_def = dict(TCR_DIST_DEFAULT) + if set(component_def).difference(set(a).union(b)): + raise ValueError(f'key mismatch between payloads (`a` and `b`) and defined components.') + out = fn(a, b, **component_def) + return out + + return _fn + + +def check_jaro_weights(fn: Callable): + # checks that jaro weights are sensibly defined + argname = 'jaro_weights' + signature, params, argnum = _get_func_argument_info_from_name(fn, argname) + fn = _add_func_signature(fn, signature) + + @wraps(fn, assigned=WRAPPER_ASSIGNMENTS) + def _fn(*args, **kwargs): + jaro_weights, arg_type = _get_argument(params, argname, argnum, args, kwargs) + if jaro_weights is None: + jaro_weights = [1 / 3] * 3 + if len(jaro_weights) != 3: + raise ValueError('`jaro_weights` has to be of length 3') + if sum(jaro_weights) != 1.0: + raise ValueError('`jaro_weights` has to sum to 1.0') + + args, kwargs = _put_argument(params, argname, argnum, arg_type, jaro_weights, args, kwargs) + out = fn(*args, **kwargs) + return out + + return _fn + + +def check_jaro_winkler_params(fn: Callable): + # checks that Jaro-Winkler parameters are sensibly defined + argname_p = 'p' + argname_l = 'max_l' + signature, params, argnum_p = _get_func_argument_info_from_name(fn, argname_p) + _, _, argnum_l = _get_func_argument_info_from_name(fn, argname_l) + fn = _add_func_signature(fn, signature) + + @wraps(fn, assigned=WRAPPER_ASSIGNMENTS) + def _fn(*args, **kwargs): + p, _ = _get_argument(params, argname_p, argnum_p, args, kwargs) + max_l, _ = _get_argument(params, argname_l, argnum_l, args, kwargs) + + if not (0.0 <= p <= 0.25): + raise ValueError('`p` must be in range [0.0, 0.25]') + if max_l < 0 or not isinstance(max_l, int): + raise ValueError('`max_l` must be a non-negative integer') + out = fn(*args, **kwargs) + return out + + return _fn diff --git a/tests/test_distances.py b/tests/test_distances.py index bb9a975..685be1f 100644 --- a/tests/test_distances.py +++ b/tests/test_distances.py @@ -38,6 +38,20 @@ [dc.Decimal('0.0')], ] +hamming_results = [ + [dc.Decimal('1.0')], + [dc.Decimal('2.0'), dc.Decimal('3.0'), dc.Decimal('3.0')], + [dc.Decimal('0.0')], +] + +jaro_results = [ + [dc.Decimal('0.1667')], + [dc.Decimal('0.4444'), dc.Decimal('1.0'), dc.Decimal('1.0')], + [dc.Decimal('0.0')], +] + +jaro_winkler_results = jaro_results # change this in the future + # ------ Fixtures ---------------------------------------------------------------------------------------------------- # @pytest.fixture() @@ -204,3 +218,39 @@ def test_tcr_dist_custom_error(): butterflies = {'wings': 'beat'} with pytest.raises(TypeError): setriq.TcrDist(rainbows=rainbows, butterflies=butterflies) + + +@pytest.mark.parametrize(['sequences', 'distances'], zip(test_cases, hamming_results)) +def test_hamming(sequences, distances): + metric = setriq.Hamming() + response = metric(sequences) + + n = len(sequences) + assert len(response) == (n * (n - 1) / 2) + + res = response_to_decimal(response) + assert all(r == tgt for r, tgt in zip(res, distances)) + + +@pytest.mark.parametrize(['sequences', 'distances'], zip(test_cases, jaro_results)) +def test_jaro(sequences, distances): + metric = setriq.Jaro() + response = metric(sequences) + + n = len(sequences) + assert len(response) == (n * (n - 1) / 2) + + res = response_to_decimal(response) + assert all(r == tgt for r, tgt in zip(res, distances)) + + +@pytest.mark.parametrize(['sequences', 'distances'], zip(test_cases, jaro_winkler_results)) +def test_jaro_winkler(sequences, distances): + metric = setriq.JaroWinkler(p=0.10) + response = metric(sequences) + + n = len(sequences) + assert len(response) == (n * (n - 1) / 2) + + res = response_to_decimal(response) + assert all(r == tgt for r, tgt in zip(res, distances)) diff --git a/tests/test_single_dispatch.py b/tests/test_single_dispatch.py new file mode 100644 index 0000000..2addc1d --- /dev/null +++ b/tests/test_single_dispatch.py @@ -0,0 +1,88 @@ +import pytest + +from setriq import BLOSUM62 +from setriq import single_dispatch + + +class Cases: + SEQUENCES = [ + ('CASSLKPNTEAFF', 'CASSAHIANYGYTF'), + ('CASSLKPNTEAFF', 'CASRGATETQYF'), + ('CASSAHIANYGYTF', 'CASRGATETQYF') + ] + EQUAL_SEQUENCE_LENGTH = [ + ('AASQ', 'PASQ'), + ('GTA', 'HLA'), + ('SEQVENCES', 'SEQVENCES') + ] + + +class Results: + CDR_DIST = [ + 0.7121679380884349, + 0.6498905737037513, + 0.75209911355881 + ] + LEVENSHTEIN = [ + 8.0, + 8.0, + 9.0 + ] + TCR_DIST = [ + 4.0, + 8.0, + 0.0 + ] + HAMMING = [ + 1.0, + 2.0, + 0.0 + ] + JARO = [ + 0.3335622710622711, + 0.3641636141636142, + 0.3373015873015873 + ] + JARO_WINKLER = [ + 0.20013736263736265, + 0.25491452991452995, + 0.2361111111111111 + ] + + +@pytest.mark.parametrize(['sequences', 'distance'], zip(Cases.SEQUENCES, Results.CDR_DIST)) +def test_cdr_dist(sequences, distance): + result = single_dispatch.cdr_dist(*sequences) + assert result == distance + + +@pytest.mark.parametrize(['sequences', 'distance'], zip(Cases.SEQUENCES, Results.LEVENSHTEIN)) +def test_levenshtein(sequences, distance): + result = single_dispatch.levenshtein(*sequences) + assert result == distance + + +@pytest.mark.parametrize(['sequences', 'distance'], zip(Cases.EQUAL_SEQUENCE_LENGTH, Results.TCR_DIST)) +def test_tcr_dist(sequences, distance): + a, b = sequences + component = {'substitution_matrix': BLOSUM62, 'gap_penalty': 4.0} + result = single_dispatch.tcr_dist({'cmp_1': a}, {'cmp_1': b}, cmp_1=component) + assert result == distance + + +@pytest.mark.parametrize(['sequences', 'distance'], zip(Cases.EQUAL_SEQUENCE_LENGTH, Results.HAMMING)) +def test_hamming(sequences, distance): + result = single_dispatch.hamming(*sequences) + assert result == distance + + +@pytest.mark.parametrize(['sequences', 'distance'], zip(Cases.SEQUENCES, Results.JARO)) +def test_jaro(sequences, distance): + result = single_dispatch.jaro(*sequences) + assert result == distance + + +@pytest.mark.parametrize(['sequences', 'distance'], zip(Cases.SEQUENCES, Results.JARO_WINKLER)) +def test_jaro_winkler(sequences, distance): + result = single_dispatch.jaro_winkler(*sequences, p=0.10) + assert result == distance diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..555d1e5 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,109 @@ +import pytest + +import numpy as np + +from setriq.modules import utils + + +class Cases: + ENFORCE_LIST = [ + 1, [], [1, 2, 3], "hello world", np.array([1, 2, 3]) + ] + ENSURE_SEQ_LEN = [ + ['AASQ', 'PSQ'], + ['GTA', 'LA', 'KKR'], + ['SQVENCES', 'SEQVENCES'] + ] + SINGLE_DISPATCH = [ + (1, 2), ('hello', 3), ([], []) + ] + TCR_DIST_COMPONENT = [ + ({}, ValueError), # checks the default error + ({'comp_1': {'gap_penalty': 1.0}}, ValueError), # checks missing keys + ({'comp_1': {'substitution_matrix': [], 'gap_penalty': 1.0}}, TypeError) + ] + JARO_WEIGHTS = [ + ([1 / 3] * 3,), + (None,), + ([0.5, 0.3, 0.2],) + ] + JARO_WINKLER_PARAMS = [ + ({'p': 0.50, 'max_l': 4}, r'`p` must be in range \[0\.0, 0\.25\]'), + ({'p': 0.10, 'max_l': -1}, r'`max_l` must be a non-negative integer') + ] + + +@pytest.mark.parametrize('test_case', Cases.ENFORCE_LIST) +def test_enforce_list(test_case): + @utils.enforce_list() + def f(x): + return x + + assert isinstance(f(test_case), list) + + +@pytest.mark.parametrize('test_case', Cases.ENSURE_SEQ_LEN) +def test_ensure_equal_sequence_length(test_case): + @utils.ensure_equal_sequence_length(argnum=0) + def f(x): + return x + + with pytest.raises(ValueError): + f(test_case) + + +@pytest.mark.parametrize('test_case', Cases.SINGLE_DISPATCH) +def test_single_dispatch(test_case): + @utils.single_dispatch + def f(a, b): + return a, b + + with pytest.raises(TypeError): + f(*test_case) + + +@pytest.mark.parametrize(['arguments', 'exception'], Cases.TCR_DIST_COMPONENT) +def test_tcr_dist_component_check_sd(arguments, exception): + @utils.tcr_dist_sd_component_check + def f(a, b, **component_df): + return a, b, component_df + + with pytest.raises(exception): + f({'comp_1': ''}, {'comp_1': ''}, **arguments) + + +@pytest.mark.parametrize('test_case', Cases.ENSURE_SEQ_LEN) +def test_ensure_equal_sequence_length_sd(test_case): + @utils.ensure_equal_sequence_length_sd + def f(a, b): + return a, b + + with pytest.raises(ValueError): + f(*test_case[:2]) + + +@pytest.mark.parametrize('test_case', Cases.JARO_WEIGHTS) +def test_check_jaro_weights(test_case): + @utils.check_jaro_weights + def f(a, b, jaro_weights=None): + return jaro_weights + + arg, = test_case + assert f('', '', *test_case) == (arg or [1 / 3] * 3) + assert f('', '', jaro_weights=arg) == (arg or [1 / 3] * 3) + + with pytest.raises(ValueError, match='`jaro_weights` has to be of length 3'): + f('', '', (arg or [1 / 3] * 3)[:2]) + + with pytest.raises(ValueError, match='`jaro_weights` has to sum to 1.0'): + f('', '', [1] * 3 if not arg else [elem + 1 for elem in arg]) + + +@pytest.mark.parametrize(['arguments', 'exception_message'], Cases.JARO_WINKLER_PARAMS) +def test_check_jaro_winkler_params(arguments, exception_message): + @utils.check_jaro_winkler_params + def f(p, max_l): + return p, max_l + + with pytest.raises(ValueError, match=exception_message): + f(**arguments)