In [1]:
import pandas as pd
import numpy as np
import pymc as pm
from pymc.model import Model



In [18]:
def process_mrt_data(df):

  df = df.query('viability == 1')
  df['decision_date'] = pd.to_datetime(df['decision_time']).dt.date

  df['user_start_day_dt'] = pd.to_datetime(df['user_start_day']).dt.date
  df['decision_date'] = pd.to_datetime(df['decision_date'])
  df['user_start_day_dt'] = pd.to_datetime(df['user_start_day'])

  df['state_day_type'] = pd.to_datetime(df['decision_time']).apply(lambda x: 1 if x.weekday() >= 5 else 0)

  df = df.drop(columns = ['state_modif'])

  df['state_day_in_study'] = (df['decision_date'] - df['user_start_day_dt']).dt.days + 1
  df['state_day_in_study'] = (df['state_day_in_study'] - 35.5) / 34.5

  desired_order = ['user_id', 'action', 'quality', 'state_tod',
                 'state_b_bar', 'state_a_bar', 'state_app_engage', 'state_day_type',
                 'state_bias', 'state_day_in_study']
  df = df[desired_order]
  return df

def get_user_data(df, user_id):
  return df[df['user_id'] == user_id]

def get_batch_data(df, user_id):
  user_df = get_user_data(df, user_id)
  states_df = user_df.filter(regex='state_*')
  outcomes = user_df['quality']
  actions = user_df['action']

  return np.array(states_df), np.array(outcomes), np.array(actions)

In [None]:
MRT_DATA = pd.read_csv('../../data/oralytics_mrt_data.csv')
MRT_DATA = process_mrt_data(MRT_DATA)
MRT_USERS = MRT_DATA['user_id'].unique().tolist()

## Fitting Models
---

### Helpers

In [38]:
def sigmoid(x):
  return 1 / (1 + np.exp(-x))

def build_zip_model(X, A, Y):
  model = pm.Model()
  with Model() as model:
    d = X.shape[1]
    w_b = pm.MvNormal('w_b', mu=np.zeros(d, ), cov=np.eye(d), shape=d)
    delta_b = pm.MvNormal('delta_b', mu=np.zeros(d,), cov=np.eye(d), shape=(d,))
    w_p = pm.MvNormal('w_p', mu=np.zeros(d, ), cov=np.eye(d), shape=d)
    delta_p = pm.MvNormal('delta_p', mu=np.zeros(d,), cov=np.eye(d), shape=(d,))
    bern_term = X @ w_b + A * (X @ delta_b)
    poisson_term = X @ w_p + A * (X @ delta_p)
    R = pm.ZeroInflatedPoisson("likelihood", psi=1 - sigmoid(bern_term), mu=np.exp(poisson_term), observed=Y)

  return model

def run_zip_map_for_users(users_states, users_actions, users_rewards, num_restarts):
  model_params = {}

  for user_id in users_states.keys():
    print("FOR USER: ", user_id)
    user_states = users_states[user_id]
    d = user_states.shape[1]
    user_actions = users_actions[user_id]
    user_rewards = users_rewards[user_id]
    logp_vals = np.empty(shape=(num_restarts,))
    param_vals = np.empty(shape=(num_restarts, 4 * d))
    for seed in range(num_restarts):
      model = build_zip_model(user_states, user_actions, user_rewards)
      np.random.seed(seed)
      init_params = {'w_b': np.random.randn(d), 'delta_b': np.random.randn(d), 'w_p':  np.random.randn(d), 'delta_p': np.random.randn(d)}
      with model:
        map_estimate = pm.find_MAP(start=init_params)

      w_b = map_estimate['w_b']
      delta_b = map_estimate['delta_b']
      w_p = map_estimate['w_p']
      delta_p = map_estimate['delta_p']
      logp_vals[seed] = model.compile_logp()(map_estimate)
      param_vals[seed] = np.concatenate((w_b, delta_b, w_p, delta_p), axis=None)
    model_params[user_id] = param_vals[np.argmax(logp_vals)]

  return model_params

### Execution

In [31]:
users_states = {}
users_rewards = {}
users_actions = {}
for user_id in MRT_USERS:
    states, rewards, actions = get_batch_data(MRT_DATA, user_id)
    users_rewards[user_id] = rewards
    users_actions[user_id] = actions
    users_states[user_id] = states

In [None]:
zip_model_params = run_zip_map_for_users(users_states, users_actions, users_rewards, num_restarts=5)

## Saving Parameter Values
---

In [40]:
def create_zip_df_from_params(model_columns, zip_model_params):
    rows = []
    for user in zip_model_params.keys():
        values = zip_model_params[user]
        new_row = {'User': user}
        for i in range(1, len(model_columns)):
            new_row[model_columns[i]] = values[i - 1]
        rows.append(new_row)
    df = pd.DataFrame(rows, columns=model_columns)
    return df

