In [1]:
#!pip install matplotlib
#!pip install .

In [2]:
import json
import pprint
import matplotlib.pyplot as plt

from anytree import PreOrderIter, RenderTree, Node
from anytree.search import findall
from comorbid_graphs import ComorbidGraph, ComorbidGraphNode

### Creating graph 

In [3]:
with open('../tests/fixtures/symp_tree.json') as f:
    data = json.load(f)
cg = ComorbidGraph(data, ComorbidGraphNode, assign_ids=True)

In [4]:
print(cg.pretty_print_tree()[:500])
print('...')

Source
├── symptom
│   ├── urinary system symptom
│   ├── general symptom
│   ├── respiratory system and chest symptom
│   ├── neurological and physiological symptom
│   ├── musculoskeletal system symptom
│   ├── nervous system symptom
│   ├── abdominal symptom
│   ├── head and neck symptom
│   ├── skin and integumentary tissue symptom
│   ├── hemic and immune system symptom
│   ├── digestive system symptom
│   ├── cardiovascular system symptom
│   ├── nutrition, metabolism, and development symp
...


## Searching
Searching should be extensible to allow sql search for larger data than KGs.
For this we have a many step search.
- search for ancestors, parents first
  - get the names of these, include-exclude
  - merge all of the nodes, merging with include-exclude
- for docs in the ending KnowledgeGraph
  - filter by type
  - filter by body length
- search in the body

### 1. Filtering
Filter based on node properties if included or not - `name`, `parent`, `type`, `body-length`, `content`.


In [5]:
class FilterableNodeMixin(object):
    """Filterable graph nodes
      - filter by name
      - filter by parent name
      - filter by type
      - filter by body length
      - filter by content containing keywords

    # define functions as follows
    def _filter_x(self, inc_list, exc_list, strict=False):
        if condition for failing:
            return False
        return True  # it passed the condition for failing
    """

    def apply_filters(self, query_dict: dict, strict:bool = False, late_body_loading: bool = False) -> bool:
        """Main iterative filter, iterates through the necessary filters for running."""
        
        # define filter functions
        # for keeping the code clean, and consistent
        # TODO: find a more elegant way for this
        filters_dict = {
            "name": self._filter_name,
            "parent": self._filter_parent,
            "type": self._filter_type,
            "text_length": self._filter_text_length,
        }
        for filter_ in ["name", "parent", "type", "text_length"]:
            if filter_ in query_dict and not filters_dict[filter_](
                query_dict[filter_]["inc"], query_dict[filter_]["exc"], strict=strict
            ):
                return False

        # exceptional case of body filtering
        if not late_body_loading:
            if "content" in query_dict and not self._filter_content(
                query_dict["content"]["inc"], query_dict["content"]["exc"]
            ):
                return False
        return True

    def _filter_name(self, inc_list: list, exc_list: list, strict:bool =False) -> bool:
        """Checks if the name is in the included list, and not in the excluded."""
        if strict:
            if self.name in inc_list and self.name not in exc_list:
                return True
            return False
        if any(i in self.name for i in inc_list) and not any(
            i in self.name for i in exc_list
        ):
            return True
        return False

    def _filter_parent(self, inc_list: list, exc_list: list, strict:bool =False) -> bool:
        if not self.parent or not self.parent.name:
            return False
        if strict:
            if self.name in inc_list and self.name not in exc_list:
                return True
            return False
        if any(i in self.parent.name for i in inc_list) and not any(
            i in self.parent.name for i in exc_list
        ):
            return True
        return False

    def _filter_type(self, inc_list: list, exc_list: list, strict:bool =False) -> bool:
        if not hasattr(self, "type"):
            return False
        if self.type in inc_list and self.type not in exc_list:
            return True
        return False

    def _filter_text_length(self, inc_list: list, exc_list: list, strict:bool =False) -> bool:
        if not hasattr(self, "body"):
            return False
        if inc_list != [] and len(self.body) <= int(inc_list[0]):
            return False
        if exc_list != [] and len(self.body) > int(exc_list[0]):
            return False
        return True

    def _filter_content(self, inc_list: list, exc_list: list, strict:bool =False) -> bool:
        """Filters content that has the phrases put in the included list."""

        if not hasattr(self, "body"):
            return False
        if any(i in self.body for i in inc_list) and not any(
            i in self.body for i in exc_list
        ):
            return True
        return False

