### This notebook tries to fit tick's EM model to the real world datasets to get a feeling what their kernels look like. Because the EM algorithm takes a single (long) path, we concatenate the paths of the real world datasets.

In [15]:
import pickle

import matplotlib.pyplot as plt
import numpy as np
from tick.hawkes import HawkesEM
from tick.plot import plot_hawkes_kernels

In [16]:
with open("../../data/evaluation/hawkes/mimic_II.pkl", "rb") as f:
    mimic = pickle.load(f)

with open("../../data/evaluation/hawkes/mooc.pkl", "rb") as f:
    mooc = pickle.load(f)

with open("../../data/evaluation/hawkes/stackOverflow.pkl", "rb") as f:
    stack = pickle.load(f)

with open("../../data/evaluation/hawkes/retweet.pkl", "rb") as f:
    retweet = pickle.load(f)

In [17]:
retweet.keys()

dict_keys(['timestamps', 'types', 'lengths', 'timeintervals'])

In [18]:
retweet["timestamps"].shape

(24000,)

In [19]:
retweet["timestamps"][0].shape

(88,)

In [20]:
def get_number_of_marks(types_data):
    return np.unique(np.concatenate(types_data)).max() + 1


def merge_to_single_path(dataset):
    """
    Merge all paths to a single path.
    We return the timestamps for every mark separately.
    """
    num_marks = get_number_of_marks(dataset["types"])
    res = [[] for _ in range(num_marks)]
    for path_idx in range(len(dataset["types"])):
        marks = np.unique(dataset["types"][path_idx])
        for mark in marks:
            prev_time = 0
            if len(res[mark]) > 0 and res[mark][-1] != 0:
                prev_time = res[mark][-1]
            time_stamps = dataset["timestamps"][path_idx][dataset["types"][path_idx] == mark]
            # Add the previous time to the timestamps
            time_stamps = [time + prev_time for time in time_stamps]
            res[mark] += time_stamps

    for i in range(len(res)):
        res[i] = np.array(res[i])
    return res


def normalize_timestamps(timestamps):
    """
    Normalize the timestamps to have max delta time of 1
    """
    flattened_times = np.concatenate(timestamps)
    max_time = np.diff(flattened_times).max()
    return [time / max_time for time in timestamps]

In [21]:
# em = HawkesEM(1, kernel_size=100, n_threads=8, verbose=False, tol=1e-3)
# timestamps = merge_to_single_path(mimic)
# # timestamps = normalize_timestamps(timestamps)

# num_marks_to_consider = 6
# em.fit(timestamps[:num_marks_to_consider])

# fig = plot_hawkes_kernels(em, show=True)
# fig.set_size_inches(15, 10)  # Adjust the size as needed
# plt.tight_layout()  # Adjust the layout to prevent overlap

# plt.show()
# plt.savefig("mimic.jpg")

In [22]:
# em = HawkesEM(0.04, kernel_size=100, n_threads=8, verbose=False, tol=1e-3)
# timestamps = merge_to_single_path(mooc)
# # timestamps = normalize_timestamps(timestamps)

# num_marks_to_consider = 6
# em.fit(timestamps[:num_marks_to_consider])

# fig = plot_hawkes_kernels(em, show=True)
# fig.set_size_inches(15, 10)  # Adjust the size as needed
# plt.tight_layout()  # Adjust the layout to prevent overlap

# plt.show()
# plt.savefig("mooc.jpg")

In [23]:
# em = HawkesEM(30, kernel_size=100, n_threads=8, verbose=False, tol=1e-3)
# timestamps = merge_to_single_path(stack)
# # timestamps = normalize_timestamps(timestamps)

# num_marks_to_consider = 6
# em.fit(timestamps[:num_marks_to_consider])

# fig = plot_hawkes_kernels(em, show=True)
# fig.set_size_inches(15, 10)  # Adjust the size as needed
# plt.tight_layout()  # Adjust the layout to prevent overlap

# plt.show()
# plt.savefig("stack.jpg")

