In [1]:
import numba
import numpy as np
from numpy.typing import ArrayLike
from collections import namedtuple
from typing import Dict, List, Any
from verticox.likelihood import Parameters

# Original code

In [2]:
def _get_dt(t, Dt):
    # If there is no entry for time t it means that it was from a right-censored sample
    # I'm making it an empty list so I can call len on it but it's semantically sound
    d = Dt.get(t, [])
    return len(d)

def bottom(z, params, t):
    denominator = 0.
    for j in params.Rt[t]:
        denominator += np.exp(params.K * z[j])
    return denominator

def derivative_2_diagonal(z: ArrayLike, params: Parameters, u):
    u_event_time = params.event_times[u]

    relevant_event_times = [t for t in params.Rt.keys() if t <= u_event_time]

    summed = 0

    for t in relevant_event_times:
        denominator = bottom(z, params, t)

        first_part = np.square(params.K) * np.exp(params.K * z[u]) / denominator

        second_part = np.square(params.K) * np.square(np.exp(params.K * z[u])) / \
                      np.square(denominator)

        summed += _get_dt(t, params.Dt) * (first_part - second_part)

    return summed + params.K * params.rho


def derivative_2_off_diagonal(z: ArrayLike, params: Parameters, u, v):
    min_event_time = min(params.event_times[u], params.event_times[v])
    relevant_event_times = [t for t in params.Rt.keys() if t <= min_event_time]

    summed = 0

    for t in relevant_event_times:
        summed += params.deaths_per_t[t] * np.square(params.K) * np.exp(params.K * z[u]) * \
                  np.exp(params.K * z[
                      v]) / \
                  np.square(np.exp(params.K * z[params.Rt[t]]).sum())

    return -1 * summed

def hessian_parametrized(z: ArrayLike, params: Parameters):
    # The hessian is a N x N matrix where N is the number of elements in z
    N = z.shape[0]
    mat = np.zeros((N, N))

    for u in range(N):
        for v in range(N):

            if u == v:
                # Formula for diagonals
                mat[u, v] = derivative_2_diagonal(z, params, u)
            else:
                # Formula for off-diagonals
                mat[u, v] = derivative_2_off_diagonal(z, params, u, v)

    return mat

In [3]:
from numpy import array
params =  Parameters(gamma=array([0., 0., 0., 0., 0.]), sigma=array([0.85583803, 0.79890457, 0.91337144, 0.68948236, 0.78043699]), rho=0.5, Rt={1.0: array([0, 1, 2, 3, 4]), 297.0: array([0, 2, 3, 4]), 2172.0: array([2, 3, 4]), 2178.0: array([2, 4]), 2190.0: array([4])}, K=2, event_times=array([2.970e+02, 1.000e+00, 2.178e+03, 2.172e+03, 2.190e+03]), Dt={297.0: [0], 1.0: [1]}, deaths_per_t={297.0: 1, 1.0: 1, 2178.0: 0, 2172.0: 0, 2190.0: 0})

z = np.zeros((5,))

In [4]:
params

Parameters(gamma=array([0., 0., 0., 0., 0.]), sigma=array([0.85583803, 0.79890457, 0.91337144, 0.68948236, 0.78043699]), rho=0.5, Rt={1.0: array([0, 1, 2, 3, 4]), 297.0: array([0, 2, 3, 4]), 2172.0: array([2, 3, 4]), 2178.0: array([2, 4]), 2190.0: array([4])}, K=2, event_times=array([2.970e+02, 1.000e+00, 2.178e+03, 2.172e+03, 2.190e+03]), Dt={297.0: [0], 1.0: [1]}, deaths_per_t={297.0: 1, 1.0: 1, 2178.0: 0, 2172.0: 0, 2190.0: 0})

In [5]:
%timeit hessian_parametrized(z, params)

603 µs ± 1.62 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


# Remove get_dt

In [6]:
def bottom(z, params, t):
    denominator = 0.
    for j in params.Rt[t]:
        denominator += np.exp(params.K * z[j])
    return denominator

def derivative_2_diagonal(z: ArrayLike, params: Parameters, u):
    u_event_time = params.event_times[u]

    relevant_event_times = [t for t in params.Rt.keys() if t <= u_event_time]

    summed = 0

    for t in relevant_event_times:
        denominator = bottom(z, params, t)

        first_part = np.square(params.K) * np.exp(params.K * z[u]) / denominator

        second_part = np.square(params.K) * np.square(np.exp(params.K * z[u])) / \
                      np.square(denominator)

        summed += params.deaths_per_t[t] * (first_part - second_part)

    return summed + params.K * params.rho


