In [42]:
import numpy as np
import scipy
from scipy.special import expit
import arviz as az
import matplotlib.pyplot as plt
import multiprocessing as mp
import pandas as pd
import seaborn as sns
from datetime import datetime
now = datetime.now

In [37]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [38]:
sns.set()

# Data

In [4]:
df = pd.read_csv("data_full.csv", index_col=0)
df.drop(df[(df.domain != "active.fi") | (df.n_session_done != 14)].index, inplace=True)
df

Unnamed: 0,user,domain,condition,item,success,teacher_md,learner_md,psy_md,session,is_eval,ts_display,ts_reply,n_session_done
49739,goldfish@active.fi,active.fi,ThresholdCondition,1506,True,leitner,,grid,1900,False,2020-09-30 17:51:37.666000+00:00,2020-09-30 17:51:41.042000+00:00,14
49740,goldfish@active.fi,active.fi,ThresholdCondition,190,True,leitner,,grid,1900,False,2020-09-30 17:51:42.081000+00:00,2020-09-30 17:51:43.910000+00:00,14
49741,goldfish@active.fi,active.fi,ThresholdCondition,1000,True,leitner,,grid,1900,False,2020-09-30 17:51:44.980000+00:00,2020-09-30 17:51:46.413000+00:00,14
49742,goldfish@active.fi,active.fi,ThresholdCondition,1506,True,leitner,,grid,1900,False,2020-09-30 17:51:47.476000+00:00,2020-09-30 17:51:50.206000+00:00,14
49743,goldfish@active.fi,active.fi,ThresholdCondition,190,True,leitner,,grid,1900,False,2020-09-30 17:51:51.245000+00:00,2020-09-30 17:51:53.477000+00:00,14
...,...,...,...,...,...,...,...,...,...,...,...,...,...
147617,azalea@active.fi,active.fi,ThresholdCondition,179,True,threshold,exp_decay,grid,2571,True,2020-10-08 05:52:21.524000+00:00,2020-10-08 05:52:31.640000+00:00,14
147618,azalea@active.fi,active.fi,ThresholdCondition,1071,False,threshold,exp_decay,grid,2571,True,2020-10-08 05:52:32.658000+00:00,2020-10-08 05:52:45.959000+00:00,14
147619,azalea@active.fi,active.fi,ThresholdCondition,1196,True,threshold,exp_decay,grid,2571,True,2020-10-08 05:52:48.143000+00:00,2020-10-08 05:52:51.527000+00:00,14
147620,azalea@active.fi,active.fi,ThresholdCondition,1282,True,threshold,exp_decay,grid,2571,True,2020-10-08 05:52:52.543000+00:00,2020-10-08 05:52:56.460000+00:00,14


In [5]:
df["ts_display"] = pd.to_datetime(df["ts_display"])  # str to datetime
df["ts_reply"] = pd.to_datetime(df["ts_reply"]) 

In [6]:
# Convert timestamps into seconds
beginning_history = pd.Timestamp("1970-01-01", tz="UTC")
df["timestamp"] = (df["ts_reply"] - beginning_history).dt.total_seconds().values

In [7]:
n_u = len(df.user.unique())
n_w = len(df.item.unique())
print(f"n_u={n_u}, n_w={n_w}" )

n_u=53, n_w=1998


In [8]:
# Copy actual item ID
df["item_id"] = df.item

In [9]:
for i, i_id in enumerate(df.item_id.unique()): 
    df.loc[df.item_id == i_id, 'item'] = i

In [10]:
print(df.item.min())
print(df.item.max())

0
1997


In [11]:
n_o_by_u = np.zeros(shape=n_u, dtype=int)
for u, (user, user_df) in enumerate(df.groupby("user")):
    n_o_by_u[u] = len(user_df) - len(user_df.item.unique()) # Do not count first presentation
n_o_max = n_o_by_u.max()
n_obs = n_o_by_u.sum()
print( f"n_o_max={n_o_max}")
print(f"n_o_min={n_o_by_u.min()}")
print(f"n_obs={n_obs}")

n_o_max=1404
n_o_min=1285
n_obs=70618


In [12]:
y = np.zeros(shape=n_obs, dtype=int)
d = np.zeros(shape=n_obs, dtype=float)
w = np.zeros(shape=n_obs, dtype=int)
r = np.zeros(shape=n_obs, dtype=int)
u = np.zeros(shape=n_obs, dtype=int)

idx = 0

for i_u, (user, user_df) in enumerate(df.groupby("user")):
    
    user_df = user_df.sort_values(by="timestamp")
    seen = user_df.item.unique()
    w_u = user_df.item.values      # Words
    ts_u = user_df.timestamp.values
    counts = {word: -1 for word in seen}
    last_pres = {word: None for word in seen}
    r_u = np.zeros(len(user_df))   # Number of repetitions
    d_u = np.zeros(r_u.shape)      # Time elapsed since last repetition 
    for i, word in enumerate(w_u):
        ts = ts_u[i]
        r_u[i] = counts[word]
        if last_pres[word] is not None:
            d_u[i] =  ts - last_pres[word]
        counts[word] += 1
        last_pres[word] = ts
    
    to_keep = r_u >= 0
    y_u = user_df.success.values[to_keep]
    r_u = r_u[to_keep]
    w_u = w_u[to_keep]
    d_u = d_u[to_keep]
    
    n_ou = len(y_u)
    # assert n_o_by_u[i_u] == n_ou
    
    y[idx:idx+n_ou] = y_u
    d[idx:idx+n_ou] = d_u
    w[idx:idx+n_ou] = w_u
    r[idx:idx+n_ou] = r_u
    u[idx:idx+n_ou] = i_u
    
    idx += n_ou