In [24]:
# em = HawkesEM(150, kernel_size=100, n_threads=8, verbose=False, tol=1e-3)
# timestamps = merge_to_single_path(retweet)
# # timestamps = normalize_timestamps(timestamps)

# num_marks_to_consider = 3
# em.fit(timestamps[:num_marks_to_consider])

# fig = plot_hawkes_kernels(em, show=True)
# fig.set_size_inches(15, 10)  # Adjust the size as needed
# plt.tight_layout()  # Adjust the layout to prevent overlap

# plt.show()
# plt.savefig("retweet.jpg")

In [25]:
# Helpers for EasyTPP-format datasets (Hugging Face)


def get_num_marks_easytpp(ds_dict):
    if len(ds_dict.get("type_event", [])) == 0:
        return 0
    max_type = -1
    for seq_types in ds_dict["type_event"]:
        if len(seq_types):
            max_type = max(max_type, int(np.max(seq_types)))
    return max_type + 1 if max_type >= 0 else 0


def merge_easytpp_to_single_path(ds_dict, max_sequences=None):
    """
    Merge EasyTPP dict-of-lists into a single path of per-mark timestamps
    suitable for tick.HawkesEM.

    Returns: list[np.ndarray] of length M (num marks)
    """
    num_seqs = len(ds_dict.get("seq_len", []))
    if num_seqs == 0:
        return []
    M = get_num_marks_easytpp(ds_dict)
    res = [[] for _ in range(M)]

    limit = num_seqs if max_sequences is None else min(num_seqs, int(max_sequences))
    for i in range(limit):
        times = np.array(ds_dict["time_since_start"][i], dtype=float)
        types = np.array(ds_dict["type_event"][i], dtype=int)
        if times.size == 0:
            continue
        marks_in_seq = np.unique(types).astype(int)
        for m in marks_in_seq:
            idx = types == m
            ts_m = times[idx]
            if ts_m.size == 0:
                continue
            prev_time = res[m][-1] if len(res[m]) > 0 else 0.0
            shifted = (ts_m + prev_time).tolist()
            res[m].extend(shifted)

    for j in range(M):
        res[j] = np.array(res[j], dtype=float)
    return res

In [26]:
# Define manual EasyTPP loader above usage to avoid NameError
import inspect
import io
import json
import os
from typing import Any, Dict, List


try:
    from huggingface_hub import snapshot_download
except Exception:
    snapshot_download = None


# ---------- file discovery / parsing ----------


def _find_split_files(base_dir: str, split: str) -> List[str]:
    split_low = split.lower()
    hits = []
    for root, _, files in os.walk(base_dir):
        for name in files:
            low = name.lower()
            if split_low in low and (
                low.endswith(".json") or low.endswith(".jsonl") or low.endswith(".json.gz") or low.endswith(".jsonl.gz")
            ):
                hits.append(os.path.join(root, name))
    return sorted(hits)


def _open_textmaybe_gz(path: str):
    if path.endswith(".gz"):
        import gzip

        return gzip.open(path, "rt", encoding="utf-8")
    return io.open(path, "r", encoding="utf-8")


def _read_json_records(path: str) -> List[Dict[str, Any]]:
    with _open_textmaybe_gz(path) as f:
        head = f.read(2048)
        f.seek(0)
        if head.lstrip().startswith("["):
            data = json.load(f)
            if isinstance(data, dict):
                for v in data.values():
                    if isinstance(v, list):
                        data = v
                        break
            return data if isinstance(data, list) else []
        else:
            recs = []
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    recs.append(json.loads(line))
                except Exception:
                    pass
            return recs


