In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pickle

In [None]:
import pandas as pd

data = pd.read_csv("../../tests/testing_data/train.csv")

In [None]:
data.columns

In [None]:
from twinn_ml_interface.objectmodels import (
    Node,
    Unit,
)

In [None]:
type_dict = {
    'altenburg1': 'discharge', 
    'eschweiler': 'discharge',
    'herzogenrath1': 'discharge',
    'juelich': 'discharge',
    'stah': 'discharge',
    'middenroer': 'precipitation',
    'urft': 'precipitation',
    'evap': 'evaporation',
}

discharge_stations = [
    'altenburg1', 'eschweiler', 'herzogenrath1', 'juelich', 'stah',
]

precip_areas = [
    'middenroer', 'urft',
]

In [None]:
tenant_node = None

child_tenant_node = Node(
    val=Unit(
        unit_code="RUHR",
        unit_type_code="RIVER",
        active=True,
        geometry=None,
    ),
    parent=tenant_node,
    children=None,
)

tenant_node = Node(
    val=Unit(
        unit_code="WL",
        unit_type_code="TENANT",
        active=True,
    ),
    parent=None,
    children=[child_tenant_node],
)

child_tenant_node.parent = tenant_node

tenant_node

In [None]:
child_tenant_node.children = [
    Node(
        val=Unit(
            unit_code=unit_code.upper(),
            unit_type_code="DISCHARGE_STATION",
            active=True,
            geometry=None,
        ),
        parent=child_tenant_node,
        children=None,
    )
    for unit_code in discharge_stations
]

In [None]:
child_tenant_node.children += [
    Node(
        val=Unit(
            unit_code=unit_code.upper(),
            unit_type_code="PRECIPITATION_AREA",
            active=True,
            geometry=None,
        ),
        parent=child_tenant_node,
        children=None,
    )
    for unit_code in precip_areas
]

In [None]:
child_tenant_node.children += [
    Node(
        val=Unit(
            unit_code="EVAP",
            unit_type_code="EVAPORATION_STATION",
            active=True,
            geometry=None,
        ),
        parent=child_tenant_node,
        children=None,
    )
]

In [None]:
tenant_node.children[0].parent == tenant_node

In [None]:
tenant_node

In [None]:
hierarchy = {
    '17def674-2d70-4a5f-ad9e-e2eeffe27944': tenant_node
}

In [None]:
hierarchy

In [None]:
with open("/mnt/c/tmp_data/darrow/hierarchy.pkl", "wb") as f:
    pickle.dump(hierarchy, f)

In [None]:
with open("/mnt/c/tmp_data/darrow/hierarchy.pkl", "rb") as f:
    test = pickle.load(f)
    
test

## Can we get the hierarchy from the API?

In [None]:
import logging
import pickle
from abc import ABC, abstractmethod
from typing import Any

from azure.storage.blob import BlobClient
from twinn_ml_interface.objectmodels import Node, RelativeType


