# Data sets
Parsing and discovery, loading.

In [None]:
#| default_exp data_sets

In [None]:
#| hide 
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| hide
from nbdev.showdoc import *
from fastcore.basics import *

In [None]:
#| export
import os
import re
import logging
import warnings
import numpy as np
import polars as pl
from pathlib import Path
from copy import deepcopy
from typing import Dict, List
from typing import DefaultDict, Iterable
from collections import defaultdict
from pisces.utils import determine_header_rows_and_delimiter

## Data set discovery using Prefix Trees

Data sets are discovered based on being folders within the provided data set root directory which contain subdirectories that start with `cleaned_`.  

Once the data sets are discovered, we take the `cleaned_<feature>` subdirectories and use the `<feature>` as the feature name. 

Then we take the files within the `cleaned_<feature>` subdirectories and discover the ids that data set has for that feature. These do not need to be the same across features, hence all of our data getters might also return `None`.

Automagic ID discovery is done using a prefix tree, which is a data structure that allows for efficient searching of strings based on their prefixes.

In [None]:
#| export
class SimplifiablePrefixTree:
    """
    A standard prefix tree with the ability to "simplify" itself by combining nodes with only one child.
    These also have the ability to "flatten" themselves, which means to convert all nodes at and below a certain depth into leaves on the most recent ancestor of that depth.
    """
    def __init__(self, delimiter: str = "", # The delimiter to use when splitting words into characters. If empty, the words are treated as sequences of characters.
                 key: str = "", # The key of the current node in its parent's `.children` dictionary. If empty, the node is (likely) the root of the tree.
                 ):
        """
        key : str
            The key of the current node in its parent's `.children` dictionary. If empty, the node is (likely) the root of the tree.
        children : Dict[str, SimplifiablePrefixTree]
            The children of the current node, stored in a dictionary with the keys being the children's keys.
        is_end_of_word : bool
            Whether the current node is the end of a word. Basically, is this a leaf node?
        delimiter : str
            The delimiter to use when splitting words into characters. If empty, the words are treated as sequences of characters.
        print_spacer : str
            The string to use to indent the printed tree.
        """
        self.key = key
        self.children: Dict[str, SimplifiablePrefixTree] = {}
        self.is_end_of_word = False
        self.delimiter = delimiter
        self.print_spacer = "++"
    
    def chars_from(self, word: str):
        """
        Splits a word into characters, using the `delimiter` attribute as the delimiter.
        """
        return word.split(self.delimiter) if self.delimiter else word

    def insert(self, word: str):
        """
        Inserts a word into the tree. If the word is already in the tree, nothing happens.
        """
        node = self
        for char in self.chars_from(word):
            if char not in node.children:
                node.children[char] = SimplifiablePrefixTree(self.delimiter, key=char)
            node = node.children[char]
        node.is_end_of_word = True

    def search(self, word: str) -> bool:
        """
        Searches for a word in the tree.
        """
        node = self
        for char in self.chars_from(word):
            if char not in node.children:
                return False
            node = node.children[char]
        return node.is_end_of_word
    
    def simplified(self) -> 'SimplifiablePrefixTree':
        """
        Returns a simplified copy of the tree. The original tree is not modified.
        """
        self_copy = deepcopy(self)
        return self_copy.simplify()
    
    def simplify(self):
        """
        Simplifies the tree in place.
        """
        if len(self.children) == 1 and not self.is_end_of_word:
            child_key = list(self.children.keys())[0]
            self.key += child_key
            self.children = self.children[child_key].children
            self.simplify()
        else:
            current_keys = list(self.children.keys())
            for key in current_keys:
                child = self.children.pop(key)
                child.simplify()
                self.children[child.key] = child
        return self
    
    def reversed(self) -> 'SimplifiablePrefixTree':
        """
        Returns a reversed copy of the tree, except with with `node.key` reversed versus the node in `self.children`. The original tree is not modified.
        """
        rev_self = SimplifiablePrefixTree(self.delimiter, key=self.key[::-1])
        rev_self.children = {k[::-1]: v.reversed() for k, v in self.children.items()}
        return rev_self
    
    def flattened(self, max_depth: int = 1) -> 'SimplifiablePrefixTree':
        """
        Returns a Tree identical to `self` up to the given depth, but with all nodes at + below `max_depth` converted into leaves on the most recent ancestor of depth `max_depth - 1`.
        """
        flat_self = SimplifiablePrefixTree(self.delimiter, key=self.key)
        if max_depth == 0:
            if not self.is_end_of_word:
                warnings.warn(f"max_depth is 0, but {self.key} is not a leaf.")
            return flat_self
        if max_depth == 1:
            for k, v in self.children.items():
                if v.is_end_of_word:
                    flat_self.children[k] = SimplifiablePrefixTree(self.delimiter, key=k)
                else:
                    # flattened_children = v._pushdown()
                    for flattened_child in v._pushdown():
                        flat_self.children[flattened_child.key] = flattened_child
        else:
            for k, v in self.children.items():
                flat_self.children[k] = v.flattened(max_depth - 1)
        return flat_self
    
    def _pushdown(self) -> List['SimplifiablePrefixTree']:
        """
        Returns a list corresponding to the children of `self`, with `self.key` prefixed to each child's key.
        """
        pushed_down = [
            c
            for k in self.children.values()
            for c in k._pushdown()
        ]
        for i in range(len(pushed_down)):
            pushed_down[i].key = self.key + self.delimiter + pushed_down[i].key

        if not pushed_down:
            return [SimplifiablePrefixTree(self.delimiter, key=self.key)]
        else:
            return pushed_down
            

    def __str__(self):
        # prints .children recursively with indentation
        return self.key + "\n" + self.print_tree()

    def print_tree(self, indent=0) -> str:
        result = ""
        for key, child in self.children.items():
            result +=  self.print_spacer * indent + "( " + child.key + "\n"
            result += SimplifiablePrefixTree.print_tree(child, indent + 1)
        return result


