In [None]:
import numpy as np
from scipy.linalg import det, inv

def kalman_estimation(y, psi, matur, dt, a0, P0, N, nobs, locked_parameters):
    # Extracting initial parameter values from initial psi
    k = psi[0, 0]
    sigmax = psi[1, 0]
    lambdax = psi[2, 0]
    mu = psi[3, 0]
    sigmae = psi[4, 0]
    rnmu = psi[5, 0]
    pxe = psi[6, 0]

    if np.sum(locked_parameters) == 0:
        k = psi[0, 0]
        sigmax = psi[1, 0]
        lambdax = psi[2, 0]
        mu = psi[3, 0]
        sigmae = psi[4, 0]
        rnmu = psi[5, 0]
        pxe = psi[6, 0]

        s = np.zeros((1, psi.shape[0] - 7))
        for i in range(s.shape[1]):
            s[0, i] = psi[i + 7, 0]

    if np.sum(locked_parameters) != 0:
        s = np.zeros((1, psi.shape[0] - 7 + locked_parameters.shape[0]))
        j = 1
        for i in range(s.shape[1]):
            if np.all(np.abs(i - locked_parameters) == 1):
                s[0, i] = psi[7 + j, 0]
                j += 1

    # m = Number of state variables (number of rows in a0)
    m = a0.shape[0]

    # THE TRANSITION EQUATION
    c = np.array([[0], [mu * dt]])
    T = np.array([[np.exp(-k * dt), 0], [0, 1]])

    xx = (1 - np.exp(-2 * k * dt)) * (sigmax**2) / (2 * k)
    xy = (1 - np.exp(-k * dt)) * pxe * sigmax * sigmae / k
    yx = (1 - np.exp(-k * dt)) * pxe * sigmax * sigmae / k
    yy = (sigmae**2) * dt
    Q = np.array([[xx, xy], [yx, yy]])
    R = np.eye(Q.shape[0])

    # THE MEASUREMENT EQUATION
    d = np.zeros((N, 1))
    Z = np.zeros((N, m))
    for i in range(N):
        p1 = (1 - np.exp(-2 * k * matur[i])) * (sigmax**2) / (2 * k)
        p2 = (sigmae**2) * matur[i]
        p3 = 2 * (1 - np.exp(-k * matur[i])) * pxe * sigmax * sigmae / k
        d[i, 0] = rnmu * matur[i] - (1 - np.exp(-k * matur[i])) * lambdax / k + 0.5 * (p1 + p2 + p3)
        Z[i, 0] = np.exp(-k * matur[i])
        Z[i, 1] = 1

    # Measurement errors Var-Cov Matrix: Cov[e(t)]=H
    H = np.diag(s.flatten())

    # RUNNING THE KALMAN FILTER
    save_vtt = np.zeros((nobs, N))
    save_vt = np.zeros((nobs, N))
    save_att = np.zeros((nobs, m))
    save_Ptt_1 = np.zeros((nobs, m * m))
    save_Ptt = np.zeros((nobs, m * m))
    save_dFtt_1 = np.zeros((nobs, 1))
    save_vFv = np.zeros((nobs, 1))

    Ptt = P0
    att = a0

    for t in range(nobs):
        Ptt_1 = T @ Ptt @ T.T + R @ Q @ R.T
        Ftt_1 = Z @ Ptt_1 @ Z.T + H
        dFtt_1 = det(Ftt_1)

        att_1 = T @ att + c
        yt = y[t, :].reshape(-1, 1)
        ytt_1 = Z @ att_1 + d
        vt = yt - ytt_1

        att = att_1 + Ptt_1 @ Z.T @ inv(Ftt_1) @ vt
        Ptt = Ptt_1 - Ptt_1 @ Z.T @ inv(Ftt_1) @ Z @ Ptt_1

        ytt = Z @ att + d
        vtt = yt - ytt

        save_vtt[t, :] = vtt.T
        save_vt[t, :] = vt.T
        save_att[t, :] = att.T
        save_Ptt_1[t, :] = np.array([Ptt_1[0, 0], Ptt_1[0, 1], Ptt_1[1, 0], Ptt_1[1, 1]]).flatten()
        save_Ptt[t, :] = np.array([Ptt[0, 0], Ptt[0, 1], Ptt[1, 0], Ptt[1, 1]]).flatten()
        save_dFtt_1[t, :] = dFtt_1
        save_vFv[t, :] = vt.T @ inv(Ftt_1) @ vt

    logL = -(N * nobs / 2) * np.log(2 * np.pi) - 0.5 * np.sum(np.log(save_dFtt_1)) - 0.5 * np.sum(save_vFv)
    log_L = -logL

    return log_L