def _records_to_easytpp_dict(records: List[Dict[str, Any]]) -> Dict[str, List[List[Any]]]:
    res = {"time_since_start": [], "time_since_last_event": [], "type_event": [], "seq_len": []}
    for r in records:
        times = None
        types = None
        if isinstance(r, dict):
            if "time_since_start" in r and "type_event" in r:
                times = r.get("time_since_start")
                types = r.get("type_event")
            elif "timestamps" in r and "types" in r:
                times = r.get("timestamps")
                types = r.get("types")
            elif "time" in r and "type" in r:
                times = r.get("time")
                types = r.get("type")
        if not isinstance(times, (list, tuple)) or not isinstance(types, (list, tuple)):
            continue
        if len(times) != len(types) or len(times) == 0:
            continue
        times = [float(t) for t in times]
        types = [int(c) for c in types]
        deltas = [0.0] + [times[i] - times[i - 1] for i in range(1, len(times))]
        res["time_since_start"].append(times)
        res["time_since_last_event"].append(deltas)
        res["type_event"].append(types)
        res["seq_len"].append(len(times))
    return res


# ---------- HTTP fallback (no git, no Arrow) ----------


def _http_try_download_split(repo_id: str, split: str, token: str) -> str:
    import requests

    repo = repo_id.split("datasets/")[-1]
    # common filenames
    cands = [
        f"{split}.jsonl",
        f"{split}.json",
        f"{split}.jsonl.gz",
        f"{split}.json.gz",
        f"data/{split}.jsonl",
        f"data/{split}.json",
        f"data/{split}.jsonl.gz",
        f"data/{split}.json.gz",
    ]
    headers = {"Authorization": f"Bearer {token}"} if token else {}
    cache_dir = os.path.join(os.path.expanduser("~/.cache"), "hf_ds_http", repo.replace("/", "_"))
    os.makedirs(cache_dir, exist_ok=True)
    base_urls = [
        f"https://huggingface.co/datasets/{repo}/resolve/main/",
        f"https://huggingface.co/{repo}/resolve/main/",
    ]
    for base in base_urls:
        for cand in cands:
            url = base + cand
            r = requests.get(url, headers=headers, stream=True)
            if r.status_code == 200:
                local_path = os.path.join(cache_dir, cand.replace("/", "_"))
                with open(local_path, "wb") as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        if chunk:
                            f.write(chunk)
                return local_path
    raise RuntimeError("Could not fetch split file via HTTP for %s (%s)" % (repo_id, split))


# ---------- snapshot or HTTP ----------


def _snapshot_dataset(repo_id: str, token: str) -> str:
    # Try huggingface_hub if present
    if snapshot_download is not None:
        try:
            sig = inspect.signature(snapshot_download)
            if "repo_type" in sig.parameters:
                return snapshot_download(repo_id, repo_type="dataset", use_auth_token=token)
        except Exception:
            pass
        for rid in (("datasets/" + repo_id) if not repo_id.startswith("datasets/") else repo_id, repo_id):
            try:
                return snapshot_download(rid, use_auth_token=token)
            except Exception:
                continue
    # Fall back to HTTP single-file download into a temp dir
    local_file = _http_try_download_split(repo_id, split="train", token=token)
    base_dir = os.path.dirname(local_file)
    return base_dir


def load_easytpp_as_dict(repo_id: str, split: str) -> Dict[str, List[List[Any]]]:
    token = os.getenv("HF_TOKEN")
    local_dir = _snapshot_dataset(repo_id, token)
    candidates = _find_split_files(local_dir, split)
    if not candidates:
        # try HTTP direct for this split
        path = _http_try_download_split(repo_id, split, token)
        candidates = [path]
    for path in candidates:
        try:
            recs = _read_json_records(path)
            if isinstance(recs, list) and recs:
                out = _records_to_easytpp_dict(recs)
                if out["seq_len"]:
                    return out
        except Exception:
            continue
    raise RuntimeError("Failed to parse any JSON/JSONL for split '%s' in %s" % (split, local_dir))

In [27]:
# Helper to plot Hawkes kernels at a readable size based on number of marks
plt.rcParams.update({"font.size": 12, "figure.dpi": 120})


