# DataFinder


*****************
*******************
## Studying fast way to turn to strong typing


In [7]:
import sys

sys.path.insert(0, "../src/")

### Some tests

In [8]:
from multimethod import multidispatch, multimethod


@multidispatch
def disp(events, detectors, duration):
    raise NotImplementedError("The type of input parameters is not supported")


@disp.register
def _(events: str | list, detectors: str | list, duration: int): ...


@multimethod
def meth(events, detectors, duration):
    raise NotImplementedError("The type of input parameters is not supported")


@meth.register
def _(events: str | list, detectors: str | list, duration: int): ...

from functools import singledispatch


@singledispatch
def single(events, detectors, duration):
    raise NotImplementedError("The type of input parameters is not supported")


@single.register
def _(events: str | list, detectors: str | list, duration: int): ...


def manual(events, detectors, duration):
    assert type(events) == str
    assert type(detectors) == str
    assert type(duration) == int


def nothing(events: str | list, detectors: str | list, duration: int): ...

In [9]:
%timeit disp("hi", "", 3)
%timeit meth("", "", 1)
%timeit single("", "", 1)
%timeit manual("", "", 1)
%timeit nothing("", "", 1)

---------
### The best option so far is the following:

In [154]:
from functools import lru_cache, wraps
from typing import (
    TypeVar,
    get_type_hints,
    Callable,
    GenericAlias,
    Union,
    List,
    get_origin,
    get_args,
)
import typing
from types import UnionType
from numpy import float32
import numpy

T = TypeVar("T")
import types


def _check_arg(arg, arg_type_hints):
    """
    Checks if an argument is an instance of one or more specified types.

    Args:
        arg: The argument to check.
        arg_type_hints: Type hints specifying the type(s) the argument should have.
            It can be a simple type or a composed type, such as a Union of types.

    Raises:
        TypeError: If the argument is not an instance of the specified type.
        NotImplementedError: If type checking for a certain type is not implemented.

    Examples:
        # Check if 'value' is an integer
        _check_arg(value, int)

        # Check if 'data' is a list of integers
        _check_arg(data, list[int])

        # Check if 'value' is either a string or an integer
        _check_arg(value, typing.Union[str, int])

        # Check if 'matrix' is a list of lists of integers
        _check_arg(matrix, list[list[int]])
    """
    type_error_msg = f"{arg} is not an instance of {arg_type_hints}"
    arg_type_origin = typing.get_origin(arg_type_hints)

    # Check for simple types
    if arg_type_origin is None:
        error_formats = [dict, list]
        if arg_type_hints in error_formats:
            raise NotImplementedError(
                f"Type checking for {arg_type_hints} not implemented"
            )
        if not isinstance(arg, arg_type_hints):
            raise TypeError(type_error_msg)

    # Check for composed types
    else:
        arg_types = typing.get_args(arg_type_hints)
        if arg_type_origin not in (typing.Union, types.UnionType):
            non_monadic = (
                any(typing.get_origin(_) for _ in arg_types) or len(arg_types) > 1
            )
            if non_monadic:
                raise NotImplementedError(
                    f"Type checking for {arg_type_hints} not implemented: Non monadic!"
                )
            # Checking if parent object is right
            if not isinstance(arg, arg_type_origin):
                raise TypeError(type_error_msg)
            if hasattr(arg, "dtype"):
                if arg.dtype not in arg_types:
                    raise TypeError(type_error_msg)
            elif isinstance(arg, list):
                if not all(isinstance(elem, arg_types) for elem in arg):
                    raise TypeError(type_error_msg)
            else:
                raise NotImplementedError(
                    f"Type checking for {arg_types} not implemented"
                )
        else:
            for arg_type in arg_types:
                try:
                    _check_arg(arg, arg_type)
                    break
                except TypeError:
                    continue
                except NotImplementedError as err:
                    raise err
            else:
                raise TypeError(type_error_msg)


