In [1]:
from datasets import load_dataset
from fim.utils.interpolator import KernelInterpolator
from fim.models.hawkes.thinning import EventSampler
from fim.models.hawkes import FIMHawkes
import torch
from matplotlib import pyplot as plt
from functools import partial

In [2]:
data = load_dataset("FIM4Science/hawkes-synthetic-short-scale-single-process", "train_process_8")
data.set_format(type='torch')

In [3]:
marks = data['train'][0]['target_kernel_evaluations'].size(0)

In [2]:
!nvidia-smi


Fri Apr  4 11:56:20 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.124.06             Driver Version: 570.124.06     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          On  |   00000000:07:00.0 Off |                    0 |
| N/A   31C    P0             52W /  400W |       1MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

# Intensity calculation

$ \hat{\lambda}_k(t) = \text{ReLU}\left( \hat{\lambda}^0_k + \sum_{i: t_{ki} < t} \hat{K}_k(t - t_{ki})\right)$  

In [5]:
split="test"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda"

In [None]:
fig, axes = plt.subplots(10, 3, figsize=(20, 30))
axes = axes.flatten()

for i in range(30):
    time_intensities = data[split][i]['target_intensity_times']
    intensities = data[split][i]['target_intensities']

    for mark in range(marks):
        # Plot intensities
        # Plot events
        event_times = data[split][i]['time_since_start']
        event_types = data[split][i]['type_event']
        event_times_by_type = [event_times[event_types == mark].tolist() for mark in range(3)]
        axes[i].eventplot(event_times_by_type, colors=['blue', 'orange', 'green'], linelengths=0.05)
        axes[i].plot(time_intensities, intensities[mark], label=f'Mark {mark}')

    # axes[i].set_xlim(0, 500)

    axes[i].set_xlim(0, time_intensities.max())
    axes[i].set_ylim(-0.15, intensities.max())

    axes[i].set_title(f'Path {i}')
    axes[i].legend()

plt.tight_layout()
plt.show()


In [18]:
def calc_intensity(t: torch.Tensor, kernel_grids: torch.Tensor, kernel_evaluations: torch.Tensor, base_intensities: torch.Tensor, time_since_start: torch.Tensor):
    """Calculate Hawkes process intensity.

    Args:
        t: Time points to evaluate intensity at
        kernel_grids: Grid points for kernel interpolation (shape: [num_marks, num_grid_points])
        kernel_evaluations: Kernel function values at grid points (shape: [num_marks, num_grid_points])
        base_intensities: Base intensity for each mark type (shape: [num_marks])
        time_since_start: Event times (shape: [num_events])

    Returns:
        Intensity values at requested time points (shape: [num_marks, num_events])
    """
    # Create interpolator for kernel function
    kernel_approx = KernelInterpolator(kernel_grids, kernel_evaluations, mode="interpolate")

    intensities = FIMHawkes.intentsity(t.unsqueeze(-1).unsqueeze(0), time_since_start.unsqueeze(0), kernel_approx, base_intensities)

    return intensities

In [None]:
path_id = 27
plt.figure(figsize=(10, 6))
t_values = torch.linspace(0, 10, 2000)
t_values = data[split][path_id]['target_intensity_times']

intensity = calc_intensity(t_values, data[split][path_id]['target_kernel_grids'], data[split][path_id]['target_kernel_evaluations'], data[split][path_id]['target_base_intensities'], data[split][path_id]['time_since_start'])
plt.plot(t_values, intensity.squeeze(), label='Predicted Intensity', color="blue", linewidth=2)
plt.plot(data[split][path_id]['target_intensity_times'].squeeze(), data[split][path_id]['target_intensities'].squeeze(), label='True Intensity', linewidth=4, color="orange", alpha=0.5)
plt.eventplot(data[split][path_id]['time_since_start'].squeeze(), linelengths=0.15, colors='orange', lineoffsets=-0.1)
plt.xlim(0, 10)
plt.legend()

In [None]:
intensity_interpolator = KernelInterpolator(data[split][0]['target_kernel_grids'].to(device), data[split][0]['target_kernel_evaluations'].to(device), mode="interpolate", out_of_bounds_value=0)
intensity = partial(intensity_fn, kernel=intensity_interpolator, base_intensity=data[split][0]['target_base_intensities'].to(device), time_seq=data[split][0]['time_since_start'].unsqueeze(0).to(device))

sampler = EventSampler(num_sample=1, device=device, num_exp=550)

instances = data[split][27:28]


accepted_dtimes, weights = sampler.draw_next_time_one_step(instances['time_since_start'].squeeze(-1).to(device), instances['time_since_last_event'].squeeze(-1).to(device), instances['type_event'].squeeze(-1).to(device), intensity, False)
# We should condition on each accepted time to sample event mark, but not conditioned on the expected event time.
# 1. Use all accepted_dtimes to get intensity.
# [batch_size, seq_len, num_sample, num_marks]
intensities_at_times = intensity(accepted_dtimes)