### 2. Content
Filtering the body for keywords.
But it shouldnt be like the previous case, because loading the body of many of these guys will tire the machine - so will have to allow a filtering done by an en engine (like `sqlite`-engine) which are optimized for these kinds of updates.

In [6]:
"""For Future Performance Issues"""

class LBLNodeMixin(object):
    """Example of Late Loading: Late Body Loading (LBL) 
    
    Title: Filtering the body for keywords.
    Description: Shouldnt be like the python-anytree-filter case,
    because loading the body of many of instances will be not performant -
    so will have to allow a filtering done 
    by an en engine (like `sqlite`-engine) which are optimized for these."""

    def apply_lbl_content_filter(
        self, query_dict: dict, strict:bool = False, late_body_loading: bool = False
    ) -> bool:
        """In case the lbl-loading-filter hasnt been updated, this will be run.
        Preparing in case there is an error."""

        if "content" in query_dict and not self._filter_content(
            query_dict["content"]["inc"], query_dict["content"]["exc"]
        ):
            return False
        return True

class LBLGraphMixin(object):
    def apply_lbl_filters(self, query_dict: dict, late_body_loading: bool = True) ->  dict:
        """OVERWRITE THIS FOR DB PERFORMANCE.
        Late Body Loading filter content.
        
        Get ids of contents that have the query
        - overwritable function to fit with database-prepared query
        """
        return set(
            findall(
                self.tree,
                filter_=lambda node: 
                node.apply_lbl_content_filter(query_dict, late_body_loading=late_body_loading)
            )
        )

### 3. Subgraph
Filter the `ancestors` and `parents`, and use `inc-exc` to zoom in.

In [12]:
class FilterableSubgraphMixin(object):
    """Allow cutting and merging ancestors.
    - get the names of these, include-exclude
    - merge all of the nodes, merging with include-exclude

    TODO: Needs proper testing."""

    @staticmethod
    def get_node_list(base_node, list_words, strict: bool = False):
        return list(
            findall(
                base_node,
                filter_=lambda node: any(x in node.name for x in list_words)
                if not strict
                else any(x == node.name for x in list_words),
            )
        )

    def filter_subgraph(
        self, inc_list, exc_list, node_type, base_name, strict: bool = False
    ):
        """Does the ancestor filtering and can do merging of data-results.
        Prepares for merging into a tree of nodes.
        """

        # copying because of reference-issues
        f = self.tree.deep_copy()
        exclude_nodes = self.get_node_list(f, exc_list, strict=strict)
        include_nodes = self.get_node_list(f, inc_list, strict=strict)

        for node in PreOrderIter(f):

            if node.name == 'spasticity':
                print("*"*100)
                print(node.ancestors)
                print("*"*100)
            if node.name == 'pain':
                print(list(node.ancestors))

            # which nodes crossover
            inc_ancestors = set(include_nodes) & set(list(node.ancestors) + [node])
            if inc_ancestors:
                inc_max_level = max([i.depth for i in inc_ancestors])
            else:
                inc_max_level = -1

            # find nodes that are excluding this node
            exc_ancestors = set(exclude_nodes) & set(list(node.ancestors) + [node])
            if exc_ancestors:
                exc_max_level = max([i.depth for i in exc_ancestors])
            else:
                exc_max_level = -1

            # add if index of inclusion > ind-of-exclusion
            if inc_max_level > exc_max_level:
                include_nodes.append(node)
            else:
                if node.parent:
                    # connect children to previous
                    for i in node.children:
                        i.parent = node.parent
                    node.parent.children += node.children

                    # remove this node
                    node.parent.children = list(
                        [i for i in node.parent.children if i.name != node.name]
                    )
                # deleting its reference just in case
                del node
        return self.merge_nodes_into_tree(include_nodes, node_type=node_type)

    def merge_nodes_into_tree(
        self,
        node_list: list,
        node_type: Node,
        base_name: str = "source",
    ):
        """Merges list of nodes into one main source,
        and re-order if none of ancestors in list.

        Parameters:
        node_list:    list of nodes to include
        node_type:   types of nodes, for creation of new node
        base_name:   for the source node
        """

        # create main node for keeping results
        result_node = node_type(name=base_name)
        
        # iterate through the nodes, and append if not excluded
        for node in node_list:
            # find the closest included ancestor
            included_ancestors = set(node_list) & set(node.ancestors)
            if included_ancestors:
                closest_index = max([i.depth for i in included_ancestors])
                closest_ancestor = [i for i in included_ancestors if i.depth == closest_index][0]
            else:
                closest_ancestor = None
    
            # if parent not found, add directly
            if not node.parent or not node.parent:
                node.old_parent = None
                node.parent = result_node
            elif closest_ancestor and closest_ancestor == node.parent:
                pass
            # if parent not found in list, add directly, but fix children issues
            elif node.parent.name not in [i.name for i in node_list]:
                node.old_parent = node.parent
                # fix parenting issues
                if closest_ancestor:
                    node.parent = closest_ancestor
                else:
                    node.parent = result_node
                node.parent.children = list(
                    [i for i in node.parent.children if i.name != node.name]
                ) + [node]

        return result_node