def type_check(func: Callable[..., T]) -> Callable[..., T]:
    var_name_and_type = get_type_hints(func)
    var_names = list(var_name_and_type.keys())

    @wraps(func)
    def type_checker(*args, **kwargs):
        for i, arg in enumerate(args):
            # positional arguments are by default in the right order
            arg_type_hints = var_name_and_type[var_names[i]]
            _check_arg(arg, arg_type_hints)

        for kwarg_name, kwarg in kwargs.items():
            kwarg_type = var_name_and_type[kwarg_name]
            _check_arg(kwarg, kwarg_type)

        return func(*args, **kwargs)

    return type_checker


@type_check
def manual_dec(
    events: str | list[int],
    detectors: str | int,
    duration: int,
    dt: numpy.ndarray[numpy.int32],
    alpha: float = 0.1,
): ...


# Esempi di utilizzo
manual_dec([2], 3, 3, numpy.array(5), alpha=0.1)

In [None]:
from pyburst import type_check


@type_check
def manual_dec(
    events: list[str],
    detectors: str,
    duration: int,
    dt: float,
    alpha: float,
): ...


manual_dec(["hi"], "", duration=3, dt=0.5, alpha=0.1)

In [None]:
[""]

------------
-------------
## Timeseries data loading


In [106]:
import sys

sys.path.insert(0, "../src/")

### Data Caching

Un primo approccio potrebbe essere il seguente:

In [107]:
from collections import OrderedDict
from warnings import warn

import numpy
import cupy
import pandas


def _obj_size(obj):
    """
    Calculate the total size in bytes of a nested object, considering numpy.ndarray, cupy.ndarray, and pandas.Series objects as well.

    Parameters
    ----------
    obj : dict or list or tuple or numpy.ndarray or cupy.ndarray or pandas.Series
        The nested object to calculate the size of.

    Returns
    -------
    int
        The total size of the nested object in bytes.

    Examples
    --------
    >>> nested_dict = {
    ...     'a': 1,
    ...     'b': {'c': 2, 'd': 3},
    ...     'e': {'f': {'g': 4, 'h': 5}},
    ...     'numpy_array': numpy.zeros((10, 10)),
    ...     'cupy_array': cupy.zeros((10, 10)),
    ...     'pandas_series': pandas.Series(range(10))
    ... }
    >>> size_in_bytes = _obj_size(nested_dict)
    >>> print("Total size of the nested object:", size_in_bytes, "bytes")
    Total size of the nested object: XXX bytes
    """
    total_size = 0

    if isinstance(obj, dict):
        for key, value in obj.items():
            if isinstance(value, dict):
                total_size += _obj_size(value)
            else:
                total_size += sys.getsizeof(value)
    elif isinstance(obj, (list, tuple)):
        for item in obj:
            total_size += _obj_size(item)
    elif isinstance(obj, (numpy.ndarray, cupy.ndarray, pandas.Series)):
        total_size += obj.nbytes
    else:
        total_size += sys.getsizeof(obj)

    return total_size


class LRUCache(OrderedDict):
    """
    A Least Recently Used (LRU) cache implemented using an OrderedDict.

    Parameters
    ----------
    max_size_mb : float, optional
        The maximum size of the cache in megabytes (default is 0.0001 MB).

    Attributes
    ----------
    nbytes : int
        The current size of the cache in bytes.
    cache_size_mb : float
        The maximum size of the cache in megabytes.
    cache_contents : list
        The keys present in the cache.

    Notes
    -----
    The cache automatically evicts the least recently used items when the size limit is reached.

    Examples
    --------
    >>> cache = LRUCache(max_size_mb=1)
    >>> cache['a'] = 1
    >>> cache['b'] = 2
    >>> cache['c'] = 3
    >>> print(cache.cache_contents)
    ['a', 'b', 'c']
    >>> print(cache.nbytes)
    XXX bytes
    >>> print(cache.cache_size_mb)
    1.0
    """

    def __init__(self, max_size_mb: float = 100):
        super().__init__()
        self.max_size_bytes = max_size_mb * (1024 * 1024)  # Convert MB to bytes

    def __getitem__(self, key):
        if key in self:
            value = super().__getitem__(key)
            self.move_to_end(key)  # Update the item as the most recent
            return value
        else:
            raise KeyError(key)

    def __setitem__(self, key, value):
        if key not in self:
            item_size_bytes = _obj_size(value)
            while self._get_size() + item_size_bytes > self.max_size_bytes:
                self._evict_least_recently_used()
        super().__setitem__(key, value)

    def _get_size(self):
        return _obj_size(self)

    def _evict_least_recently_used(self):
        key, _ = self.popitem(last=False)  # Remove the least recent item
        warn(f"Removed {key} to free cache space.")

    @property
    def nbytes(self):
        """
        int: The current size of the cache in bytes.
        """
        return self._get_size()

    @property
    def cache_size_mb(self):
        """
        float: The maximum size of the cache in megabytes.
        """
        return self.max_size_bytes / (1024 * 1024)

    @property
    def cache_contents(self):
        """
        list: The keys present in the cache.
        """
        return list(self.keys())

