In [1]:

from fim.data.dataloaders import DataLoaderFactory
import matplotlib.pyplot as plt
from pprint import pprint

In [2]:
dataset_config = {
    "name": "HawkesDataLoader",
    "path_collections": {
        "train": (
            "/home/berghaus/FoundationModels/FIM/data/synthetic_data/hawkes/1k_5_st_hawkes_mixed_2000_paths_250_events/train",
        ),
        "validation": (
            "/home/berghaus/FoundationModels/FIM/data/synthetic_data/hawkes/1k_5_st_hawkes_mixed_2000_paths_250_events/val",
        )
    },
    "loader_kwargs": {
        "batch_size": 3,
        "num_workers": 8,
        "test_batch_size": 1,
        "variable_num_of_paths": True,
        "min_path_count": 100,
        "max_path_count": 1000,
        "max_number_of_minibatch_sizes": 10,
        "variable_sequence_lens": True,
        "min_sequence_len": 10,
        "max_sequence_len": 250,
        "num_kernel_evaluation_points": 10,
        "is_bulk_model": False
    },
    "dataset_kwargs": {
        "files_to_load": {
            "base_intensities": "base_intensities.pt",
            "event_times": "event_times.pt",
            "event_types": "event_types.pt",
            "kernel_evaluations": "kernel_evaluations.pt",
            "kernel_grids": "kernel_grids.pt"
        }
    }
}

In [3]:
from torch import Tensor


def normalize_obs_grid( obs_grid: Tensor) -> tuple[Tensor, Tensor]:
        norm_constants = obs_grid.amax(dim=[-3, -2, -1])
        obs_grid_normalized = obs_grid / norm_constants.view(-1, 1, 1, 1)
        return obs_grid_normalized

In [4]:
dataloader = DataLoaderFactory.create(**dataset_config)

In [6]:
data = []
for sample in dataloader.train_it:
    sample["event_times"] = normalize_obs_grid(sample["event_times"])
    data.append(sample)

In [10]:
data[0]["event_times"].min()

tensor(6.1418e-05)