data = {'n_u': len(np.unique(u)), 'n_w': len(np.unique(w)), 'n_obs': len(y),
        'u': u, 'w': w, 
        'd': d, 'r': r,
        'y': y}

# Inference

In [53]:
def EM(data):
    t0 = now()
    
    n_u = data['n_u']
    n_w = data['n_w']
    x = data['d']
    r = data['r']
    y = data['y']
    
    eps = np.finfo(float).eps

    old_sg_u = np.zeros(2) -np.inf
    old_sg_w = np.zeros(2) -np.inf
    old_mu = np.zeros(2)

    itr = 1
    
    def optimize_Zu_Zw_FIRST(x, y):

        def opt(param):

            Zu = param[:n_u*2].reshape((n_u, 2))
            Zw = param[n_u*2:].reshape((n_w, 2))

            Z = Zu[u] + Zw[w]

            a = np.exp(Z[:, 0]) 
            b = expit(Z[:, 1])
            neg_rate = -a*x  *(1-b)**r                                             
            ll_y = np.sum(neg_rate*y+(1-y)*np.log(1 - np.exp(neg_rate)+eps))       

            return - ll_y

        res = scipy.optimize.minimize(opt, 
                                      x0=np.zeros(n_u*2+n_w*2))

        return res.x[:n_u*2].reshape((n_u, 2)), res.x[n_u*2:].reshape((n_w, 2))

    def optimize_Zu_Zw(x, y, sg_w, sg_u, mu):

        def opt(param):

            Zu = param[:n_u*2].reshape((n_u, 2))
            Zw = param[n_u*2:].reshape((n_w, 2))

            Z = Zu[u] + Zw[w]

            a = np.exp(Z[:, 0]) 
            b = expit(Z[:, 1])
            neg_rate = -a*x  *(1-b)**r                                             
            ll_y = np.sum(neg_rate*y+(1-y)*np.log(1 - np.exp(neg_rate)+eps))       

            ll_zw = scipy.stats.norm.logpdf(Zw, loc=mu/2, scale=sg_w).sum()
            ll_zu = scipy.stats.norm.logpdf(Zu, loc=mu/2, scale=sg_u).sum() 

            return - ll_y - ll_zw - ll_zu

        res = scipy.optimize.minimize(opt, 
                                      x0=np.zeros(n_u*2+n_w*2))

        return res.x[:n_u*2].reshape((n_u, 2)), res.x[n_u*2:].reshape((n_w, 2))


    while True:
        if itr == 1:
            Zu, Zw = optimize_Zu_Zw_FIRST(x=x, y=y)
        else:
            Zu, Zw = optimize_Zu_Zw(x=x, y=y, sg_w=old_sg_w, sg_u=old_sg_u, mu=old_mu)

        mu = np.mean(Zu[u] + Zw[w], axis=0)
        sg_u = np.std(Zu, axis=0)
        sg_w = np.std(Zw, axis=0)

        print(itr, mu, sg_u, sg_w, str(now()-t0))
        if np.allclose([mu, sg_u, sg_w], [old_mu, old_sg_u, old_sg_w]):
            print("Converged!")
            break
        old_mu = mu
        old_sg_u = sg_u
        old_sg_w = sg_w
        itr += 1

In [None]:
EM(data)

1 [-16.94619839 -20.99870203] [16.0909386  49.05058915] [ 9.71073934 22.63069936] 4:49:27.182351
2 [-17.00821374 -20.95043688] [13.52368251 47.20173079] [11.8018168  29.35290579] 8:29:05.314733
3 [-19.08501404 -23.17194158] [27.40584356 42.09876167] [ 7.72446913 32.22399177] 13:59:48.819766
4 [-22.48122818 -24.57472509] [30.8497112  71.54688336] [19.56515469 48.64090824] 17:17:31.219626
5 [-20.78283871 -24.81993658] [31.54914947 88.36781003] [10.57581621 60.18088907] 22:02:16.240500
6 [-18.28303232 -20.95307164] [14.46041845 35.75133612] [ 8.50047785 15.02349218] 23:50:50.699903
7 [-16.50181968 -19.63288655] [ 9.33840308 26.56615665] [ 4.66815941 11.81378488] 1 day, 1:26:44.334660
8 [-17.23302386 -20.90760523] [26.89921168 68.49521575] [ 5.15342896 29.7392249 ] 1 day, 5:17:56.582611
9 [-26.2403717  -33.36425159] [52.8815587  85.14245768] [12.28110172 73.75505398] 1 day, 9:00:26.363676
10 [-18.35777946 -20.70083084] [14.61272705 40.3262262 ] [ 6.34044552 18.17971938] 1 day, 10:58:40.812