In [60]:
from tabulate import tabulate


headers = ["cache_key_1", "cache_key_2","event_name", "detector_name", "timeseries", "gps_time", "duration"]
def nested_dict_to_table(nested_dict, headers):
    rows = []
    for event_key, value in nested_dict.items():
        for detector_key, data in value.items():
            row = [event_key, detector_key]
            for k, v in data.items():
                if isinstance(v, list):
                    v = f"<TimeSeries(...)>"
                row.append(v)
            rows.append(row)
    return tabulate(rows, headers=headers, tablefmt="fancy_grid")


print(nested_dict_to_table(d, headers))

NameError: name 'd' is not defined

Voglio provare a creare una classe che gestisca in maniera specifica il tipo di dati di cui ho bisogno

### Remote Data Loading

In [1]:
import sys

sys.path.insert(0, "../src/")

In [2]:
from pyburst._typing import type_check
from pyburst._data_loader import LRUCache
from gwpy.timeseries import TimeSeries
import gwosc.datasets
from warnings import warn

import os
from concurrent.futures import ThreadPoolExecutor


class EventDataLoader:
    _AVAILABLE_SOURCES = {
        "remote_open": "_fetch_remote",
        "local": "_fetch_local",
    }

    _CACHED_DATA = LRUCache()

    @classmethod
    @type_check(classmethod=True)
    def _validate_source(cls, source: str):
        if source not in cls._AVAILABLE_SOURCES.keys():
            raise NotImplementedError(f"{source} is not a valid source.")

    @classmethod
    @type_check(classmethod=True)
    def _fetch_remote(
        cls,
        event_name: str,
        detector_name: str,
        duration: float,
        sample_rate: int,
        url: str,
        format: str,
        max_attempts: int,
        this_attempt: int = 1,
        verbose: bool = False,
    ):
        try:
            event_gps_time = gwosc.datasets.event_gps(event_name)
            start_time = event_gps_time - duration / 2
            end_time = event_gps_time + duration / 2
            signal = TimeSeries.fetch_open_data(
                detector_name,
                start_time,
                end_time,
                sample_rate,
                format=format,
                verbose=verbose,
            )
            result = {
                "event_name": event_name,
                "detector_name": detector_name,
                "time_series": signal,
                "gps_time": event_gps_time,
                "duration": duration,
            }
            return result
        except:
            if this_attempt < max_attempts:
                warn(
                    f"Failed downloading {this_attempt}/{max_attempts} times, retrying...",
                    ResourceWarning,
                )
                cls._fetch_remote(
                    event_name,
                    detector_name,
                    duration,
                    sample_rate,
                    format,
                    max_attempts,
                    this_attempt + 1,
                )
            else:
                raise ConnectionError(
                    f"Failed downloading too many times ({this_attempt})"
                )

    @classmethod
    @type_check(classmethod=True)
    def _fetch_local(
        cls,
        event_name: str,
        detector_name: str,
        duration: float,
        sample_rate: int,
        url: str,
        format: str,
        max_attempts: int,
        this_attempt: int = 1,
        verbose: bool = False,
    ):
        try:
            event_gps_time = gwosc.datasets.event_gps(event_name)
            start_time = event_gps_time - duration / 2
            end_time = event_gps_time + duration / 2
            signal = TimeSeries.fetch_open_data(
                detector_name,
                start_time,
                end_time,
                sample_rate,
                format=format,
                verbose=verbose,
            )
            result = {
                "event_name": event_name,
                "detector_name": detector_name,
                "time_series": signal,
                "gps_time": event_gps_time,
                "duration": duration,
            }
            return result
        except:
            if this_attempt < max_attempts:
                warn(
                    f"Failed downloading {this_attempt}/{max_attempts} times, retrying...",
                    ResourceWarning,
                )
                cls._fetch_remote(
                    event_name,
                    detector_name,
                    duration,
                    sample_rate,
                    format,
                    max_attempts,
                    this_attempt + 1,
                )
            else:
                raise ConnectionError(
                    f"Failed downloading too many times ({this_attempt})"
                )

    @classmethod
    @type_check(classmethod=True)
    def _save_event(cls, event_data: dict, save_path: str, fmt: str):
        event_name = event_data["event_name"]
        detector_name = event_data["detector_name"]
        gps_time = event_data["gps_time"]
        timeseries = event_data["time_series"]
        file_path = os.path.join(save_path, event_name, detector_name)
        if not os.path.exists(file_path):
            os.makedirs(
                file_path,
            )
        timeseries.write(
            file_path + f"/{event_name}_{detector_name}_{gps_time}_.{fmt}",
            format=fmt,
            overwrite=True,
        )

    @classmethod
    @type_check(classmethod=True)
    def save_event_data(cls, data_dict: dict, save_path: str, fmt: str = "hdf5"):
        with ThreadPoolExecutor() as executor:
            futures = [
                executor.submit(
                    cls._save_event,
                    data_dict[event_name][detector_name],
                    save_path,
                    fmt,
                )
                for event_name in data_dict.keys()
                for detector_name in data_dict[event_name]
            ]
            for future in futures:
                future.result()

    @classmethod
    @type_check(classmethod=True)
    def get_event_data(
        cls,
        event_names: list[str],
        detector_names: list[str],
        duration: float = 50.0,
        source: str = "remote_open",
        url: str = "",
        sample_rate: int = 4096,
        format: str = "hdf5",
        max_attempts: int = 100,
        cache_results: bool = True,
        force_cache_overwrite: bool = False,
        parallel: bool = True,
        verbose: bool = False,
    ):
        # checking if source is supported
        cls._validate_source(source)

        # getting the correct fetch function depending on input
        _fetch_function = getattr(cls, cls._AVAILABLE_SOURCES[source])

        if parallel:
            with ThreadPoolExecutor() as executor:
                futures = [
                    executor.submit(
                        _fetch_function,
                        event_name,
                        detector_name,
                        duration,
                        sample_rate,
                        url,
                        format,
                        max_attempts,
                        1,
                        False,
                    )
                    for event_name in event_names
                    for detector_name in detector_names
                    if event_name not in cls._CACHED_DATA
                    or detector_name not in cls._CACHED_DATA[event_name]
                    or force_cache_overwrite
                    or cls._CACHED_DATA[event_name][detector_name]["duration"]
                    != duration
                ]

                for future in futures:
                    result = future.result()
                    event_name, detector_name = (
                        result["event_name"],
                        result["detector_name"],
                    )
                    cls._CACHED_DATA.setdefault(event_name, {})[detector_name] = result

        else:
            for event_name in event_names:
                for detector_name in detector_names:
                    if (
                        event_name not in cls._CACHED_DATA
                        or detector_name not in cls._CACHED_DATA[event_name]
                        or force_cache_overwrite
                        or cls._CACHED_DATA[event_name][detector_name]["duration"]
                        != duration
                    ):
                        cls._CACHED_DATA.setdefault(event_name, {})[detector_name] = (
                            _fetch_function(
                                event_name,
                                detector_name,
                                duration,
                                sample_rate,
                                format,
                                max_attempts,
                                1,
                                verbose,
                            )
                        )

        out_var = dict(cls._CACHED_DATA)
        if not cache_results:
            cls._CACHED_DATA.clear()

        return out_var

In [3]:
data = EventDataLoader.get_event_data(
    ["GW150914-v3"],
    ["L1", "H1"],
    10.,
    "remote_open",
)

In [4]:
EventDataLoader.save_event_data(data, '.')

OSError: File exists: .\GW150914-v3\L1/GW150914-v3_L1_1126259462.4_.hdf5

### Save data locally