class Hierarchy:
    root: Node

    def __init__(self, root: Node) -> None:
        """A graph that tracks the relations between different units."""
        self.root = root

    def get_nodes(self, relatives: list[RelativeType], node: Node) -> list[Node] | None:
        """Get a list of nodes from a hierarchy tree.

        Args:
            relatives (list[RelativeType]): a list of relatives to get.
            node (Node): the node of the hierarchy tree.

        Returns:
            list[Node], optional: the list of nodes with units.
        """
        relative: RelativeType = relatives.pop(0)
        relative_nodes = self.get_next_level(node, relative)
        if not relatives or relative_nodes is None:
            return relative_nodes
        next_nodes = []
        for unit in relative_nodes:
            next_nodes = next_nodes + self.get_nodes(relatives.copy(), unit)
        return next_nodes

    def get_next_level(self, node: Node, relative_type: RelativeType) -> list[Node]:
        """Get a relative of the given node.

        Args:
            node (Node): the node.
            relative_type (RelativeType): where to go from the node.

        Returns:
            list[Node]: the resulting nodes.
        """
        match relative_type:
            case RelativeType.PARENT:
                if node.parent is None:
                    raise HierarchyError(BadHierarchyRequestErrorMsg())
                return [node.parent]
            case RelativeType.CHILDREN:
                if node.children is None:
                    return []
                return node.children
            case RelativeType.SELF:
                if node is None:
                    return []
                return [node]
        return None

    def find_node_in_tree(self, unit_name: str) -> Node:
        """Retrieve a node from a hierarchy tree.

        Args:
            unit_name (str): the unit to search for.

        Returns:
            Node: the node matching the unit.
        """
        to_visit = [self.root]
        while to_visit:
            current = to_visit.pop(0)
            if current.val.unit_code == unit_name:
                return current
            if current.children:
                to_visit += current.children
        raise HierarchyError(TargetNotInHierarchyErrorMsg(unit_name))

    def get_all_units_for_node(self, root: Node, *, only_leaf_nodes: bool = True) -> list[Node]:
        """Gets a list of all child units (optionally: leaf nodes only) for a given node.

        Args:
            root (Node): The node for which you want to retrieve the children.

        Kwargs:
            only_leaf_nodes (bool): Whether to retrieve leaf nodes, or all nodes. Defaults to True.

        Returns:
            list[Node]: The list of child nodes of the root Node.
        """
        root_children = []
        if root.children is None:
            root_children.append(root)
            return root_children
        if not only_leaf_nodes:
            root_children.append(root)
        for child in root.children:
            root_children += self.get_all_units_for_node(child, only_leaf_nodes=only_leaf_nodes)
        return root_children

    def get_all_units(self, *, only_leaf_nodes: bool = True) -> list[Node]:
        """Gets a list of all child units (Leaf nodes only) for the hierarchy.

        Kwargs:
            only_leaf_nodes (bool): Whether to retrieve leaf nodes, or all nodes. Defaults to True.

        Returns:
            list[Node]: The list of child nodes of the root Node.
        """
        return self.get_all_units_for_node(self.root, only_leaf_nodes=only_leaf_nodes)

    def get_prefix(self, node: Node) -> str:
        """Get prefix of a node.

        Args:
            node (Node): Node to get prefix for

        Returns:
            str: Prefix of the node
        """
        prefix_list = []
        current_node = node
        while current_node != self.root:
            if current_node.children:
                prefix_list.append(current_node.val.unit_code)
            current_node = current_node.parent
        root_code = self.root.val.unit_code
        return "/".join([root_code] + prefix_list[::-1])

    def find_lowest_common_ancestor(self, current: Node, units: list[str]) -> Node | None:
        """Finds the lowest common ancestor for a given node."""
        if current.val.unit_code in units:
            return current
        if not current.children:
            return None
        results = list(filter(None, [self.find_lowest_common_ancestor(child, units) for child in current.children]))
        if len(results) == 1:
            return results[0]
        if len(results) > 1:
            return current
        return None

In [None]:
import math

from twinn_ml_interface.objectmodels import (
    DataLabelConfigTemplate,
    DataLevel,
    Node,
    RelativeType,
    Unit,
    UnitTag,
    UnitTagTemplate,
)


def get_unit_tags_from_data_template(
    data_config_template: DataLabelConfigTemplate,
    node: Node,
    hierarchy: Hierarchy,
    *,
    fallback: bool = False,
) -> list[UnitTag]:
    """Get the unit_codes and tag_names from config templates.

    Args:
        data_config_template (DataLabelConfigTemplate): _description_
        node (Node): _description_
        hierarchy (Hierarchy): _description_

    Kwargs:
        fallback (bool, optional): _description_. Defaults to False.

    Returns:
        list[UnitTag]: _description_
    """
    priority_unit_tags = []
    for unit_tag_template in data_config_template.unit_tag_templates:
        if fallback:
            unit_tag_template.relative_path = _make_fallback(unit_tag_template.relative_path)
        priority_unit_tags += get_unit_tags_from_template(
            unit_tag_template,
            node,
            hierarchy,
        )
    return priority_unit_tags


def get_unit_tags_from_template(
    unit_tag_template: UnitTagTemplate,
    node: Node,
    hierarchy: Hierarchy,
) -> list[UnitTag]:
    """Convert a UnitTagTemplate to a list of UnitTags.

    Args:
        unit_tag_template (UnitTagTemplate): the template to convert.
        node (Node): the node to search relatives from.
        hierarchy (Hierarchy): the hierarchy.

    Returns:
        list[UnitTag]: the resulting UnitTags it got as relatives.
    """
    unit_tags = []
    for unit in get_units_from_hierarchy(unit_tag_template.relative_path, node, hierarchy):
        unit_tags += [UnitTag(unit, tag) for tag in unit_tag_template.tags]
    return unit_tags


def get_units_from_hierarchy(relatives: list[RelativeType], node: Node, hierarchy: Hierarchy) -> list[Unit]:
    """Get units from the hierarchy tree for a list of relatives.

    Args:
        relatives (list[RelativeType]): the relatives path.
        node (Node): the node in the hierarchy used as target.
        hierarchy (Hierarchy): a unit hierarchy.

    Returns:
        list[Unit]: a list of units.
    """
    resulting_nodes: list[Node] = []

    nodes = hierarchy.get_nodes(relatives.copy(), node)
    if nodes and None not in nodes:
        resulting_nodes = nodes
    return [node.val for node in resulting_nodes]


