diff --git a/tests/test_distances.py b/tests/test_distances.py index 42a1af0..de12b2a 100644 --- a/tests/test_distances.py +++ b/tests/test_distances.py @@ -69,6 +69,11 @@ def _method(): return _method +@pytest.fixture() +def tcr_dist_keys(): + return ['cdr_1', 'cdr_2', 'cdr_2_5', 'cdr_3'] + + @pytest.fixture() def tcr_dist_base(): def _method(): @@ -158,8 +163,11 @@ def test_tcr_dist_component(tcr_dist_component, sequences, distances): @pytest.mark.parametrize(['sequences', 'distances'], convert_to_tcr_dist_format(test_cases, tcr_dist_results)) -def test_tcr_dist(tcr_dist_base, sequences, distances): +def test_tcr_dist(tcr_dist_base, tcr_dist_keys, sequences, distances): metric = tcr_dist_base() + assert metric.required_input_keys == tcr_dist_keys + assert all(k == key for (k, _), key in zip(metric.default_definition, tcr_dist_keys)) + response = metric(sequences) n = len(sequences) @@ -178,8 +186,10 @@ def test_tcr_dist_error(tcr_dist_base, sequences, distances): @pytest.mark.parametrize(['sequences', 'distances'], convert_to_tcr_dist_format(test_cases, tcr_dist_results)) -def test_tcr_dist_custom(tcr_dist_custom, sequences, distances): +def test_tcr_dist_custom(tcr_dist_custom, tcr_dist_keys, sequences, distances): metric = tcr_dist_custom() + assert metric.required_input_keys == tcr_dist_keys + response = metric(sequences) n = len(sequences)