def plot_kernels_large(em, per_subplot=4.5, extra_w=8.0, extra_h=4.5, dpi=380):
    try:
        num_marks = int(len(em.baseline))
    except Exception:
        num_marks = 1
    fig = plot_hawkes_kernels(em, show=False)
    # Avoid extreme sizes: clamp marks and figure dims
    marks = max(1, min(num_marks, 18))
    width = per_subplot * marks + extra_w
    height = per_subplot * marks + extra_h
    fig.set_size_inches(max(12.0, width), max(9.0, height))
    try:
        fig.tight_layout()
    except Exception:
        pass
    return fig, dpi


# Convenience wrapper to show+save


def show_and_save(fig, path, dpi):
    fig.savefig(path, dpi=dpi, bbox_inches="tight")
    try:
        plt.show(block=False)
        plt.pause(0.1)
    except Exception:
        pass
    plt.close(fig)

In [28]:
# Load and fit EM on easytpp/taobao (Arrow bypass)
import matplotlib.pyplot as plt


# Load using manual snapshot + JSON parser to avoid Arrow
taobao_train = load_easytpp_as_dict("easytpp/taobao", split="train")

# Merge train into a single multi-mark timestamp list
taobao_timestamps = merge_easytpp_to_single_path(taobao_train)

em = HawkesEM(1.0, kernel_size=100, n_threads=8, verbose=False, tol=1e-3)
num_marks_to_consider = len(taobao_timestamps)
em.fit(taobao_timestamps[:num_marks_to_consider])

fig = plot_hawkes_kernels(em, show=True)
fig.set_size_inches(15, 10)
plt.tight_layout()
plt.show()
plt.savefig("taobao_hf.jpg")



In [29]:
# Unified loader shim: delegate to load_easytpp_as_dict defined above


def load_easytpp(repo_id: str, split: str):
    return load_easytpp_as_dict(repo_id, split)

In [30]:
# Amazon via robust loader (manual JSON parsing fallback)
import matplotlib.pyplot as plt
from tick.hawkes import HawkesEM
from tick.plot import plot_hawkes_kernels


amazon_train = load_easytpp_as_dict("easytpp/amazon", split="train")

amazon_timestamps = merge_easytpp_to_single_path(amazon_train)

em = HawkesEM(1.0, kernel_size=100, n_threads=8, verbose=False, tol=1e-3)
num_marks_to_consider = len(amazon_timestamps)
em.fit(amazon_timestamps[:num_marks_to_consider])

fig, dpi = plot_kernels_large(em)
show_and_save(fig, "amazon_hf.jpg", dpi)

In [31]:
# Taxi via robust loader (manual JSON parsing fallback)
import matplotlib.pyplot as plt
from tick.hawkes import HawkesEM
from tick.plot import plot_hawkes_kernels


taxi_train = load_easytpp_as_dict("easytpp/taxi", split="train")

taxi_timestamps = merge_easytpp_to_single_path(taxi_train)

em = HawkesEM(1.0, kernel_size=100, n_threads=8, verbose=False, tol=1e-3)
num_marks_to_consider = len(taxi_timestamps)
em.fit(taxi_timestamps[:num_marks_to_consider])

fig, dpi = plot_kernels_large(em)
show_and_save(fig, "taxi_hf.jpg", dpi)

In [32]:
# Taobao via robust loader (manual JSON parsing fallback)
import matplotlib.pyplot as plt
from tick.hawkes import HawkesEM
from tick.plot import plot_hawkes_kernels


taobao_train = load_easytpp_as_dict("easytpp/taobao", split="train")

taobao_timestamps = merge_easytpp_to_single_path(taobao_train)

em = HawkesEM(1.0, kernel_size=100, n_threads=8, verbose=False, tol=1e-3)
num_marks_to_consider = len(taobao_timestamps)
em.fit(taobao_timestamps[:num_marks_to_consider])

fig, dpi = plot_kernels_large(em)
show_and_save(fig, "taobao_hf.jpg", dpi)

In [33]:
# # Manual JSON loader fallback that bypasses Arrow
# import os, json, io
# from typing import List, Dict, Any

# try:
#     from huggingface_hub import snapshot_download
# except Exception:
#     snapshot_download = None


