In [1]:
from jax import numpy as jnp
import numpy as np
import jax
from jax.nn import sigmoid
import sys
import jax.random as random
sys.path.append("/Users/changjc/workspace/bayesianquilts")

In [2]:
abilities = np.array([0, 0.5, 0.25])[:, np.newaxis, np.newaxis]
difficulties = np.array([[0, 1, 2, 3], [-2, 0, 3, 4]])[np.newaxis, ...]
discriminations = np.array([1, 2])[np.newaxis, :, np.newaxis]

In [None]:
dsigmoid = lambda x: sigmoid(x) * (1 - sigmoid(x))
ddsigmoid = lambda x: sigmoid(x) * (1 - sigmoid(x)) * (1 - 2 * sigmoid(x))


#
# N x I x K


def p_ni(abilities, difficulties, discriminations):  # dimensio
    N = abilities.shape[0]
    I = discriminations.shape[1]
    K = difficulties.shape[-1] + 1
    p_cum = sigmoid(discriminations * (abilities - difficulties))
    # first partials, will be N x I x K x d where d is the dimension of the parameter
    # d_cum_abilites will be N x I x K x N
    dp_cum_dabilities = (
        (p_cum * (1 - p_cum))[..., jnp.newaxis]
        * jnp.ones((N, I, K - 1, N))
        * discriminations[..., jnp.newaxis]
    )
    dp_cum_ddifficulties = (
        -(p_cum * (1 - p_cum))[..., jnp.newaxis, jnp.newaxis]
        * jnp.ones((N, I, K - 1, I, K - 1))
        * discriminations[..., np.newaxis, jnp.newaxis]
    )
    dp_cum_ddiscrimintations = (
        (p_cum * (1 - p_cum))[..., np.newaxis]
        * jnp.ones((N, I, K - 1, I))
        * ((abilities - difficulties)[..., jnp.newaxis])
    )

    # second partials
    # Diagonal terms
    d2p_cum_dabilities2 = (
        (p_cum * (1 - p_cum) * (1 - 2 * p_cum))[..., jnp.newaxis]
        * jnp.ones((N, I, K - 1, N))
        * discriminations[..., jnp.newaxis] ** 2
    )
    d2p_cum_ddiscrimintations2 = (
        (p_cum * (1 - p_cum) * (1 - 2 * p_cum))[..., jnp.newaxis]
        * jnp.ones((N, I, K - 1, I))
        * (abilities - difficulties)[..., jnp.newaxis] ** 2
    )
    d2p_cum_ddifficulties2 = (
        p_cum * (1 - p_cum) * (1 - 2 * p_cum) * discriminations * discriminations**2
    )[..., jnp.newaxis, jnp.newaxis] * jnp.ones((N, I, K - 1, I, K - 1))

    # mixed partials
    d2p_cum_dabilities_difficulties = (
        -(p_cum * (1 - p_cum) * (1 - 2 * p_cum))[
            ..., jnp.newaxis, jnp.newaxis, jnp.newaxis
        ]
        * discriminations[..., jnp.newaxis, jnp.newaxis, jnp.newaxis] ** 2
    ) * jnp.ones((N, I, K - 1, N, I, K - 1))
    d2p_cum_dabilities_discriminations = (
        (p_cum * (1 - p_cum))
        + p_cum * (1 - p_cum) * (1 - 2 * p_cum) * discriminations**2
    )[..., jnp.newaxis, jnp.newaxis] * jnp.ones((N, I, K - 1, N, I))
    d2p_cum_ddifficulties_discriminations = (
        -p_cum * (1 - p_cum)
        - p_cum
        * (1 - p_cum)
        * (1 - 2 * p_cum)
        * discriminations
        * (abilities - difficulties)
    )[..., jnp.newaxis, jnp.newaxis, jnp.newaxis] * jnp.ones((N, I, K - 1, I, I, K - 1))

    p_cum = jnp.pad(p_cum, ((0, 0), (0, 0), (1, 0)), constant_values=0)
    p_cum = jnp.pad(p_cum, ((0, 0), (0, 0), (0, 1)), constant_values=1)

    # padding for gradient
    dp_cum_dabilities = jnp.pad(
        dp_cum_dabilities, ((0, 0), (0, 0), (1, 1), (0, 0)), constant_values=0
    )
    dp_cum_ddifficulties = jnp.pad(
        dp_cum_ddifficulties,
        ((0, 0), (0, 0), (1, 1), (0, 0), (0, 0)),
        constant_values=0,
    )
    dp_cum_ddiscrimintations = jnp.pad(
        dp_cum_ddiscrimintations, ((0, 0), (0, 0), (1, 1), (0, 0)), constant_values=0
    )
    # diagonal hessian
    d2p_cum_dabilities2 = jnp.pad(
        d2p_cum_dabilities2, ((0, 0), (0, 0), (1, 1), (0, 0)), constant_values=0
    )
    d2p_cum_ddifficulties2 = jnp.pad(
        d2p_cum_ddifficulties2,
        ((0, 0), (0, 0), (1, 1), (0, 0), (0, 0)),
        constant_values=0,
    )
    d2p_cum_ddiscrimintations2 = jnp.pad(
        d2p_cum_ddiscrimintations2, ((0, 0), (0, 0), (1, 1), (0, 0)), constant_values=0
    )
    # pad the mixed terms
    d2p_cum_dabilities_difficulties = jnp.pad(
        d2p_cum_dabilities_difficulties,
        ((0, 0), (0, 0), (1, 1), (0, 0), (0, 0), (0, 0)),
        constant_values=0,
    )
    d2p_cum_ddifficulties_discriminations = jnp.pad(
        d2p_cum_ddifficulties_discriminations,
        ((0, 0), (0, 0), (1, 1), (0, 0), (0, 0), (0, 0)),
        constant_values=0,
    )
    d2p_cum_dabilities_discriminations = jnp.pad(
        d2p_cum_dabilities_discriminations,
        ((0, 0), (0, 0), (1, 1), (0, 0), (0, 0)),
        constant_values=0,
    )

    dp_dabilities = dp_cum_dabilities[:, :, 1:, ...] - dp_cum_dabilities[:, :, :-1, ...]
    dp_ddifficulties = (
        dp_cum_ddifficulties[:, :, 1:, ...] - dp_cum_ddifficulties[:, :, :-1, ...]
    )
    dp_ddiscrimintations = (
        dp_cum_ddiscrimintations[:, :, 1:, ...]
        - dp_cum_ddiscrimintations[:, :, :-1, ...]
    )

    d2p_dabilities = (
        d2p_cum_dabilities2[:, :, 1:, ...] - d2p_cum_dabilities2[:, :, :-1, ...]
    )
    d2p_ddiscriminations = (
        d2p_cum_ddiscrimintations2[:, :, 1:, ...]
        - d2p_cum_ddiscrimintations2[:, :, :-1, ...]
    )
    d2p_ddifficulties = (
        d2p_cum_ddifficulties2[:, :, 1:, ...] - d2p_cum_ddifficulties2[:, :, :-1, ...]
    )
    d2p_dabilities_ddifficulties = (
        d2p_cum_dabilities_difficulties[:, :, 1:, ...]
        - d2p_cum_dabilities_difficulties[:, :, :-1, ...]
    )
    d2p_dabilities_ddiscrimintations = (
        d2p_cum_dabilities_discriminations[:, :, 1:, ...]
        - d2p_cum_dabilities_discriminations[:, :, :-1, ...]
    )

    d2p_ddifficulties_ddiscrimintations = (
        d2p_cum_ddifficulties_discriminations[:, :, 1:, ...]
        - d2p_cum_ddifficulties_discriminations[:, :, :-1, ...]
    )

    p = p_cum[..., 1:] - p_cum[..., :-1]

    # compute derivatives

    gradients = {
        "abilities": dp_dabilities,
        "difficulties": dp_ddifficulties,
        "discriminations": dp_ddiscrimintations,
    }

    grad_log_p = {
        "abilities": dp_dabilities / p_cum[..., jnp.newaxis],
        "difficulties": dp_ddifficulties / p_cum[..., jnp.newaxis, jnp.newaxis],
        "discriminations": dp_ddiscrimintations / p_cum[..., jnp.newaxis],
    }
    grad2_p = {
        ("abilities", "abilites"): d2p_dabilities,
        ("discriminations", "discriminations"): d2p_ddiscriminations,
        ("difficulties", "difficulties"): d2p_ddifficulties,
    }
    grad2_log_p = {}

    return {
        "p": p,
        "log(p)": jnp.log(p),
        "grad(p)": gradients,
        "grad(log(p))": grad_log_p,
        "grad(grad(p))": grad2_p,
    }


def find_a(abilities, difficulties, discriminations):
    vals = p_ni(abilities, difficulties, discriminations)


def find_b(abilities, difficulties, discriminations, a):
    return


p_ni(abilities, difficulties, discriminations)

In [6]:
K = 10
from autoencirt.data.rwa import item_text, get_data, to_reverse


ModuleNotFoundError: No module named 'tensorflow.io'

In [None]:
pd_data = get_data(reorient=True, pandas=True)
X = pd_data[0].iloc[:, :22]
N, I = X.shape