diff --git a/.github/workflows/performance-test.yml b/.github/workflows/performance-test.yml index fe48e0a..1df4f2d 100644 --- a/.github/workflows/performance-test.yml +++ b/.github/workflows/performance-test.yml @@ -225,17 +225,11 @@ jobs: echo "" >> $GITHUB_STEP_SUMMARY cat performance.md >> $GITHUB_STEP_SUMMARY - - name: Commit Performance Report - if: always() + - name: Commit and Push Performance Report + if: always() && github.ref == 'refs/heads/main' run: | git config --local user.email "action@github.com" git config --local user.name "GitHub Action" git add performance.md git diff --staged --quiet || git commit -m "Update performance test results [skip ci]" - - - name: Push Performance Report - if: always() - uses: ad-m/github-push-action@master - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - branch: ${{ github.ref }} + git push origin HEAD:main diff --git a/.gitignore b/.gitignore index 3bf8cf0..375503c 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ test_results.py .vscode/settings.json temp_examples_output.txt json_block_*.json +.idea/ \ No newline at end of file diff --git a/src/test/term_info_queries_test.py b/src/test/term_info_queries_test.py index b3e7849..f8a8313 100644 --- a/src/test/term_info_queries_test.py +++ b/src/test/term_info_queries_test.py @@ -299,7 +299,7 @@ def test_term_info_serialization_neuron_class2(self): self.assertFalse("thumbnail" in serialized) self.assertTrue("references" in serialized) - self.assertEqual(7, len(serialized["references"])) + self.assertEqual(9, len(serialized["references"])) self.assertTrue("targetingSplits" in serialized) self.assertEqual(6, len(serialized["targetingSplits"])) diff --git a/src/test/test_downstream_class_connectivity.py b/src/test/test_downstream_class_connectivity.py new file mode 100644 index 0000000..2483046 --- /dev/null +++ b/src/test/test_downstream_class_connectivity.py @@ -0,0 +1,120 @@ +"""Tests for DownstreamClassConnectivity query. + +Tests the query that finds downstream partner neuron classes for a given +neuron class, using the pre-indexed downstream_connectivity_query Solr field. +""" + +import pytest +import pandas as pd + +from vfbquery.vfb_queries import ( + get_downstream_class_connectivity, + DownstreamClassConnectivity_to_schema, +) + +# FBbt_00001482 = lineage NB3-2 primary interneuron — known to have +# downstream_connectivity_query data in the vfb_json Solr core. +TEST_CLASS = "FBbt_00001482" +# A class that is unlikely to have downstream connectivity data. +EMPTY_CLASS = "FBbt_00000001" + + +class TestDownstreamClassConnectivityDict: + """Tests using return_dataframe=False (dict output).""" + + @pytest.mark.integration + def test_returns_results(self): + result = get_downstream_class_connectivity( + TEST_CLASS, return_dataframe=False, force_refresh=True + ) + assert isinstance(result, dict) + assert result["count"] > 0 + assert len(result["rows"]) > 0 + + @pytest.mark.integration + def test_row_has_expected_keys(self): + result = get_downstream_class_connectivity( + TEST_CLASS, return_dataframe=False, limit=1, force_refresh=True + ) + assert result["rows"], "Expected at least one row" + row = result["rows"][0] + expected_keys = { + "id", "downstream_class", "total_n", "connected_n", + "percent_connected", "pairwise_connections", "total_weight", "avg_weight", + } + assert expected_keys.issubset(row.keys()) + + @pytest.mark.integration + def test_headers_present(self): + result = get_downstream_class_connectivity( + TEST_CLASS, return_dataframe=False, limit=1, force_refresh=True + ) + assert "headers" in result + assert "downstream_class" in result["headers"] + + @pytest.mark.integration + def test_limit_respected(self): + result = get_downstream_class_connectivity( + TEST_CLASS, return_dataframe=False, limit=3, force_refresh=True + ) + assert len(result["rows"]) <= 3 + # count should reflect total, not the limited set + assert result["count"] >= len(result["rows"]) + + @pytest.mark.integration + def test_empty_class_returns_zero(self): + result = get_downstream_class_connectivity( + EMPTY_CLASS, return_dataframe=False, force_refresh=True + ) + assert result["count"] == 0 + assert result["rows"] == [] + + +class TestDownstreamClassConnectivityDataFrame: + """Tests using return_dataframe=True (DataFrame output).""" + + @pytest.mark.integration + def test_returns_dataframe(self): + df = get_downstream_class_connectivity( + TEST_CLASS, return_dataframe=True, force_refresh=True + ) + assert isinstance(df, pd.DataFrame) + assert not df.empty + + @pytest.mark.integration + def test_dataframe_has_expected_columns(self): + df = get_downstream_class_connectivity( + TEST_CLASS, return_dataframe=True, limit=1, force_refresh=True + ) + expected_cols = { + "id", "downstream_class", "total_n", "connected_n", + "percent_connected", "pairwise_connections", "total_weight", "avg_weight", + } + assert expected_cols.issubset(set(df.columns)) + + @pytest.mark.integration + def test_limit_respected(self): + df = get_downstream_class_connectivity( + TEST_CLASS, return_dataframe=True, limit=5, force_refresh=True + ) + assert len(df) <= 5 + + @pytest.mark.integration + def test_empty_class_returns_empty_dataframe(self): + df = get_downstream_class_connectivity( + EMPTY_CLASS, return_dataframe=True, force_refresh=True + ) + assert isinstance(df, pd.DataFrame) + assert df.empty + + +class TestDownstreamClassConnectivitySchema: + def test_schema_generation(self): + schema = DownstreamClassConnectivity_to_schema( + "test neuron class", {"short_form": TEST_CLASS} + ) + assert schema.query == "DownstreamClassConnectivity" + assert schema.function == "get_downstream_class_connectivity" + assert schema.preview == 5 + assert "downstream_class" in schema.preview_columns + assert "percent_connected" in schema.preview_columns diff --git a/src/test/test_hierarchy.py b/src/test/test_hierarchy.py new file mode 100644 index 0000000..da66d25 --- /dev/null +++ b/src/test/test_hierarchy.py @@ -0,0 +1,166 @@ +"""Tests for get_hierarchy function. + +Tests the hierarchy tree builder for both part_of (brain region structure) +and subclass_of (cell type hierarchies), in both ancestor and descendant +directions. +""" + +import pytest + +from vfbquery.vfb_queries import get_hierarchy + + +# Known test terms +MUSHROOM_BODY = "FBbt_00005801" +KENYON_CELL = "FBbt_00003686" + + +class TestHierarchyValidation: + def test_invalid_relationship_raises(self): + with pytest.raises(ValueError, match="relationship"): + get_hierarchy(KENYON_CELL, relationship="invalid") + + def test_invalid_direction_raises(self): + with pytest.raises(ValueError, match="direction"): + get_hierarchy(KENYON_CELL, direction="invalid") + + +class TestSubclassOfDescendants: + @pytest.mark.integration + def test_returns_descendants(self): + result = get_hierarchy(KENYON_CELL, 'subclass_of', 'descendants', max_depth=1) + assert result['id'] == KENYON_CELL + assert result['label'] == 'Kenyon cell' + assert result['relationship'] == 'subclass_of' + assert 'descendants' in result + assert len(result['descendants']) > 0 + + @pytest.mark.integration + def test_descendants_have_id_and_label(self): + result = get_hierarchy(KENYON_CELL, 'subclass_of', 'descendants', max_depth=1) + for child in result['descendants']: + assert 'id' in child + assert 'label' in child + assert child['id'].startswith('FBbt_') + assert child['label'] != child['id'] # label should be resolved + + @pytest.mark.integration + def test_depth_1_has_no_grandchildren(self): + result = get_hierarchy(KENYON_CELL, 'subclass_of', 'descendants', max_depth=1) + for child in result['descendants']: + assert 'descendants' not in child + + @pytest.mark.integration + def test_depth_2_has_nested_children(self): + result = get_hierarchy(KENYON_CELL, 'subclass_of', 'descendants', max_depth=2) + has_grandchildren = any('descendants' in child for child in result['descendants']) + assert has_grandchildren, "At least one direct subclass should have its own subclasses" + + +class TestSubclassOfAncestors: + @pytest.mark.integration + def test_returns_ancestors(self): + result = get_hierarchy(KENYON_CELL, 'subclass_of', 'ancestors', max_depth=1) + assert 'ancestors' in result + assert len(result['ancestors']) > 0 + + @pytest.mark.integration + def test_ancestors_have_id_and_label(self): + result = get_hierarchy(KENYON_CELL, 'subclass_of', 'ancestors', max_depth=1) + for anc in result['ancestors']: + assert 'id' in anc + assert 'label' in anc + + @pytest.mark.integration + def test_kenyon_cell_ancestor_is_mb_intrinsic_neuron(self): + result = get_hierarchy(KENYON_CELL, 'subclass_of', 'ancestors', max_depth=1) + ancestor_ids = [a['id'] for a in result['ancestors']] + assert 'FBbt_00007484' in ancestor_ids # mushroom body intrinsic neuron + + @pytest.mark.integration + def test_depth_2_has_nested_ancestors(self): + result = get_hierarchy(KENYON_CELL, 'subclass_of', 'ancestors', max_depth=2) + has_grandparent = any('ancestors' in anc for anc in result['ancestors']) + assert has_grandparent + + +class TestPartOfDescendants: + @pytest.mark.integration + def test_returns_parts(self): + result = get_hierarchy(MUSHROOM_BODY, 'part_of', 'descendants', max_depth=1) + assert result['id'] == MUSHROOM_BODY + assert result['label'] == 'mushroom body' + assert 'descendants' in result + assert len(result['descendants']) > 0 + + @pytest.mark.integration + def test_parts_have_id_and_label(self): + result = get_hierarchy(MUSHROOM_BODY, 'part_of', 'descendants', max_depth=1) + for part in result['descendants']: + assert 'id' in part + assert 'label' in part + assert part['id'].startswith('FBbt_') + + +class TestPartOfAncestors: + @pytest.mark.integration + def test_mushroom_body_part_of_protocerebrum(self): + result = get_hierarchy(MUSHROOM_BODY, 'part_of', 'ancestors', max_depth=1) + assert 'ancestors' in result + ancestor_ids = [a['id'] for a in result['ancestors']] + assert 'FBbt_00003627' in ancestor_ids # protocerebrum + + +class TestDisplayOutput: + @pytest.mark.integration + def test_display_field_present(self): + result = get_hierarchy(KENYON_CELL, 'subclass_of', 'both', max_depth=1) + assert 'display' in result + assert isinstance(result['display'], str) + assert 'Kenyon cell' in result['display'] + + @pytest.mark.integration + def test_display_shows_ancestors(self): + result = get_hierarchy(KENYON_CELL, 'subclass_of', 'both', max_depth=1) + assert 'ancestors' in result['display'].lower() + assert 'mushroom body intrinsic neuron' in result['display'] + + @pytest.mark.integration + def test_display_shows_tree_connectors(self): + result = get_hierarchy(KENYON_CELL, 'subclass_of', 'descendants', max_depth=1) + assert '├──' in result['display'] or '└──' in result['display'] + + @pytest.mark.integration + def test_html_field_present(self): + result = get_hierarchy(KENYON_CELL, 'subclass_of', 'both', max_depth=1) + assert 'html' in result + assert '' in result['html'] + assert 'Kenyon cell' in result['html'] + + @pytest.mark.integration + def test_html_contains_vfb_links(self): + result = get_hierarchy(KENYON_CELL, 'subclass_of', 'descendants', max_depth=1) + assert 'virtualflybrain.org' in result['html'] + assert KENYON_CELL in result['html'] + + +class TestBothDirections: + @pytest.mark.integration + def test_both_returns_ancestors_and_descendants(self): + result = get_hierarchy(KENYON_CELL, 'subclass_of', 'both', max_depth=1) + assert 'ancestors' in result + assert 'descendants' in result + assert len(result['ancestors']) > 0 + assert len(result['descendants']) > 0 + + @pytest.mark.integration + def test_descendants_only_has_no_ancestors(self): + result = get_hierarchy(KENYON_CELL, 'subclass_of', 'descendants', max_depth=1) + assert 'descendants' in result + assert 'ancestors' not in result + + @pytest.mark.integration + def test_ancestors_only_has_no_descendants(self): + result = get_hierarchy(KENYON_CELL, 'subclass_of', 'ancestors', max_depth=1) + assert 'ancestors' in result + assert 'descendants' not in result diff --git a/src/test/test_neuron_neuron_connectivity.py b/src/test/test_neuron_neuron_connectivity.py index aa6ff6e..b833b93 100644 --- a/src/test/test_neuron_neuron_connectivity.py +++ b/src/test/test_neuron_neuron_connectivity.py @@ -1,89 +1,118 @@ -#!/usr/bin/env python3 -""" -Test suite for NeuronNeuronConnectivityQuery. +"""Tests for NeuronNeuronConnectivityQuery. Tests the query that finds neurons connected to a given neuron. This implements the neuron_neuron_connectivity_query from the VFB XMI specification. - -Test cases: -1. Query execution with known neuron -2. Schema generation and validation -3. Term info integration (if applicable) -4. Preview results validation """ -import unittest -import sys -import os - -# Add the src directory to the path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +import pytest +import pandas as pd from vfbquery.vfb_queries import ( get_neuron_neuron_connectivity, NeuronNeuronConnectivityQuery_to_schema, - get_term_info ) -class NeuronNeuronConnectivityTest(unittest.TestCase): - """Test suite for neuron_neuron_connectivity_query""" - - def setUp(self): - """Set up test fixtures""" - # Test neuron: LPC1 (FlyEM-HB:1775513344) [VFB_jrchk00s] - self.test_neuron = "VFB_jrchk00s" - - def test_query_execution(self): - """Test that the query executes successfully""" - print(f"\n=== Testing neuron_neuron_connectivity_query execution ===") - result = get_neuron_neuron_connectivity(self.test_neuron, return_dataframe=False, limit=5) - self.assertIsNotNone(result, "Query should return a result") - self.assertIsInstance(result, dict, "Result should be a dictionary") - print(f"Query returned {result.get('count', 0)} results") - if 'data' in result and len(result['data']) > 0: - first_result = result['data'][0] - self.assertIn('id', first_result, "Result should contain 'id' field") - self.assertIn('label', first_result, "Result should contain 'label' field") - print(f"First result: {first_result.get('label', 'N/A')} ({first_result.get('id', 'N/A')})") - else: - print("No connected neurons found (this is OK if none exist)") +# VFB_jrchk00s = LPC1 (FlyEM-HB:1775513344) — known to have connectivity data. +TEST_NEURON = "VFB_jrchk00s" + + +class TestNeuronNeuronConnectivityDict: + """Tests using return_dataframe=False (dict output).""" + + @pytest.mark.integration + def test_returns_results(self): + result = get_neuron_neuron_connectivity( + TEST_NEURON, return_dataframe=False + ) + assert isinstance(result, dict) + assert result["count"] > 0 + assert len(result["rows"]) > 0 + + @pytest.mark.integration + def test_row_has_expected_keys(self): + result = get_neuron_neuron_connectivity( + TEST_NEURON, return_dataframe=False, limit=1 + ) + assert result["rows"], "Expected at least one row" + row = result["rows"][0] + expected_keys = {"id", "label", "outputs", "inputs", "tags"} + assert expected_keys.issubset(row.keys()) + + @pytest.mark.integration + def test_headers_present(self): + result = get_neuron_neuron_connectivity( + TEST_NEURON, return_dataframe=False, limit=1 + ) + assert "headers" in result + assert "label" in result["headers"] + assert "outputs" in result["headers"] + assert "inputs" in result["headers"] + + @pytest.mark.integration + def test_limit_respected(self): + result = get_neuron_neuron_connectivity( + TEST_NEURON, return_dataframe=False, limit=3 + ) + assert len(result["rows"]) <= 3 + assert result["count"] >= len(result["rows"]) + + @pytest.mark.integration + def test_direction_upstream(self): + all_result = get_neuron_neuron_connectivity( + TEST_NEURON, return_dataframe=False + ) + up_result = get_neuron_neuron_connectivity( + TEST_NEURON, return_dataframe=False, direction='upstream' + ) + assert up_result["count"] > 0 + assert up_result["count"] <= all_result["count"] + + @pytest.mark.integration + def test_direction_downstream(self): + all_result = get_neuron_neuron_connectivity( + TEST_NEURON, return_dataframe=False + ) + down_result = get_neuron_neuron_connectivity( + TEST_NEURON, return_dataframe=False, direction='downstream' + ) + assert down_result["count"] > 0 + assert down_result["count"] <= all_result["count"] + + +class TestNeuronNeuronConnectivityDataFrame: + """Tests using return_dataframe=True (DataFrame output).""" + + @pytest.mark.integration + def test_returns_dataframe(self): + df = get_neuron_neuron_connectivity( + TEST_NEURON, return_dataframe=True + ) + assert isinstance(df, pd.DataFrame) + assert not df.empty + + @pytest.mark.integration + def test_dataframe_has_expected_columns(self): + df = get_neuron_neuron_connectivity( + TEST_NEURON, return_dataframe=True, limit=1 + ) + expected_cols = {"id", "label", "outputs", "inputs", "tags"} + assert expected_cols.issubset(set(df.columns)) + + @pytest.mark.integration + def test_limit_respected(self): + df = get_neuron_neuron_connectivity( + TEST_NEURON, return_dataframe=True, limit=5 + ) + assert len(df) <= 5 + +class TestNeuronNeuronConnectivitySchema: def test_schema_generation(self): - """Test schema function generates correct structure""" - print(f"\n=== Testing neuron_neuron_connectivity_query schema generation ===") - test_name = "LPC1" - test_takes = {"short_form": self.test_neuron} - schema = NeuronNeuronConnectivityQuery_to_schema(test_name, test_takes) - self.assertIsNotNone(schema, "Schema should not be None") - self.assertEqual(schema.query, "NeuronNeuronConnectivityQuery", "Query name should match") - self.assertEqual(schema.label, f"Neurons connected to {test_name}", "Label should be formatted correctly") - self.assertEqual(schema.function, "get_neuron_neuron_connectivity", "Function name should match") - self.assertEqual(schema.preview, 5, "Preview should be 5") - expected_columns = ["id", "label", "outputs", "inputs", "tags"] - self.assertEqual(schema.preview_columns, expected_columns, f"Preview columns should be {expected_columns}") - print(f"Schema generated successfully: {schema.label}") - - def test_preview_results(self): - """Test that preview results are properly formatted""" - print(f"\n=== Testing preview results ===") - result = get_neuron_neuron_connectivity(self.test_neuron, return_dataframe=False, limit=3) - self.assertIsNotNone(result, "Query should return a result") - if 'data' in result and len(result['data']) > 0: - first_result = result['data'][0] - self.assertIn('id', first_result, "Preview result should have 'id'") - self.assertIn('label', first_result, "Preview result should have 'label'") - print(f"First preview result: {first_result.get('label', 'N/A')}") - else: - print("No preview results available (this is OK if no connected neurons exist)") - - -def run_tests(): - """Run the test suite""" - suite = unittest.TestLoader().loadTestsFromTestCase(NeuronNeuronConnectivityTest) - runner = unittest.TextTestRunner(verbosity=2) - result = runner.run(suite) - return result.wasSuccessful() - -if __name__ == '__main__': - success = run_tests() - sys.exit(0 if success else 1) + schema = NeuronNeuronConnectivityQuery_to_schema( + "LPC1", {"short_form": TEST_NEURON} + ) + assert schema.query == "NeuronNeuronConnectivityQuery" + assert schema.function == "get_neuron_neuron_connectivity" + assert schema.label == "Neurons connected to LPC1" + assert schema.preview == 5 + assert schema.preview_columns == ["id", "label", "outputs", "inputs", "tags"] diff --git a/src/test/test_neuron_region_connectivity.py b/src/test/test_neuron_region_connectivity.py index 72f0efe..88c8b75 100644 --- a/src/test/test_neuron_region_connectivity.py +++ b/src/test/test_neuron_region_connectivity.py @@ -1,117 +1,99 @@ -#!/usr/bin/env python3 -""" -Test suite for NeuronRegionConnectivityQuery. +"""Tests for NeuronRegionConnectivityQuery. -Tests the query that shows connectivity to regions from a given neuron. +Tests the query that finds brain regions where a given neuron has synaptic terminals. This implements the neuron_region_connectivity_query from the VFB XMI specification. - -Test cases: -1. Query execution with known neuron -2. Schema generation and validation -3. Term info integration (if applicable) -4. Preview results validation """ -import unittest -import sys -import os - -# Add the src directory to the path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +import pytest +import pandas as pd from vfbquery.vfb_queries import ( get_neuron_region_connectivity, NeuronRegionConnectivityQuery_to_schema, - get_term_info + get_term_info, ) -class NeuronRegionConnectivityTest(unittest.TestCase): - """Test suite for neuron_region_connectivity_query""" - - def setUp(self): - """Set up test fixtures""" - # Test neuron: LPC1 (FlyEM-HB:1775513344) [VFB_jrchk00s] - self.test_neuron = "VFB_jrchk00s" - - def test_query_execution(self): - """Test that the query executes successfully""" - print(f"\n=== Testing neuron_region_connectivity_query execution ===") - result = get_neuron_region_connectivity(self.test_neuron, return_dataframe=False, limit=5) - self.assertIsNotNone(result, "Query should return a result") - self.assertIsInstance(result, dict, "Result should be a dictionary") - print(f"Query returned {result.get('count', 0)} results") - if 'data' in result and len(result['data']) > 0: - first_result = result['data'][0] - self.assertIn('id', first_result, "Result should contain 'id' field") - self.assertIn('region', first_result, "Result should contain 'region' field") - self.assertIn('presynaptic_terminals', first_result, "Result should contain 'presynaptic_terminals' field") - self.assertIn('postsynaptic_terminals', first_result, "Result should contain 'postsynaptic_terminals' field") - print(f"First result: {first_result.get('region', 'N/A')} ({first_result.get('id', 'N/A')})") - print(f" Pre: {first_result.get('presynaptic_terminals', 0)}, Post: {first_result.get('postsynaptic_terminals', 0)}") - else: - print("No regions with connectivity found (this is OK if none exist)") +# VFB_jrchk00s = LPC1 (FlyEM-HB:1775513344) — known to have region connectivity data. +TEST_NEURON = "VFB_jrchk00s" + + +class TestNeuronRegionConnectivityDict: + """Tests using return_dataframe=False (dict output).""" + + @pytest.mark.integration + def test_returns_results(self): + result = get_neuron_region_connectivity( + TEST_NEURON, return_dataframe=False + ) + assert isinstance(result, dict) + assert result["count"] > 0 + assert len(result["rows"]) > 0 + + @pytest.mark.integration + def test_row_has_expected_keys(self): + result = get_neuron_region_connectivity( + TEST_NEURON, return_dataframe=False, limit=1 + ) + assert result["rows"], "Expected at least one row" + row = result["rows"][0] + expected_keys = {"id", "region", "presynaptic_terminals", "postsynaptic_terminals", "tags"} + assert expected_keys.issubset(row.keys()) + + @pytest.mark.integration + def test_headers_present(self): + result = get_neuron_region_connectivity( + TEST_NEURON, return_dataframe=False, limit=1 + ) + assert "headers" in result + assert "region" in result["headers"] + assert "presynaptic_terminals" in result["headers"] + assert "postsynaptic_terminals" in result["headers"] + @pytest.mark.integration + def test_limit_respected(self): + result = get_neuron_region_connectivity( + TEST_NEURON, return_dataframe=False, limit=3 + ) + assert len(result["rows"]) <= 3 + assert result["count"] >= len(result["rows"]) + + +class TestNeuronRegionConnectivityDataFrame: + """Tests using return_dataframe=True (DataFrame output).""" + + @pytest.mark.integration + def test_returns_dataframe(self): + df = get_neuron_region_connectivity( + TEST_NEURON, return_dataframe=True + ) + assert isinstance(df, pd.DataFrame) + assert not df.empty + + @pytest.mark.integration + def test_dataframe_has_expected_columns(self): + df = get_neuron_region_connectivity( + TEST_NEURON, return_dataframe=True, limit=1 + ) + expected_cols = {"id", "region", "presynaptic_terminals", "postsynaptic_terminals", "tags"} + assert expected_cols.issubset(set(df.columns)) + + @pytest.mark.integration + def test_limit_respected(self): + df = get_neuron_region_connectivity( + TEST_NEURON, return_dataframe=True, limit=3 + ) + assert len(df) <= 3 + + +class TestNeuronRegionConnectivitySchema: def test_schema_generation(self): - """Test that the schema function works correctly""" - print(f"\n=== Testing NeuronRegionConnectivityQuery schema generation ===") - - # Get term info for the test neuron - term_info = get_term_info(self.test_neuron) - if term_info: - neuron_name = term_info.get('Name', self.test_neuron) - else: - neuron_name = self.test_neuron - - # Generate schema - schema = NeuronRegionConnectivityQuery_to_schema(neuron_name, self.test_neuron) - - # Validate schema structure - self.assertIsNotNone(schema, "Schema should not be None") - self.assertEqual(schema.query, "NeuronRegionConnectivityQuery", "Query name should match") - self.assertEqual(schema.function, "get_neuron_region_connectivity", "Function name should match") - self.assertEqual(schema.preview, 5, "Preview should show 5 results") - self.assertIn("region", schema.preview_columns, "Preview should include 'region' column") - self.assertIn("presynaptic_terminals", schema.preview_columns, "Preview should include 'presynaptic_terminals' column") - self.assertIn("postsynaptic_terminals", schema.preview_columns, "Preview should include 'postsynaptic_terminals' column") - - print(f"Schema label: {schema.label}") - print(f"Preview columns: {schema.preview_columns}") - - def test_term_info_integration(self): - """Test that term info lookup works for the test neuron""" - print(f"\n=== Testing term_info integration ===") - term_info = get_term_info(self.test_neuron) - - self.assertIsNotNone(term_info, "Term info should not be None") - if term_info: - # get_term_info returns a dict with 'Name', 'Id', 'Tags', etc. - self.assertIn('Name', term_info, "Term info should contain 'Name'") - self.assertIn('Id', term_info, "Term info should contain 'Id'") - print(f"Neuron name: {term_info.get('Name', 'N/A')}") - print(f"Neuron tags: {term_info.get('Tags', [])}") - else: - print(f"Note: Term info not found for {self.test_neuron} (may not be in SOLR)") - - def test_preview_validation(self): - """Test that preview results are properly formatted""" - print(f"\n=== Testing preview results ===") - result = get_neuron_region_connectivity(self.test_neuron, return_dataframe=False, limit=5) - - if 'data' in result and len(result['data']) > 0: - # Check that all preview columns exist in the results - expected_columns = ['id', 'region', 'presynaptic_terminals', 'postsynaptic_terminals', 'tags'] - for item in result['data']: - for col in expected_columns: - self.assertIn(col, item, f"Result should contain '{col}' field") - - print(f"✓ All {len(result['data'])} results have required preview columns") - - # Print sample results - for i, item in enumerate(result['data'][:3], 1): - print(f"{i}. {item.get('region', 'N/A')} - Pre:{item.get('presynaptic_terminals', 0)}, Post:{item.get('postsynaptic_terminals', 0)}") - else: - print("No preview data available (query returned no results)") - - -if __name__ == '__main__': - unittest.main(verbosity=2) + term_info = get_term_info(TEST_NEURON) + neuron_name = term_info.get('Name', TEST_NEURON) if term_info else TEST_NEURON + + schema = NeuronRegionConnectivityQuery_to_schema(neuron_name, TEST_NEURON) + assert schema.query == "NeuronRegionConnectivityQuery" + assert schema.function == "get_neuron_region_connectivity" + assert schema.preview == 5 + assert "region" in schema.preview_columns + assert "presynaptic_terminals" in schema.preview_columns + assert "postsynaptic_terminals" in schema.preview_columns diff --git a/src/test/test_upstream_class_connectivity.py b/src/test/test_upstream_class_connectivity.py new file mode 100644 index 0000000..ae59e9f --- /dev/null +++ b/src/test/test_upstream_class_connectivity.py @@ -0,0 +1,120 @@ +"""Tests for UpstreamClassConnectivity query. + +Tests the query that finds upstream partner neuron classes for a given +neuron class, using the pre-indexed upstream_connectivity_query Solr field. +""" + +import pytest +import pandas as pd + +from vfbquery.vfb_queries import ( + get_upstream_class_connectivity, + UpstreamClassConnectivity_to_schema, +) + +# FBbt_00001482 = lineage NB3-2 primary interneuron — known to have +# upstream_connectivity_query data in the vfb_json Solr core. +TEST_CLASS = "FBbt_00001482" +# A class that is unlikely to have upstream connectivity data. +EMPTY_CLASS = "FBbt_00000001" + + +class TestUpstreamClassConnectivityDict: + """Tests using return_dataframe=False (dict output).""" + + @pytest.mark.integration + def test_returns_results(self): + result = get_upstream_class_connectivity( + TEST_CLASS, return_dataframe=False, force_refresh=True + ) + assert isinstance(result, dict) + assert result["count"] > 0 + assert len(result["rows"]) > 0 + + @pytest.mark.integration + def test_row_has_expected_keys(self): + result = get_upstream_class_connectivity( + TEST_CLASS, return_dataframe=False, limit=1, force_refresh=True + ) + assert result["rows"], "Expected at least one row" + row = result["rows"][0] + expected_keys = { + "id", "upstream_class", "total_n", "connected_n", + "percent_connected", "pairwise_connections", "total_weight", "avg_weight", + } + assert expected_keys.issubset(row.keys()) + + @pytest.mark.integration + def test_headers_present(self): + result = get_upstream_class_connectivity( + TEST_CLASS, return_dataframe=False, limit=1, force_refresh=True + ) + assert "headers" in result + assert "upstream_class" in result["headers"] + + @pytest.mark.integration + def test_limit_respected(self): + result = get_upstream_class_connectivity( + TEST_CLASS, return_dataframe=False, limit=3, force_refresh=True + ) + assert len(result["rows"]) <= 3 + # count should reflect total, not the limited set + assert result["count"] >= len(result["rows"]) + + @pytest.mark.integration + def test_empty_class_returns_zero(self): + result = get_upstream_class_connectivity( + EMPTY_CLASS, return_dataframe=False, force_refresh=True + ) + assert result["count"] == 0 + assert result["rows"] == [] + + +class TestUpstreamClassConnectivityDataFrame: + """Tests using return_dataframe=True (DataFrame output).""" + + @pytest.mark.integration + def test_returns_dataframe(self): + df = get_upstream_class_connectivity( + TEST_CLASS, return_dataframe=True, force_refresh=True + ) + assert isinstance(df, pd.DataFrame) + assert not df.empty + + @pytest.mark.integration + def test_dataframe_has_expected_columns(self): + df = get_upstream_class_connectivity( + TEST_CLASS, return_dataframe=True, limit=1, force_refresh=True + ) + expected_cols = { + "id", "upstream_class", "total_n", "connected_n", + "percent_connected", "pairwise_connections", "total_weight", "avg_weight", + } + assert expected_cols.issubset(set(df.columns)) + + @pytest.mark.integration + def test_limit_respected(self): + df = get_upstream_class_connectivity( + TEST_CLASS, return_dataframe=True, limit=5, force_refresh=True + ) + assert len(df) <= 5 + + @pytest.mark.integration + def test_empty_class_returns_empty_dataframe(self): + df = get_upstream_class_connectivity( + EMPTY_CLASS, return_dataframe=True, force_refresh=True + ) + assert isinstance(df, pd.DataFrame) + assert df.empty + + +class TestUpstreamClassConnectivitySchema: + def test_schema_generation(self): + schema = UpstreamClassConnectivity_to_schema( + "test neuron class", {"short_form": TEST_CLASS} + ) + assert schema.query == "UpstreamClassConnectivity" + assert schema.function == "get_upstream_class_connectivity" + assert schema.preview == 5 + assert "upstream_class" in schema.preview_columns + assert "percent_connected" in schema.preview_columns diff --git a/src/vfbquery/ha_api.py b/src/vfbquery/ha_api.py index 061ecdd..a091d6b 100644 --- a/src/vfbquery/ha_api.py +++ b/src/vfbquery/ha_api.py @@ -897,6 +897,70 @@ def post_fn(result): ) +def _run_get_hierarchy(short_form, relationship, direction, max_depth): + """Worker: run get_hierarchy in a subprocess.""" + from . import vfb_queries as _vfb + return _convert_numpy_types( + _vfb.get_hierarchy(short_form, relationship=relationship, + direction=direction, max_depth=max_depth) + ) + + +async def handle_get_hierarchy(request): + """GET /get_hierarchy?id=FBbt_00005801&relationship=part_of&direction=both&max_depth=1""" + short_form = request.query.get("id") + if not short_form: + return web.json_response({"error": "id parameter is required"}, status=400) + relationship = request.query.get("relationship", "part_of") + if relationship not in ("part_of", "subclass_of"): + return web.json_response( + {"error": "relationship must be 'part_of' or 'subclass_of'"}, status=400 + ) + direction = request.query.get("direction", "both") + if direction not in ("descendants", "ancestors", "both"): + return web.json_response( + {"error": "direction must be 'descendants', 'ancestors', or 'both'"}, status=400 + ) + max_depth = int(request.query.get("max_depth", "1")) + + key = f"get_hierarchy:{short_form}:{relationship}:{direction}:{max_depth}" + return await _dispatch_to_pool( + request, key, _run_get_hierarchy, + short_form, relationship, direction, max_depth, + ) + + +async def handle_get_hierarchy_html(request): + """GET /get_hierarchy_html?id=FBbt_00005801&relationship=part_of&direction=both&max_depth=1 + + Serves the hierarchy as a self-contained HTML page (Content-Type: text/html). + """ + short_form = request.query.get("id") + if not short_form: + return web.Response(text="Error: id parameter is required", status=400) + relationship = request.query.get("relationship", "part_of") + if relationship not in ("part_of", "subclass_of"): + return web.Response(text="Error: relationship must be 'part_of' or 'subclass_of'", status=400) + direction = request.query.get("direction", "both") + if direction not in ("descendants", "ancestors", "both"): + return web.Response(text="Error: direction must be 'descendants', 'ancestors', or 'both'", status=400) + max_depth = int(request.query.get("max_depth", "1")) + + key = f"get_hierarchy:{short_form}:{relationship}:{direction}:{max_depth}" + json_response = await _dispatch_to_pool( + request, key, _run_get_hierarchy, + short_form, relationship, direction, max_depth, + ) + + # Extract HTML from the JSON result + import json as _json + result = _json.loads(json_response.body) + html = result.get("html", "") + if not html: + return web.Response(text="No hierarchy data found", status=404) + return web.Response(text=html, content_type="text/html") + + # --------------------------------------------------------------------------- # Application factory # --------------------------------------------------------------------------- @@ -937,6 +1001,8 @@ def create_app(max_workers=None, max_concurrent=None, max_queue_depth=None, app.router.add_get("/find_combo_publications", handle_find_combo_publications) app.router.add_get("/list_connectome_datasets", handle_list_connectome_datasets) app.router.add_get("/query_connectivity", handle_query_connectivity) + app.router.add_get("/get_hierarchy", handle_get_hierarchy) + app.router.add_get("/get_hierarchy_html", handle_get_hierarchy_html) # Store config for /status and handlers app["max_workers"] = max_workers diff --git a/src/vfbquery/owlery_client.py b/src/vfbquery/owlery_client.py index 895cd28..af12656 100644 --- a/src/vfbquery/owlery_client.py +++ b/src/vfbquery/owlery_client.py @@ -105,7 +105,7 @@ def convert_short_form_to_iri(match): # Based on VFBConnect's query() method params = { 'object': iri_query, - 'direct': 'false', # Always use indirect (transitive) queries + 'direct': 'true' if direct else 'false', 'includeDeprecated': 'false', # Exclude deprecated terms 'includeEquivalent': 'true' # Include equivalent classes } diff --git a/src/vfbquery/vfb_queries.py b/src/vfbquery/vfb_queries.py index 36b07ed..9449492 100644 --- a/src/vfbquery/vfb_queries.py +++ b/src/vfbquery/vfb_queries.py @@ -4538,3 +4538,612 @@ def process_query(query): process_query(query) return term_info + + +def get_hierarchy(short_form, relationship='part_of', direction='both', max_depth=1): + """Build a hierarchy tree showing ancestors and/or descendants of a term. + + For ``subclass_of`` descendants, all descendants are fetched in one Owlery + call (fast, cached) and the tree is reconstructed by looking up each term's + parents in SOLR. For ``part_of`` descendants, direct children are fetched + per level via Owlery ``direct=True`` (slower on first call, but results are + cached by the Owlery server). + + :param short_form: Root term ID (e.g. 'FBbt_00005801') + :param relationship: 'part_of' for brain region structure, 'subclass_of' for cell type hierarchies + :param direction: 'descendants', 'ancestors', or 'both' + :param max_depth: Levels to expand (default 1 = direct only; -1 = unlimited) + :return: Nested dict with id, label, ancestors, descendants + """ + if relationship not in ('part_of', 'subclass_of'): + raise ValueError("relationship must be 'part_of' or 'subclass_of'") + if direction not in ('descendants', 'ancestors', 'both'): + raise ValueError("direction must be 'descendants', 'ancestors', or 'both'") + + label_cache = {} + _ont_solr = pysolr.Solr('https://solr.virtualflybrain.org/solr/ontology/', always_commit=False, timeout=30) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _batch_lookup_labels(ids): + """Fetch labels for a list of IDs from the ontology SOLR core.""" + missing = [i for i in ids if i not in label_cache] + if not missing: + return + try: + id_list = ','.join(missing) + results = _ont_solr.search( + q='*:*', + fq=f'{{!terms f=short_form}}{id_list}', + fl='short_form,label', + rows=len(missing) + ) + for doc in results.docs: + label_cache[doc.get('short_form', '')] = doc.get('label', doc.get('short_form', '')) + except Exception: + pass + for i in missing: + label_cache.setdefault(i, i) + + def _get_all_children(term_id): + """Get all descendants (transitive) using the existing cached functions.""" + if relationship == 'part_of': + result = get_parts_of(term_id, return_dataframe=False) + else: + result = get_subclasses_of(term_id, return_dataframe=False) + if not result or not result.get('rows'): + return [] + return [row['id'] for row in result['rows'] if row.get('id') and row['id'] != term_id] + + def _term_info_parents(term_id): + """Return [(parent_sf, parent_label), ...] from SOLR term_info.""" + try: + results = vfb_solr.search(f'id:{term_id}', fl='term_info', rows=1) + if not results.docs or 'term_info' not in results.docs[0]: + return [] + raw = results.docs[0]['term_info'] + ti = json.loads(raw[0] if isinstance(raw, list) else raw) + if relationship == 'subclass_of': + return [(p['short_form'], p.get('label', p['short_form'])) for p in ti.get('parents', [])] + else: + # part_of: BFO_0000050 in relationships + out = [] + for r in ti.get('relationships', []): + if 'BFO_0000050' in r.get('relation', {}).get('iri', ''): + obj = r['object'] + out.append((obj['short_form'], obj.get('label', obj['short_form']))) + # Fallback to Neo4j edge + if not out: + try: + cypher = ( + f"MATCH (c:Class {{short_form: '{term_id}'}})" + f"-[:part_of]->(p:Class) " + f"RETURN p.short_form AS sf, p.label AS label" + ) + for row in get_dict_cursor()(vc.nc.commit_list([cypher])): + out.append((row['sf'], row.get('label', row['sf']))) + except Exception: + pass + return out + except Exception: + return [] + + # ------------------------------------------------------------------ + # Descendants + # ------------------------------------------------------------------ + + def _build_descendants_subclass(root_id): + """Build subclass tree: one cached Owlery call + batch SOLR parent lookup.""" + all_desc = _get_all_children(root_id) + if not all_desc: + return [] + + tree_ids = set(all_desc) | {root_id} + _batch_lookup_labels(list(tree_ids)) + + # Batch-fetch parents from vfb_json SOLR + children_of = {tid: [] for tid in tree_ids} + id_list = ','.join(all_desc) + try: + results = vfb_solr.search( + q='id:*', fq=f'{{!terms f=id}}{id_list}', fl='id,term_info', rows=len(all_desc) + ) + for doc in results.docs: + child_id = doc.get('id', '') + if 'term_info' not in doc: + continue + raw = doc['term_info'] + ti = json.loads(raw[0] if isinstance(raw, list) else raw) + parents_in_tree = [p['short_form'] for p in ti.get('parents', []) if p['short_form'] in tree_ids] + if parents_in_tree: + for pid in parents_in_tree: + children_of[pid].append(child_id) + else: + children_of[root_id].append(child_id) + except Exception: + children_of[root_id] = all_desc + + def build(node_id, depth): + node = {'id': node_id, 'label': label_cache.get(node_id, node_id)} + if max_depth == -1 or depth < max_depth: + kids = children_of.get(node_id, []) + if kids: + node['descendants'] = [ + build(k, depth + 1) + for k in sorted(kids, key=lambda x: label_cache.get(x, x)) + ] + return node + + top = children_of.get(root_id, []) + return [build(k, 1) for k in sorted(top, key=lambda x: label_cache.get(x, x))] + + def _build_descendants_part_of(root_id): + """Build part_of descendant tree via Ubergraph SPARQL. + + Queries the Ubergraph redundant graph for all transitive part_of + edges within the subtree, then reconstructs the nesting by finding + each child's most specific parent. + """ + import requests as _req + from collections import defaultdict + + root_iri = _short_form_to_iri(root_id) + sparql = f''' +PREFIX BFO: +PREFIX rdfs: +SELECT DISTINCT ?child ?childLabel ?parent ?parentLabel WHERE {{ + GRAPH {{ + ?child BFO:0000050 <{root_iri}> . + ?child BFO:0000050 ?parent . + }} + FILTER(?parent != ?child) + FILTER( + ?parent = <{root_iri}> || + EXISTS {{ + GRAPH {{ + ?parent BFO:0000050 <{root_iri}> . + }} + }} + ) + ?child rdfs:label ?childLabel . + ?parent rdfs:label ?parentLabel . + FILTER(STRSTARTS(STR(?child), "http://purl.obolibrary.org/obo/FBbt_")) +}} +''' + try: + resp = _req.get( + 'https://ubergraph.apps.renci.org/sparql', + params={'query': sparql}, + headers={'Accept': 'application/json'}, + timeout=30, + ) + resp.raise_for_status() + bindings = resp.json().get('results', {}).get('bindings', []) + except Exception: + # Fallback to flat list via Owlery + all_parts = _get_all_children(root_id) + if not all_parts: + return [] + _batch_lookup_labels(all_parts) + return [ + {'id': pid, 'label': label_cache.get(pid, pid)} + for pid in sorted(all_parts, key=lambda x: label_cache.get(x, x)) + ] + + if not bindings: + return [] + + # Parse SPARQL results into parent map + parents_of = defaultdict(set) + all_parts = set() + for b in bindings: + csf = b['child']['value'].rsplit('/', 1)[-1] + psf = b['parent']['value'].rsplit('/', 1)[-1] + parents_of[csf].add(psf) + label_cache[csf] = b['childLabel']['value'] + label_cache[psf] = b['parentLabel']['value'] + all_parts.add(csf) + + # Find most specific parent for each child + # (no other parent of this child is itself a descendant of this parent) + children_of = defaultdict(list) + for child in all_parts: + best = [] + for p in parents_of[child]: + if not any(p in parents_of.get(q, set()) for q in parents_of[child] if q != p): + best.append(p) + for bp in best: + children_of[bp].append(child) + + def build(node_id, depth): + node = {'id': node_id, 'label': label_cache.get(node_id, node_id)} + if max_depth == -1 or depth < max_depth: + kids = children_of.get(node_id, []) + if kids: + node['descendants'] = [ + build(k, depth + 1) + for k in sorted(kids, key=lambda x: label_cache.get(x, x)) + ] + return node + + top = children_of.get(root_id, []) + return [build(k, 1) for k in sorted(top, key=lambda x: label_cache.get(x, x))] + + # ------------------------------------------------------------------ + # Ancestors + # ------------------------------------------------------------------ + + def _build_ancestors_subclass(term_id, depth, visited): + """Build is-a ancestor chain from SOLR term_info parents. + + Filters to FBbt cell terms only (types includes 'Cell') to + exclude cross-ontology parents (CL, UBERON, BFO, etc.) and + non-cell ancestors (developmental lineage, anatomical structure). + Stops at 'cell' (FBbt_00007002). + """ + if term_id in visited or (max_depth != -1 and depth >= max_depth): + return [] + if term_id == 'FBbt_00007002': # cell — top of useful hierarchy + return [] + visited.add(term_id) + + try: + results = vfb_solr.search(f'id:{term_id}', fl='term_info', rows=1) + if not results.docs or 'term_info' not in results.docs[0]: + return [] + raw = results.docs[0]['term_info'] + ti = json.loads(raw[0] if isinstance(raw, list) else raw) + parents = ti.get('parents', []) + except Exception: + return [] + + ancestors = [] + for p in parents: + psf = p['short_form'] + # Filter: must be FBbt and must be a cell type + if not psf.startswith('FBbt_'): + continue + if 'Cell' not in p.get('types', []): + continue + plabel = p.get('label', psf) + label_cache[psf] = plabel + node = {'id': psf, 'label': plabel} + further = _build_ancestors_subclass(psf, depth + 1, visited) + if further: + node['ancestors'] = further + ancestors.append(node) + return ancestors + + def _build_ancestors_part_of(term_id): + """Build part_of ancestor chain via Ubergraph SPARQL. + + Filters ancestors to terms that are part of the nervous system + (or the nervous system itself) to exclude developmental lineage + terms and generic structural classes that leak in via is-a + propagation in the Ubergraph redundant graph. + """ + import requests as _req + from collections import defaultdict + + term_iri = _short_form_to_iri(term_id) + sparql = f''' +PREFIX BFO: +PREFIX FBbt: +PREFIX rdfs: +SELECT DISTINCT ?ancestor ?ancestorLabel ?parent ?parentLabel WHERE {{ + GRAPH {{ + <{term_iri}> BFO:0000050 ?ancestor . + }} + FILTER(?ancestor != <{term_iri}>) + FILTER(STRSTARTS(STR(?ancestor), "http://purl.obolibrary.org/obo/FBbt_")) + FILTER( + ?ancestor = FBbt:00005093 || + EXISTS {{ + GRAPH {{ + ?ancestor BFO:0000050 FBbt:00005093 . + }} + }} + ) + ?ancestor rdfs:label ?ancestorLabel . + OPTIONAL {{ + GRAPH {{ + ?ancestor BFO:0000050 ?parent . + }} + FILTER( + ?parent = FBbt:00005093 || + EXISTS {{ + GRAPH {{ + ?parent BFO:0000050 FBbt:00005093 . + }} + }} + ) + FILTER(?parent != ?ancestor) + FILTER(STRSTARTS(STR(?parent), "http://purl.obolibrary.org/obo/FBbt_")) + FILTER( + EXISTS {{ + GRAPH {{ + <{term_iri}> BFO:0000050 ?parent . + }} + }} + ) + ?parent rdfs:label ?parentLabel . + }} +}} +''' + try: + resp = _req.get( + 'https://ubergraph.apps.renci.org/sparql', + params={'query': sparql}, + headers={'Accept': 'application/json'}, + timeout=30, + ) + resp.raise_for_status() + bindings = resp.json().get('results', {}).get('bindings', []) + except Exception: + # Fallback to term_info approach + return _build_ancestors_subclass(term_id, 0, set()) + + if not bindings: + return [] + + # Build parent map among ancestors + parents_of = defaultdict(set) + all_ancestors = set() + for b in bindings: + asf = b['ancestor']['value'].rsplit('/', 1)[-1] + label_cache[asf] = b['ancestorLabel']['value'] + all_ancestors.add(asf) + if 'parent' in b: + psf = b['parent']['value'].rsplit('/', 1)[-1] + parents_of[asf].add(psf) + label_cache[psf] = b['parentLabel']['value'] + + # Find most specific ancestors (direct parents of the query term) + # = ancestors that aren't themselves ancestors of another ancestor + children_of = defaultdict(list) + for anc in all_ancestors: + best = [] + for p in parents_of.get(anc, set()): + if p in all_ancestors: + if not any(p in parents_of.get(q, set()) for q in parents_of.get(anc, set()) if q != p and q in all_ancestors): + best.append(p) + for bp in best: + children_of[bp].append(anc) + + # Direct parents of query term = ancestors with no child that is also an ancestor + direct_parents = [a for a in all_ancestors if not any(a in parents_of.get(other, set()) for other in all_ancestors if other != a)] + + def build(node_id, depth): + node = {'id': node_id, 'label': label_cache.get(node_id, node_id)} + if max_depth == -1 or depth < max_depth: + # Find this node's parents among the ancestors + node_parents = [p for p in parents_of.get(node_id, set()) if p in all_ancestors] + # Most specific parents + best = [] + for p in node_parents: + if not any(p in parents_of.get(q, set()) for q in node_parents if q != p): + best.append(p) + if best: + node['ancestors'] = [ + build(p, depth + 1) + for p in sorted(best, key=lambda x: label_cache.get(x, x)) + ] + return node + + return [build(dp, 1) for dp in sorted(direct_parents, key=lambda x: label_cache.get(x, x))] + + # ------------------------------------------------------------------ + # Assemble result + # ------------------------------------------------------------------ + + _batch_lookup_labels([short_form]) + root = { + 'id': short_form, + 'label': label_cache.get(short_form, short_form), + 'relationship': relationship, + } + + if direction in ('descendants', 'both'): + if relationship == 'subclass_of': + root['descendants'] = _build_descendants_subclass(short_form) + else: + root['descendants'] = _build_descendants_part_of(short_form) + + if direction in ('ancestors', 'both'): + if relationship == 'subclass_of': + root['ancestors'] = _build_ancestors_subclass(short_form, 0, set()) + else: + root['ancestors'] = _build_ancestors_part_of(short_form) + + # ------------------------------------------------------------------ + # Render display text and HTML + # ------------------------------------------------------------------ + + VFB_BASE = 'https://v2.virtualflybrain.org/org.geppetto.frontend/geppetto?id=' + DEFAULT_MAX_SIBLINGS = 10 # truncate large sibling groups in text display + + def _text_tree(node, prefix='', is_last=True, is_root=True, max_siblings=DEFAULT_MAX_SIBLINGS): + """Render a node and its descendants as a text tree.""" + lines = [] + label = f'{node["label"]} ({node["id"]})' + if is_root: + lines.append(label) + else: + lines.append(prefix + ('└── ' if is_last else '├── ') + label) + child_prefix = prefix + (' ' if is_last else '│ ') + children = node.get('descendants', []) + for i, child in enumerate(children): + if max_siblings is not None and len(children) > max_siblings and i == max_siblings - 2: + lines.append(child_prefix + f'├── ... ({len(children) - max_siblings + 1} more)') + lines.extend(_text_tree(children[-1], child_prefix, True, False, max_siblings)) + break + lines.extend(_text_tree(child, child_prefix, i == len(children) - 1, False, max_siblings)) + return lines + + def _invert_ancestor_tree(ancestors, leaf_node): + """Invert ancestor tree so highest-level terms are roots and the query term is a leaf. + + Returns a list of top-level nodes, each with 'descendants' pointing downward + toward the query term. + """ + def _collect_roots(ancestors): + """Find the top-level ancestors (those with no further ancestors).""" + roots = [] + for a in ancestors: + if 'ancestors' in a and a['ancestors']: + roots.extend(_collect_roots(a['ancestors'])) + else: + roots.append(a) + return roots + + def _build_inverted(node, ancestors, target_leaf): + """Build downward tree from an ancestor node toward the target leaf.""" + # Find which of the ancestors list directly to this node + children_toward_leaf = [] + for a in ancestors: + if 'ancestors' in a and a['ancestors']: + for grandparent in a['ancestors']: + if grandparent['id'] == node['id']: + children_toward_leaf.append(a) + elif a['id'] == node['id']: + # This ancestor IS the current node — leaf's direct parent + pass + + result = {'id': node['id'], 'label': node['label']} + if children_toward_leaf: + result['descendants'] = [ + _build_inverted(c, ancestors, target_leaf) + for c in sorted(children_toward_leaf, key=lambda x: x.get('label', '')) + ] + else: + # This node's child is the query term itself + result['descendants'] = [leaf_node] + return result + + # Collect all ancestor nodes into a flat list with their parent links + all_nodes = {} # id -> node + parent_map = {} # child_id -> set of parent_ids + + def _walk(ancestors, child_id=None): + for a in ancestors: + all_nodes[a['id']] = {'id': a['id'], 'label': a['label']} + if child_id: + parent_map.setdefault(child_id, set()).add(a['id']) + if 'ancestors' in a and a['ancestors']: + _walk(a['ancestors'], a['id']) + + _walk(ancestors, leaf_node['id']) + + # Roots are nodes that aren't children of anything + all_children = set() + for children in parent_map.values(): + all_children.update(children) + all_parents = set(parent_map.keys()) + root_ids = all_children - all_parents + + if not root_ids: + # Fallback: all direct ancestors are roots + root_ids = {a['id'] for a in ancestors} + + # Add leaf node to all_nodes so its label is available + all_nodes[leaf_node['id']] = leaf_node + + # Build downward trees from each root + def _build_down(node_id): + node = {'id': node_id, 'label': all_nodes.get(node_id, {}).get('label', node_id)} + children_ids = [cid for cid, pids in parent_map.items() if node_id in pids] + if children_ids: + node['descendants'] = [ + _build_down(cid) + for cid in sorted(children_ids, key=lambda x: all_nodes.get(x, {}).get('label', x)) + ] + return node + + return [_build_down(rid) for rid in sorted(root_ids, key=lambda x: all_nodes.get(x, {}).get('label', x))] + + display_lines = [] + if 'ancestors' in root and root['ancestors']: + rel_label = 'Part of' if relationship == 'part_of' else 'Is a' + display_lines.append(f'{rel_label} (ancestors):') + inverted = _invert_ancestor_tree(root['ancestors'], {'id': root['id'], 'label': root['label']}) + for node in inverted: + display_lines.extend(_text_tree(node)) + display_lines.append('') + + if 'descendants' in root: + rel_label = 'Has parts' if relationship == 'part_of' else 'Subtypes' + display_lines.append(f'{rel_label} (descendants):') + display_lines.extend(_text_tree(root)) + + root['display'] = '\n'.join(display_lines) + + # Full display (no sibling truncation) + full_lines = [] + if 'ancestors' in root and root['ancestors']: + rel_label = 'Part of' if relationship == 'part_of' else 'Is a' + full_lines.append(f'{rel_label} (ancestors):') + inverted_full = _invert_ancestor_tree(root['ancestors'], {'id': root['id'], 'label': root['label']}) + for node in inverted_full: + full_lines.extend(_text_tree(node, max_siblings=None)) + full_lines.append('') + + if 'descendants' in root: + rel_label = 'Has parts' if relationship == 'part_of' else 'Subtypes' + full_lines.append(f'{rel_label} (descendants):') + full_lines.extend(_text_tree(root, max_siblings=None)) + + root['display_full'] = '\n'.join(full_lines) + + # HTML rendering + def _html_tree_nodes(node, depth=0, key='descendants'): + """Render a node as nested HTML list items.""" + sid = node['id'] + label = node['label'] + link = f'{label} ({sid})' + children = node.get(key, []) + if not children: + return f'
  • {link}
  • ' + items = ''.join(_html_tree_nodes(c, depth + 1, key) for c in children) + return f'
  • 1 else " open"}>{link}
      {items}
  • ' + + html_parts = [ + '', + f'Hierarchy: {root["label"]}', + '', + f'

    {root["label"]} ({root["id"]})

    ', + ] + + if 'ancestors' in root and root['ancestors']: + rel_label = 'Part of' if relationship == 'part_of' else 'Is a' + html_parts.append(f'

    {rel_label} (ancestors)

    ') + inverted_html = _invert_ancestor_tree(root['ancestors'], {'id': root['id'], 'label': root['label']}) + items = ''.join(_html_tree_nodes(n) for n in inverted_html) + html_parts.append(f'
      {items}
    ') + + if 'descendants' in root and root['descendants']: + rel_label = 'Has parts' if relationship == 'part_of' else 'Subtypes' + html_parts.append(f'

    {rel_label} (descendants)

    ') + root_node_html = _html_tree_nodes({'id': root['id'], 'label': root['label'], 'descendants': root['descendants']}) + html_parts.append(f'
      {root_node_html}
    ') + + html_parts.append('') + root['html'] = '\n'.join(html_parts) + + return root