class IdExtractor(SimplifiablePrefixTree):
    """
    Class extending the prefix trees that incorporates the algorithm for extracting IDs from a list of file names. The algorithm is somewhat oblique, so it's better to just use the `extract_ids` method versus trying to use the prfix trees directly at the call site.
    
    The algorithm is based on the assumption that the IDs are the same across all file names, but that the file names may have different suffixes. The algorithm reverses the file names, inserts them into the tree, and then simplifes and flattens that tree in order to find the IDs as leaves of that simplified tree.

    1. Insert the file name string into the tree, but with each string **reversed**.
    2. Simplify the tree, combining nodes with only one child.
    3. There may be unexpected suffix matches for these IDs, so we flatten the tree to depth 1, meaning all children of the root are combined to make leaves.
    4. The leaves are the IDs we want to extract. However, we must reverse these leaf keys to get the original IDs, since we reversed the file names in step 1.

    TODO:
    * If we want to find IDs for files with differing prefixes instead, we should instead insert the file names NOT reversed and then NOT reverse in the last step.

    * To handle IDs that appear in the middle of file names, we can use both methods to come up with a list of potential IDs based on prefix and suffix, then figure out the "intersection" of those lists. (Maybe using another prefix tree?)

    """
    def __init__(self, delimiter: str = "", key: str = ""):
        super().__init__(delimiter, key)

    def extract_ids(self, files: List[str]) -> List[str]:
        for file in files:
            self.insert(file[::-1])
        return sorted([
            c.key for c in self
                .prefix_flattened()
                .children
                .values()
        ])
    
    def prefix_flattened(self) -> 'IdExtractor':
        return self.simplified().flattened(1).reversed()
    

In [None]:
#| hide
entries = [
    '3XYZabc12',
    '3XY&abc12',
    '3XYAabc12',
    '3XYBabc12',
    'MMVQabc12',
    'NMVQabc12',
]

expected_ids = sorted([
    '3XYZ',
    '3XY&',
    '3XYA',
    '3XYB',
    'MMVQ',
    'NMVQ',
])

id_extractor = IdExtractor()

ids = id_extractor.extract_ids(entries)

for i, (expected, actual) in enumerate(zip(expected_ids, ids)):
    assert expected == actual, f"Expected {expected}, but got {actual} at index {i}"


In [None]:
#| hide
print(id_extractor)