def derivative_2_off_diagonal(z: ArrayLike, params: Parameters, u, v):
    min_event_time = min(params.event_times[u], params.event_times[v])
    relevant_event_times = [t for t in params.Rt.keys() if t <= min_event_time]

    summed = 0

    for t in relevant_event_times:
        summed += params.deaths_per_t[t] * np.square(params.K) * np.exp(params.K * z[u]) * \
                  np.exp(params.K * z[
                      v]) / \
                  np.square(np.exp(params.K * z[params.Rt[t]]).sum())

    return -1 * summed

def hessian_parametrized(z: ArrayLike, params: Parameters):
    # The hessian is a N x N matrix where N is the number of elements in z
    N = z.shape[0]
    mat = np.zeros((N, N))

    for u in range(N):
        for v in range(N):

            if u == v:
                # Formula for diagonals
                mat[u, v] = derivative_2_diagonal(z, params, u)
            else:
                # Formula for off-diagonals
                mat[u, v] = derivative_2_off_diagonal(z, params, u, v)

    return mat

In [7]:
%timeit hessian_parametrized(z, params)

603 µs ± 1.47 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


No difference but didn't really expect one

# Numbify the inner loop

In [36]:
from numba import types
spec=[
      ('gamma', types.float64[:]),
      ('sigma', types.float64[:]),
      ('Rt', types.DictType(types.float64, types.int64[:])),
    ('event_times', types.float64[:]),
    ('Dt', types.DictType(types.float64, types.int64[:])),
    ('deaths_per_t', types.DictType(types.float64, types.int64))        
]

@numba.experimental.jitclass(spec)
class Parameters_numba:
    gamma: np.ndarray
    sigma: np.ndarray
    rho: float
    Rt: Dict[float, np.ndarray]
    K: int
    event_times:np.ndarray
    Dt: Dict[float,np.ndarray]
    deaths_per_t: Dict[float, int]
    
    def __init__(self, gamma, sigma, rho, Rt ,K, event_times, Dt, deaths_per_t):
        self.gamma = gamma
        self.sigma = sigma
        self.rho = rho
        self.Rt = Rt
        self.K = K
        self.event_times = event_times
        self.Dt = Dt
        self.deaths_per_t = deaths_per_t
    


@numba.njit()
def bottom(z, params, t):
    denominator = 0.
    for j in params.Rt[t]:
        denominator += np.exp(params.K * z[j])
    return denominator

@numba.njit()
def derivative_2_diagonal(z: ArrayLike, params: Parameters, u):
    u_event_time = params.event_times[u]

    relevant_event_times = [t for t in params.Rt.keys() if t <= u_event_time]

    summed = 0

    for t in relevant_event_times:
        denominator = bottom(z, params, t)

        first_part = np.square(params.K) * np.exp(params.K * z[u]) / denominator

        second_part = np.square(params.K) * np.square(np.exp(params.K * z[u])) / \
                      np.square(denominator)

        summed += params.deaths_per_t[t] * (first_part - second_part)

    return summed + params.K * params.rho

@numba.njit()
def derivative_2_off_diagonal(z: ArrayLike, params, u, v):
    min_event_time = min(params.event_times[u], params.event_times[v])
    relevant_event_times = [t for t in params.Rt.keys() if t <= min_event_time]

    summed = 0

    for t in relevant_event_times:
        summed += params.deaths_per_t[t] * np.square(params.K) * np.exp(params.K * z[u]) * \
                  np.exp(params.K * z[v]) / \
                  np.square(np.exp(params.K * z[params.Rt[t]]).sum())

    return -1 * summed

@numba.njit()
def hessian_parametrized(z: ArrayLike, params: Parameters):
    # The hessian is a N x N matrix where N is the number of elements in z
    N = z.shape[0]
    mat = np.zeros((N, N))

    for u in range(N):
        for v in range(N):

            if u == v:
                # Formula for diagonals
                mat[u, v] = derivative_2_diagonal(z, params, u)
            else:
                # Formula for off-diagonals
                mat[u, v] = derivative_2_off_diagonal(z, params, u, v)

    return mat

In [37]:
from numba import typed, types

typed_Rt = typed.Dict.empty(types.float64, types.int64[:])

for key, value in params.Rt.items():
    typed_Rt[key] = value

typed_deaths_per_t = typed.Dict.empty(types.float64, types.int64)

for key, value in params.deaths_per_t.items():
    typed_deaths_per_t[key] = value
    
typed_Dt = typed.Dict.empty(types.float64, types.int64[:])

for key, value in params.Dt.items():
    typed_Dt[key] = np.array(value)
    
typed_params = Parameters_numba(gamma=params.gamma, sigma=params.sigma, rho=params.rho,
                          Rt=typed_Rt, K=params.K, event_times=params.event_times, 
                          Dt=typed_Dt, deaths_per_t=typed_deaths_per_t)

In [38]:
%timeit hessian_parametrized(z, typed_params)

25.1 µs ± 10.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [29]:
typed_params.event_times

array([2.970e+02, 1.000e+00, 2.178e+03, 2.172e+03, 2.190e+03])