In [12]:
from typing import Union

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy
import numpy as np

from environments.continuous_teaching import ContinuousTeaching
from baseline_policies.threshold import Threshold

In [13]:
%config InlineBackend.figure_format = "retina"
sns.set()

In [14]:
df_param = pd.read_csv("data/param_exp_data.csv", index_col=0)
df_param

Unnamed: 0,mu1,sigma_u1,sigma_w1,mu2,sigma_u2,sigma_w2
unconstrained,-5.661843,1.844262,1.616331,-0.723793,1.720237,1.218074


In [22]:
def generate_param(n_u, n_w, df_param):
    
    mu = np.array([df_param.loc["unconstrained", f"mu{i}"] for i in range(1, 3)])
    sg_u = np.array([df_param.loc["unconstrained", f"sigma_u{i}"] for i in range(1, 3)])
    sg_w = np.array([df_param.loc["unconstrained", f"sigma_w{i}"] for i in range(1, 3)])
    
    Zu = np.random.normal(np.zeros(2), sg_u, size=(n_u, 2))
    Zw = np.random.normal(np.zeros(2), sg_w, size=(n_w, 2))
    
    return mu, Zu, Zw

In [31]:
def p_recall(u: int, w: int, r: int, x: Union[int, float], Zu: np.ndarray, Zw: np.ndarray, mu=np.ndarray):
    
    """

    :param x: time elapsed since last presentation
    :param r: number of repetition (number of presentation - 1)
    :param u: user ID
    :param w: item ID
    :return: probability of recall
    """

    Za = mu[0] + Zu[u, 0] + Zw[w, 0]
    Zb = mu[1] + Zu[u, 1] + Zw[w, 1]

    a = np.exp(Za)
    b = scipy.special.expit(Zb)
    neg_rate = - a * x * (1 - b) ** r
    p = np.exp(neg_rate)
    return p

In [33]:
n_u = 10
n_w = 10
mu, Zu, Zw = generate_param(df_param=df_param, n_u=n_u, n_w=n_w)

In [34]:
p_recall(u=0, w=0, r=2, x=3.2, Zu=Zu, Zw=Zw, mu=mu)

0.9962307154707872