## Merging all
Create the search language by allowing all entries.   
Control for inputs irregularities and more.

In [13]:
from anytree import Node
import collections

DIRECTION = ["inc_", "exc_", "include_", "exclude_"]
FILTERS = ["name", "content", "type", "text_length", "ancestor", "parent"]
LBL_WARNING = """Warning:
This feature is assuming that you have a DATABASE which is doing the heavylifting.
If you do not have that, please use the normal version.
"""


def flatten(lol):
    return [i for j in lol for i in j]


def flatten_wname(lol):
    return [i.name for j in lol for i in j]


class SearchableMixin(object):
    @staticmethod
    def build_query(query_str):
        # init, fix
        query_str = query_str.replace("  ", " ").replace("\n", " ").rstrip().lstrip()
        if not query_str.startswith("inc_") and not query_str.startswith("exc_"):
            query_str = "inc_content:" + query_str

        # break into lines
        for i in DIRECTION:
            query_str = query_str.replace(i, "_BREAK_" + i)
        values = [
            i.rstrip().lstrip()
            for i in query_str.split("_BREAK_")
            if i.rstrip().lstrip() != ""
        ]

        # build query
        query_dict = {}
        for line in values:
            u_ind = line.find("_")  # underline index
            c_ind = line.find(":")  # comma index
            key = line[u_ind + 1 : c_ind]
            if key not in FILTERS:
                continue
            if key not in query_dict:
                query_dict[key] = {}
            query_dict[key][line[:u_ind]] = [
                i for i in line[c_ind + 1 :].rstrip().lstrip().split(",") if i != ""
            ]

        # fix query
        for key, val in query_dict.items():
            if "inc" not in val:
                query_dict[key]["inc"] = []
            if "exc" not in val:
                query_dict[key]["exc"] = []
        return query_dict

    def advanced_search(
        self,
        query_str: str,
        groupby_type: bool = None,
        full_tree: bool = False,
        base_name: str = "search results",
        late_body_loading: bool = False,
        silent_lbl_warning: bool = False,
        node_type: Node = ComorbidGraphNode,
    ):
        """Deals with the independent steps of searching
        - filter-content
        - filter-others
        - filter-ancestors
        - merge

        For performance we can have this feature enabled: late_body_loading
        - which considers a potential missing body if loading from db.
        This is for cases of large-text bodies that can block the memory.
        For this, we do the filtering in the engine of the DATABASE that we are using.
        Write WARNING when using this feature.
        Otherwise, we load KnowledgeGraph with Body.

        Parameters:
        full_tree:           keep the full tree of ancestors
        base_name:           name for the search top
        node_type:           allows different type of results
                             - maybe if extra properties needed.
        groupby_type:        group results by type
                             keep in mind document:section differentiation

        late_body_loading:   takes care of late-loading of body for
                             performance improvement.
        silent_lbl_warning:  silent the warning of DATABASE required.
        """
        query_dict = self.build_query(query_str)

        nodes = []

        # content filtering
        if late_body_loading:
            if not silent_lbl_warning:
                print(LBL_WARNING)
            # content_filtered_nodes
            nodes.append(self.apply_lbl_filters(query_dict, late_body_loading=True))

        # nodes based on other filters
        nodes.append(
            set(
                findall(
                    self.tree,
                    filter_=lambda node: node.apply_filters(
                        query_dict, late_body_loading=late_body_loading
                    ),
                )
            )
        )
        pprint.pprint(flatten_wname(nodes))

        # ancestor filtering
        if "ancestor" in query_dict:
            ancestor_node = self.filter_subgraph(
                inc_list=query_dict["ancestor"]["inc"],
                exc_list=query_dict["ancestor"]["exc"],
                node_type=node_type,
                base_name=base_name,
            )
            nodes.append([i for i in PreOrderIter(ancestor_node)])
        else:
            ancestor_node = node_type(name=base_name)

        # merging
        # get count of nodes found
        all_nodes = [i.name for j in nodes for i in j]

        # filter out nodes that appear less than expected
        baseline = {node.name: node for node in nodes[0]}
        
        count_nodes = len(nodes)
        for name_key, count in collections.Counter(all_nodes).items():
            if name_key == 'pain':
                print(baseline[name_key].ancestors)
            if count < count_nodes and name_key in baseline:
                del baseline[name_key]

        pprint.pprint([i.name for i in list(baseline.values())])

        # post processing
        incl_list = set([i.name for i in list(baseline.values())])
        ancestor_node = self.filter_subgraph(
            inc_list=incl_list,
            exc_list=list(set(all_nodes) - incl_list),
            node_type=node_type,
            base_name=base_name,
            strict=True,
        )
        return ancestor_node