( 2
++( 1
++++( c
++++++( b
++++++++( a
++++++++++( Z
++++++++++++( Y
++++++++++++++( X
++++++++++++++++( 3
++++++++++( &
++++++++++++( Y
++++++++++++++( X
++++++++++++++++( 3
++++++++++( A
++++++++++++( Y
++++++++++++++( X
++++++++++++++++( 3
++++++++++( B
++++++++++++( Y
++++++++++++++( X
++++++++++++++++( 3
++++++++++( Q
++++++++++++( V
++++++++++++++( M
++++++++++++++++( M
++++++++++++++++( N



In [None]:
#| hide
print(id_extractor.prefix_flattened())

abc12
( 3XYZ
( 3XY&
( 3XYA
( 3XYB
( MMVQ
( NMVQ



In [None]:
#| export
LOG_LEVEL = logging.INFO

class DataSetObject:
    FEATURE_PREFIX = "cleaned_"

    # Set up logging
    logger = logging.getLogger(__name__)
    logger.setLevel(LOG_LEVEL)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

    def __init__(self, name: str, path: Path):
        self.name = name
        self.path = path
        self.ids: List[str] = []

        # keeps track of the files for each feature and user
        self._feature_map: DefaultDict[str, Dict[str, str]] = defaultdict(dict)
        self._feature_cache: DefaultDict[str, Dict[str, pl.DataFrame]] = defaultdict(dict)
    
    @property
    def features(self) -> List[str]:
        return list(self._feature_map.keys())
    
    def __str__(self):
        return f"{self.name}: {self.path}"

    def get_feature_data(self, feature: str, id: str) -> pl.DataFrame | None:
        if feature not in self.features:
            warnings.warn(f"Feature {feature} not found in {self.name}. Returning None.")
            return None
        if id not in self.ids:
            warnings.warn(f"ID {id} not found in {self.name}")
            return None
        if (df := self._feature_cache[feature].get(id)) is None:
            file = self.get_filename(feature, id)
            if not file:
                return None
            self.logger.debug(f"Loading {file}")
            try:
                n_rows, delimiter = determine_header_rows_and_delimiter(file)
                # self.logger.debug(f"n_rows: {n_rows}, delimiter: {delimiter}")
                df = pl.read_csv(file, has_header=True if n_rows > 0 else False,
                                 skip_rows=max(n_rows-1, 0), 
                                 separator=delimiter)
            except Exception as e:
                warnings.warn(f"Error reading {file}:\n{e}")
                return None
            # sort by time when loading
            df.sort(df.columns[0])
            self._feature_cache[feature][id] = df
        return df

    def get_filename(self, feature: str, id: str) -> Path | None:
        feature_ids = self._feature_map.get(feature)
        if feature_ids is None:
            # raise ValueError(f"Feature {feature_ids} not found in {self.name}")
            print(f"Feature {feature_ids} not found in {self.name}")
            return None
        file = feature_ids.get(id)
        if file is None:
            # raise ValueError
            print(f"ID {id} not found in {self.name}")
            return None
        return self.get_feature_path(feature)\
            .joinpath(file)
    
    def get_feature_path(self, feature: str) -> Path:
        return self.path.joinpath(self.FEATURE_PREFIX + feature)
    
    def _extract_ids(self, files: List[str]) -> List[str]:
        return IdExtractor().extract_ids(files)
    
    def add_feature_files(self, feature: str, files: Iterable[str]):
        if feature not in self.features:
            self.logger.debug(f"Adding feature {feature} to {self.name}")
            self._feature_map[feature] = {}
        # use a set for automatic deduping
        deduped_ids = set(self.ids)
        extracted_ids = sorted(self._extract_ids(files))
        files = sorted(list(files))
        # print('# extracted_ids:', len(extracted_ids))
        for id, file in zip(extracted_ids, files):
            # print('adding data for id:', id, 'file:', file)
            self._feature_map[feature][id] = file
            # set.add only adds the value if it's not already in the set
            deduped_ids.add(id)
        self.ids = sorted(list(deduped_ids))
    
    def get_feature_files(self, feature: str) -> Dict[str, str]:
        return {k: v for k, v in self._feature_map[feature].items()}
    
    def get_id_files(self, id: str) -> Dict[str, str]:
        return {k: v[id] for k, v in self._feature_map.items()}
    
    def load_feature_data(self, feature: str | None, id: str | None) -> Dict[str, np.ndarray]:
        if feature not in self.features:
            raise ValueError(f"Feature {feature} not found in {self.name}")
    
    @classmethod
    def find_data_sets(cls, root: str | Path) -> Dict[str, 'DataSetObject']:
        root = str(root).replace("\\", "/") # Use consistent separators
        if not root.endswith("/"):
            root += "/"

        feature_dir_regex = rf".*/(.+)/{cls.FEATURE_PREFIX}(.+)/?"

        data_sets: Dict[str, DataSetObject] = {}
        for root_dir, dirs, files in os.walk(root, followlinks=True):
            normalized_root_dir = root_dir.replace("\\", "/")
            if (root_match := re.match(feature_dir_regex, normalized_root_dir)):
                data_set_name = root_match.group(1)
                feature_name = root_match.group(2)
                # TODO: I think this is the part that fails with only one subject
                if (data_set := data_sets.get(data_set_name)) is None:
                    data_set = DataSetObject(data_set_name, Path(root_dir).parent)
                    data_sets[data_set.name] = data_set
                # Filter out unwanted files
                relevant_files = [f for f in files if not f.startswith(".") and not f.endswith(".tmp")]
                data_set.add_feature_files(feature_name, relevant_files)

        return data_sets

In [None]:
#| hide
import nbdev
nbdev.nbdev_export()