In [None]:
def get_unit_tags_from_data_templates(
    data_config_templates: list[DataLabelConfigTemplate],
    node: Node,
    hierarchy: Hierarchy,
) -> tuple[dict[DataLevel, list[UnitTag]], dict[DataLevel, list[UnitTag]]]:
    """Convert DataLabelConfigTemplates to lists of UnitTags using the hierarchy.

    Args:
        data_config_templates (list[DataLabelConfigTemplate]): a list of config templates to convert.
        node (Node): the node in the hierarchy the model is created for.
        hierarchy (Hierarchy): a unit hierarchy.

    Returns:
        dict[DataLevel, list[UnitTag]]: all the unit gotten from the template.
        dict[DataLevel, list[UnitTag]]: the unit retrieved without using fallbacks.
    """
    priority_unit_tags_dict = {}
    all_unit_tags_dict = {}
    ## TODO (Miguel): continue here: Discuss if UnitTag is a real alternative to DataLabelConfigTemplate!
    for data_config_template in data_config_templates:
        priority_unit_tags, all_unit_tags = _get_unit_tags_from_data_template(
            data_config_template,
            node,
            hierarchy,
        )
        all_unit_tags_dict[data_config_template.data_level] = all_unit_tags
        priority_unit_tags_dict[data_config_template.data_level] = priority_unit_tags
    return all_unit_tags_dict, priority_unit_tags_dict

def _get_unit_tags_from_data_template(
    data_config_template: DataLabelConfigTemplate,
    node: Node,
    hierarchy: Hierarchy,
) -> tuple[list[UnitTag], list[UnitTag]]:
    desired_number = data_config_template.desired_tag_number
    priority_unit_tags = get_unit_tags_from_data_template(data_config_template, node, hierarchy)
    if not desired_number or desired_number == len(priority_unit_tags):
        return priority_unit_tags, priority_unit_tags
    if len(priority_unit_tags) > desired_number:  # too many tags
        filtered_priority_unit_tags = _select_units_by_distance(node.val, priority_unit_tags, desired_number)
        return filtered_priority_unit_tags, filtered_priority_unit_tags
    # get fallback units
    fallback_units = get_unit_tags_from_data_template(
        data_config_template,
        node,
        hierarchy,
        fallback=True,
    )
    diff_unit_tags = set(fallback_units) - set(priority_unit_tags)
    filtered_fallback_unit_tags = _select_units_by_distance(
        node.val,
        list(diff_unit_tags),
        data_config_template.desired_tag_number - len(priority_unit_tags),
    )
    return priority_unit_tags, priority_unit_tags + filtered_fallback_unit_tags


In [None]:
from twinn_ml_interface.objectmodels import (
    AvailabilityLevel,
    Configuration,
    DataLabelConfigTemplate,
    DataLevel,
    MetaDataLogger,
    ModelCategory,
    RelativeType,
    Tag,
    UnitTagTemplate,
    UnitTag,
    Unit,
    WindowViability,
)

In [None]:
with open("/mnt/c/tmp_data/darrow/hierarchy.pkl", "rb") as f:
    hierarchy_dict = pickle.load(f)
    
hierarchy_dict

In [None]:
hierarchy = Hierarchy(hierarchy_dict['17def674-2d70-4a5f-ad9e-e2eeffe27944'])

In [None]:
data_config_templates = [
    DataLabelConfigTemplate(
        data_level=DataLevel.SENSOR,
        unit_tag_templates=[UnitTagTemplate([RelativeType.PARENT, RelativeType.CHILDREN], [Tag("DISCHARGE")])],
        availability_level=AvailabilityLevel.FILTER_UNTIL_NOW,
    ),
    DataLabelConfigTemplate(
        data_level=DataLevel.WEATHER,
        unit_tag_templates=[
            UnitTagTemplate([RelativeType.PARENT, RelativeType.CHILDREN], [Tag("PRECIPITATION"), Tag("EVAPORATION")])
        ],
        availability_level=AvailabilityLevel.FILTER_UNTIL_NOW,
    ),
]


node = hierarchy.find_node_in_tree("STAH")
    

all_unit_tags_dict, priority_unit_tags_dict = get_unit_tags_from_data_templates(data_config_templates, node, hierarchy)

In [None]:
priority_unit_tags_dict

In [None]:
all_unit_tags_dict.values()