# def _find_split_files(base_dir: str, split: str) -> List[str]:
#     split_low = split.lower()
#     hits = []
#     for root, _, files in os.walk(base_dir):
#         for name in files:
#             low = name.lower()
#             if split_low in low and (low.endswith(".json") or low.endswith(".jsonl")):
#                 hits.append(os.path.join(root, name))
#     return sorted(hits)


# def _read_json_records(path: str) -> List[Dict[str, Any]]:
#     with io.open(path, "r", encoding="utf-8") as f:
#         head = f.read(2048)
#         f.seek(0)
#         if head.lstrip().startswith("["):
#             data = json.load(f)
#             if isinstance(data, dict):
#                 # If top-level dict, pick first list value
#                 for v in data.values():
#                     if isinstance(v, list):
#                         data = v
#                         break
#             if isinstance(data, list):
#                 return data
#             return []
#         else:
#             # JSONL
#             recs = []
#             for line in f:
#                 line = line.strip()
#                 if not line:
#                     continue
#                 try:
#                     recs.append(json.loads(line))
#                 except Exception:
#                     pass
#             return recs


# def _records_to_easytpp_dict(records: List[Dict[str, Any]]) -> Dict[str, List[List[Any]]]:
#     res = {"time_since_start": [], "time_since_last_event": [], "type_event": [], "seq_len": []}
#     for r in records:
#         times = None
#         types = None
#         # Common EasyTPP-style keys
#         if isinstance(r, dict):
#             if "time_since_start" in r and "type_event" in r:
#                 times = r.get("time_since_start")
#                 types = r.get("type_event")
#             elif "timestamps" in r and "types" in r:
#                 times = r.get("timestamps")
#                 types = r.get("types")
#             elif "time" in r and "type" in r:
#                 times = r.get("time")
#                 types = r.get("type")
#         if not isinstance(times, (list, tuple)) or not isinstance(types, (list, tuple)):
#             continue
#         if len(times) != len(types) or len(times) == 0:
#             continue
#         times = [float(t) for t in times]
#         types = [int(c) for c in types]
#         deltas = [0.0] + [times[i] - times[i - 1] for i in range(1, len(times))]
#         res["time_since_start"].append(times)
#         res["time_since_last_event"].append(deltas)
#         res["type_event"].append(types)
#         res["seq_len"].append(len(times))
#     return res


# def load_easytpp_as_dict(repo_id: str, split: str) -> Dict[str, List[List[Any]]]:
#     """
#     Return dict-of-lists (time_since_start, time_since_last_event, type_event, seq_len).
#     Tries datasets API first; falls back to snapshot + manual JSON parsing to avoid Arrow.
#     """
#     token = os.getenv("HF_TOKEN")
#     # Try datasets API (if available and supports remote code)
#     try:
#         from datasets import load_dataset
#         try:
#             ds = load_dataset(repo_id, split=split, trust_remote_code=True, use_auth_token=token)
#         except TypeError:
#             ds = load_dataset(repo_id, split=split, use_auth_token=token)
#         # Convert to plain dict-of-lists
#         return ds[: len(ds)]
#     except Exception:
#         pass

#     # Fallback to snapshot + manual parsing
#     if snapshot_download is None:
#         raise RuntimeError("huggingface_hub.snapshot_download not available in this environment")
#     local_dir = snapshot_download(repo_id, repo_type="dataset", use_auth_token=token)
#     candidates = _find_split_files(local_dir, split)
#     if not candidates:
#         raise RuntimeError("No JSON/JSONL files found for split '%s' in %s" % (split, local_dir))
#     # Try candidates until parse succeeds
#     for path in candidates:
#         try:
#             recs = _read_json_records(path)
#             if isinstance(recs, list) and recs:
#                 out = _records_to_easytpp_dict(recs)
#                 if out["seq_len"]:
#                     return out
#         except Exception:
#             continue
#     raise RuntimeError("Failed to parse any JSON/JSONL for split '%s' in %s" % (split, local_dir))