# 2. Normalize the intensity over last dim and then compute the weighted sum over the `num_sample` dimension.
# Each of the last dimension is a categorical distribution over all marks.
# [batch_size, seq_len, num_sample, num_marks]
intensities_normalized = intensities_at_times / intensities_at_times.sum(dim=-1, keepdim=True)

# 3. Compute weighted sum of distributions and then take argmax.
# [batch_size, seq_len, num_marks]
intensities_weighted = torch.einsum('...s,...sm->...m', weights, intensities_normalized)

# [batch_size, seq_len]
types_pred = torch.argmax(intensities_weighted, dim=-1)

# [batch_size, seq_len]
dtimes_pred = torch.sum(accepted_dtimes * weights, dim=-1)  # compute the expected next event time


In [None]:
intensity = partial(intensity_fn, kernel=intensity_interpolator, base_intensity=data[split][0]['target_base_intensities'].to(device), time_seq=data[split][0]['time_since_start'].unsqueeze(0).to(device))
predicted_intensity = intensity(instances['target_intensity_times'].unsqueeze(-1).to(device))
predicted_intensity_fn = intensity_fn(instances['target_intensity_times'].unsqueeze(-1).to(device), intensity_interpolator, data[split][0]['target_base_intensities'].to(device), data[split][0]['time_since_start'].unsqueeze(0).to(device))


In [None]:
kernel_approx = KernelInterpolator(data[split][0]['target_kernel_grids'].to(device), data[split][0]['target_kernel_evaluations'].to(device), mode="interpolate", out_of_bounds_value=0)
kk = intensity_fn(instances['target_intensity_times'].unsqueeze(-1).to(device), kernel_approx, data[split][0]['target_base_intensities'].to(device), data[split][0]['time_since_start'].unsqueeze(0).to(device))

In [None]:
instances['target_intensity_times'][0, 13:20]


In [None]:
instances['target_intensities'].squeeze()[ 13:20]

In [None]:
plt.plot(instances['target_intensity_times'].squeeze().cpu(), predicted_intensity.squeeze().cpu(), label='Predicted Intensity', color="blue", linewidth=2)
plt.plot(instances['target_intensity_times'].squeeze().cpu(), instances['target_intensities'].squeeze().cpu(), label='True Intensity', linewidth=4, color="orange", alpha=0.5)
plt.plot(instances['target_intensity_times'].squeeze().cpu(), kk.squeeze().cpu(), label='Predicted Intensity', color="red", linewidth=2)
plt.eventplot(instances['time_since_start'].squeeze(), linelengths=0.15, colors='orange', lineoffsets=-0.1)
plt.xlim(0, 10)
plt.legend()
plt.show()


In [None]:

c = ['blue', 'orange', 'green']
path_id = 0
plt.figure(figsize=(10, 6))
# Sort accepted times and corresponding intensities
sorted_indices = torch.argsort(accepted_dtimes[path_id].squeeze())
sorted_accepted_dtimes = accepted_dtimes[path_id].squeeze()[sorted_indices]
sorted_intensities_at_times = intensities_at_times[path_id].squeeze()[sorted_indices]

plt.plot(instances['target_intensity_times'].squeeze(), instances['target_intensities'].squeeze(), label='True Intensity', color="orange")
plt.plot(sorted_accepted_dtimes.cpu(), sorted_intensities_at_times.cpu(), label='Intensity at sampled times', color="blue")
plt.eventplot(instances['time_since_start'].squeeze(), linelengths=0.15, colors='orange', lineoffsets=-0.1, label='True events')
plt.eventplot(dtimes_pred[path_id].cpu(), linelengths=0.15, colors='blue', lineoffsets=-0.4, label='Sampled events')

plt.title(f'Path {path_id}')
plt.legend()
plt.xlim(0, 150)
# plt.ylim(-.5, 5)
plt.xlim(-1, 5)
plt.tight_layout()
plt.show()

In [None]:

time_intensities = data['train'][0]['target_kernel_grids']
intensities = data['train'][0]['target_kernel_evaluations']
interpolator = KernelInterpolator(data[split][0]['target_kernel_grids'].to(device), data[split][0]['target_kernel_evaluations'].to(device), mode="interpolate", out_of_bounds_value=0)
sample_times = torch.linspace(data[split][0]['target_kernel_grids'].min(), data[split][0]['target_kernel_grids'].max()*50, 2000, device=device).unsqueeze(0)
approx = interpolator(sample_times.unsqueeze(0))
c = ['blue', 'orange', 'green']
for mark in range(marks):
    plt.plot(time_intensities[mark].cpu(), intensities[mark].cpu(), label=f'Mark {mark}', linewidth=5)
    plt.plot(sample_times.squeeze().cpu(), approx.squeeze().cpu(), label=f'Approximation')
    plt.axhline(data['train'][0]['target_base_intensities'][mark], color=c[mark], linestyle='--')

plt.title(f'Path {i}')
plt.legend()
plt.xlim(0, 50)

plt.tight_layout()
plt.show()