### Testing
* create Graph
* search basics
* search advanced
* merge

In [14]:
class FilterableNode(ComorbidGraphNode, FilterableNodeMixin, LBLNodeMixin):
    pass
class FilterableGraph(ComorbidGraph, SearchableMixin, FilterableSubgraphMixin, LBLGraphMixin):
    pass

In [15]:
search_cg = FilterableGraph(data, FilterableNode, assign_ids=True)

### Advanced Query

In [16]:
query_str = """
inc_name:symptom,pain
inc_ancestor:nervous system
"""

print(search_cg.build_query(query_str))
print()

ancestor = search_cg.advanced_search(query_str, node_type=FilterableNode, late_body_loading=False)
print()
for pre, fill, node in RenderTree(ancestor):
    print("%s%s" % (pre, node.name))
# test case - check if search is complementary
# disorder = 6
# disease = 18
# disorder-disease = 4
# disease-disorder = 16
# disorder+disease = 22
# passed? yes

{'name': {'inc': ['symptom', 'pain'], 'exc': []}, 'ancestor': {'inc': ['nervous system'], 'exc': []}}

['thyroid symptom',
 'hemic and immune system symptom',
 'salivary gland symptom',
 'infant symptom',
 'painful respiration',
 'nutrition, metabolism, and development symptom',
 'general symptom',
 'urinary system symptom',
 'acute pain',
 'swelling symptom',
 'throat pain',
 'symptom',
 'gas pain',
 'knee pain',
 'severe joint pain',
 'reproductive system symptom',
 'visceral pain',
 'acute painful vision loss',
 'retrobulbar pain',
 'precordial pain',
 'neurological and physiological symptom',
 'neuropathic pain',
 'reflex symptom',
 'phantom pain',
 'breakthrough pain',
 'pleuritic chest pain',
 'neck pain',
 'elbow pain',
 'painful lymph glands',
 'chest pain',
 'colicky pain',
 'muscle pain',
 'parotid pain',
 'shoulder pain',
 'severe chest pain',
 'digestive system symptom',
 'nociceptive pain',
 'immune system symptom',
 'testicular pain',
 'hemic system symptom',
 'chronic pa


source


### Subgraph Filtering

In [None]:
res = search_cg.filter_subgraph(
    inc_list=['nervous system', 'pain'],
    exc_list=[],
    node_type=ComorbidGraphNode,
    base_name="search"
)
for pre, fill, node in RenderTree(res):
    print("%s%s" % (pre, node.name))

In [None]:
[i.name for i in includes]

## Ordering Results
There should be two options - first the graph properties.  
Second our simple algorithm based on combination of scores - as found in `comorbid-lab`.