non_stat_zip_model_columns = ['User', 'state_tod.Base.Bern', 'state_b_bar.norm.Base.Bern', 'state_a_bar.norm.Base.Bern', 'state_app_engage.Base.Bern', 'state_day_type.Base.Bern', 'state_bias.Base.Bern', 'state_day_in_study.Base.Bern', \
                                    'state_tod.Adv.Bern', 'state_b_bar.norm.Adv.Bern', 'state_a_bar.norm.Adv.Bern', 'state_app_engage.Adv.Bern', 'state_day_type.Adv.Bern', 'state_bias.Adv.Bern', 'state_day_in_study.Adv.Bern', \
                                    'state_tod.Base.Poisson', 'state_b_bar.norm.Base.Poisson', 'state_a_bar.norm.Base.Poisson', 'state_app_engage.Base.Poisson', 'state_day_type.Base.Poisson', 'state_bias.Base.Poisson', 'state_day_in_study.Base.Poisson', \
                                    'state_tod.Adv.Poisson', 'state_b_bar.norm.Adv.Poisson', 'state_a_bar.norm.Adv.Poisson', 'state_app_engage.Adv.Poisson', 'state_day_type.Adv.Poisson', 'state_bias.Adv.Poisson', 'state_day_in_study.Adv.Poisson']

In [41]:
non_stat_zip_df = create_zip_df_from_params(non_stat_zip_model_columns, zip_model_params)

In [42]:
non_stat_zip_df

Unnamed: 0,User,state_tod.Base.Bern,state_b_bar.norm.Base.Bern,state_a_bar.norm.Base.Bern,state_app_engage.Base.Bern,state_day_type.Base.Bern,state_bias.Base.Bern,state_day_in_study.Base.Bern,state_tod.Adv.Bern,state_b_bar.norm.Adv.Bern,...,state_day_type.Base.Poisson,state_bias.Base.Poisson,state_day_in_study.Base.Poisson,state_tod.Adv.Poisson,state_b_bar.norm.Adv.Poisson,state_a_bar.norm.Adv.Poisson,state_app_engage.Adv.Poisson,state_day_type.Adv.Poisson,state_bias.Adv.Poisson,state_day_in_study.Adv.Poisson
0,robas+119@developers.pg.com,0.704819,-0.219070,-0.551283,0.296057,0.095310,-0.388889,0.528720,-0.552845,-0.452327,...,0.122480,4.750021,-0.190037,-0.051027,0.173068,-0.011210,0.237931,-0.103514,0.046726,-0.008124
1,robas+126@developers.pg.com,1.887637,0.148931,0.227601,-0.227763,-0.251803,-0.873269,-0.034331,-0.067863,-0.492663,...,-0.310084,3.509668,-0.602872,-0.058350,1.980269,0.187624,0.336037,-0.353362,0.965874,0.696232
2,digitaldentalcoach+214@gmail.com,0.030180,0.452681,0.583783,0.229515,0.227598,-0.915784,-0.081071,-0.016913,0.005055,...,-0.077419,3.856038,0.040141,-0.070137,-0.005468,-0.020570,0.061503,-0.096538,0.186943,-0.018372
3,robas+135@developers.pg.com,0.509380,0.806978,0.080763,-0.036188,-0.266871,-1.128456,0.013485,-0.407499,0.164025,...,-0.062033,4.886023,-0.212265,-0.014294,-0.223266,0.071092,-0.054011,-0.026570,0.018324,0.078535
4,robas+199@developers.pg.com,2.524064,-1.479089,0.571467,0.051567,-0.246646,-1.155922,1.673725,0.860665,-0.446714,...,0.134673,4.558882,-0.325112,-0.087705,-0.876020,0.354455,-0.008029,-0.186698,-0.127382,0.385549
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
66,digitaldentalcoach+249@gmail.com,0.780269,0.532733,0.708382,-0.324244,0.266438,-0.533279,-0.357277,0.248036,0.120544,...,-0.007471,3.844040,0.024753,0.109853,-0.056942,0.142006,-0.078811,0.033063,0.105845,0.058399
67,digitaldentalcoach+271@gmail.com,0.821079,-0.073651,-0.056678,-0.357528,0.313157,-1.112828,-0.278666,1.089296,0.554711,...,-0.025668,4.906387,0.071867,-0.016929,-0.009032,0.297080,-0.155015,-0.074284,0.268140,0.064047
68,digitaldentalcoach+236@gmail.com,1.735764,-0.897853,0.534857,-0.025728,0.233794,-1.348579,-0.436784,0.855996,0.122716,...,-0.099628,4.887838,-0.007512,0.163930,0.212617,0.248223,0.014151,0.195384,-0.041612,-0.105540
69,robas+118@developers.pg.com,0.689961,-1.203569,-0.901893,0.460361,0.043133,-0.639418,-0.422947,0.364517,0.313125,...,0.000180,4.948944,-0.291100,-0.171781,-0.374005,0.463592,-0.667070,0.016242,0.118200,-0.107490


## Saving to CSV
---

In [176]:
non_stat_zip_df.to_csv('../../sim_env_data/v4_non_stat_zip_model_params.csv')