In [1]:
import numpy as np 
import pandas as pd 
import time
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax.scipy.special import logsumexp

import numpy as np
import optax
from IPython.display import display, Latex
from warnings import filterwarnings
filterwarnings('ignore')

In [2]:
np.random.seed(123)

N = 1000
J = 4
T = 150

# Generate the data
np.random.seed(123)
mu = np.array([-1.71, 0.44, -1.37, -0.91, -1.23, 1]).reshape(-1, 1)
sigma = np.diag(np.array([3.22, 3.24, 2.87, 4.15, 1.38, 1])).reshape(6, 6)

print(mu, '\n')
print(sigma)


# generate the random parameters
betas = np.random.multivariate_normal(mu.flatten(), sigma, N)
betas_np = betas[:, :-2]
etas_np = betas[:, -2]
gammas_np = betas[:, -1]

[[-1.71]
 [ 0.44]
 [-1.37]
 [-0.91]
 [-1.23]
 [ 1.  ]] 

[[3.22 0.   0.   0.   0.   0.  ]
 [0.   3.24 0.   0.   0.   0.  ]
 [0.   0.   2.87 0.   0.   0.  ]
 [0.   0.   0.   4.15 0.   0.  ]
 [0.   0.   0.   0.   1.38 0.  ]
 [0.   0.   0.   0.   0.   1.  ]]


In [3]:
price_transition_states = pd.read_csv(r'price_transition_states.csv')
price_transition_matrix = pd.read_csv(r'transition_prob_matrix.csv')
price_transition_matrix_np = price_transition_matrix.to_numpy()
price_transition_states_np = price_transition_states.to_numpy()

In [4]:
def simulate_prices(states, transition, T):
    state_indices = np.arange(states.shape[0])

    price_simu = np.zeros((T, 6)) #create a matrix to store the simulated prices
    price_simu[0] = states[0] #fix the initial vector of prices
    
    for t in range(1, T):
        preceding_state = price_simu[t-1, :] #take the preceding state
        index_preceding_state = int(preceding_state[-1] - 1) #take the index of the preceding state (-1 for 0-indexing in Python)
        index_next_state = np.random.choice(state_indices, p=(transition[index_preceding_state, :].flatten())) #draw the next state
        price_simu[t, :] = states[index_next_state] #update the price vector and store it
    return price_simu

In [5]:
price_150_by_6 = simulate_prices(price_transition_states_np, price_transition_matrix_np, T)
prices_150_by_4 = price_150_by_6[:, :-2] #remove the indices column

In [6]:
## generate baseline utility data (no loyalty)
utility_np = np.zeros((T, 1+J, N)) # 1 for the outside option, J for the number of products
for t in range(1, T):
    for i in range(N):
        utility_np[t, 0, i] = np.random.gumbel() #outside option, just a random noise
        utility_np[t, 1:, i] = betas_np[i, :] + etas_np[i]*prices_150_by_4[t, :] + np.random.gumbel(size=J) #utility for the J products

#utility_np_orig = utility_np.copy()

In [7]:
### add loyalty
state_matrix = np.zeros((T, N), dtype=int) #the state at time 0 is 0
state_matrix[1, :] = np.argmax(utility_np[0, :, :], axis=0) #initialize the state simulation

for t in range(1, T-1):
    for i in range(N):
        state_it = state_matrix[t, i]
        for j in range(1, J+1): #exclude the outside option
            utility_np[t, j, i] += gammas_np[i] * (j == state_it)
        choice = np.argmax(utility_np[t, :, i])
        if choice==0:
            state_matrix[t+1, i] = state_it ### if the outside option is chosen, the state remains the same
        else:
            state_matrix[t+1, i] = choice ### if a product is chosen, the state is updated

In [8]:
#utility_orig_jnp = jnp.array(utility_np_orig[100:, :, :])  #50 x 5 x 1000
utility_jnp = jnp.array(utility_np[100:, :, :])            #50 x 5 x 1000
choice_jnp = jnp.argmax(utility_np, axis=1)[100:, :]       #50 x 1000
prices_50_by_4_jnp = jnp.array(prices_150_by_4[100:, :])   #50 x 4
state_matrix_jnp = jnp.array(state_matrix[100:, :])        #50 x 1000

In [None]:
@jit
def ccp(theta):
    """
    Compute the choice probabilities for each time period and product for a given theta, for each possible state
    There are 4 possible states (individuals are never in state 0). For a given theta, compute the choice probabilities for each state
    Should a return a (T, J, J+1) array. That is, for each period, for each possible state, the choice probas
    """
    theta_jnp = jnp.array(theta).flatten()
    betas = theta_jnp[:-2]
    eta = theta_jnp[-2]
    gamma = theta_jnp[-1]
    
    #possible states: 0, 1, 2, 3, 4 
    v_1to4_utility_state0 = (betas + eta * prices_50_by_4_jnp).reshape(50, 1, 4)
    v_1to4_utility_state1to4 = (betas + eta * prices_50_by_4_jnp).reshape(50, 1, 4) + gamma * jnp.eye(4)
    v_utility = jnp.concatenate((v_1to4_utility_state0, v_1to4_utility_state1to4), axis=1)
    v_default = jnp.zeros((50, 5, 1))
    v_utility_full = jnp.concatenate((v_default, v_utility), axis=2)

    # Compute choice probabilities 
    log_sumexps = logsumexp(v_utility_full, axis=2, keepdims=True)
    probas = jnp.exp(v_utility_full - log_sumexps) #get the choice probabilities for each time period and product

    return probas
ccp_vec = vmap(ccp)

In [10]:
@jit
def likelihood(theta): #(log)-likelihood function
    probas_theta = ccp(theta) #get the choice probabilities for the candidate theta
    log_likelihood = jnp.sum(jnp.log(probas_theta[jnp.arange(50)[:, None], state_matrix_jnp, choice_jnp])) #sum the log-probabilities of the observed choices
    return -log_likelihood

grad_likelihood = jit(grad(likelihood)) ## gradient of the likelihood function

In [12]:
def minimize_adam(f, grad_f, x0, norm=1e9, tol=0.1, lr=0.05, maxiter=1000, verbose=0, *args): ## generic adam optimizer
  """
  Generic Adam Optimizer. Specify a function f, a starting point x0, possibly a \n
  learning rate in (0, 1). The lower the learning rate, the more stable (and slow) the convergence.
  """
  tic = time.time()
  solver = optax.adam(learning_rate=lr)
  params = jnp.array(x0, dtype=jnp.float32)
  opt_state = solver.init(params)
  iternum = 0
  while norm > tol and iternum < maxiter :
    iternum += 1
    grad = grad_f(params, *args)
    updates, opt_state = solver.update(grad, opt_state, params)
    params = optax.apply_updates(params, updates)
    params = jnp.asarray(params, dtype=jnp.float32)
    norm = jnp.max(jnp.abs(grad))
    if verbose > 0:
      if iternum % 100 == 0:
        print(f"Iteration: {iternum}  Norm: {norm}  theta: {jnp.round(params, 2)}")
    if verbose > 1:
      print(f"Iteration: {iternum}  Norm: {norm}  theta: {jnp.round(params, 2)}")
  tac = time.time()
  if iternum == maxiter:
    print(f"Convergence not reached after {iternum} iterations. \nTime: {tac-tic} seconds. Norm: {norm}")
  else:
    print(f"Convergence reached after {iternum} iterations. \nTime: {tac-tic} seconds. Norm: {norm}")

  return params

In [17]:
theta_MLE_homo = minimize_adam(likelihood, grad_likelihood, jnp.zeros(6), lr=0.01, verbose=0, maxiter=5000)

Convergence reached after 2517 iterations. 
Time: 29.813474893569946 seconds. Norm: 0.09705352783203125


In [18]:
print(theta_MLE_homo)

[-2.5353696  -1.2799662  -2.3902292  -1.5117126  -0.35722518  2.5106406 ]


In [22]:
###Computation of standard errors

@jit
def likelihood_it(theta, i, t):
    """
    Computes the likelihood for an individual observation
    """
    probas_theta = ccp(theta)
    likelihood_it = jnp.log(probas_theta[t, state_matrix_jnp[t, i], choice_jnp[t, i]])
    return likelihood_it

grad_likelihood_it = jit(grad(likelihood_it)) ### Takes the gradient of the individual likelihood

@jit
def outer_grad_likelihood(theta, i, t):
    """
    Takes the outer product (column vector x row vector) of the gradient of the individual likelihood
    """
    grad_it = (grad_likelihood_it(theta, i, t)).reshape(-1, 1) 
    return grad_it@grad_it.T


#computes the outer product above for each individual and time period
grad_likelihood_it_vec = vmap(vmap(outer_grad_likelihood, in_axes=(None, 0, None)), in_axes=(None, None, 0)) 

@jit
def compute_standard_errors(theta):
    sum_outers = (1/(N*50))*(jnp.sum(grad_likelihood_it_vec(theta, jnp.arange(N), jnp.arange(T)), axis=(0, 1)))
    return jnp.diag(jnp.sqrt(jnp.linalg.inv(sum_outers)))

In [23]:
se = compute_standard_errors(theta_MLE_homo)
se

Array([ 5.812934 , 12.313161 ,  5.165976 , 13.472887 ,  5.1963058,
        1.7994431], dtype=float32)

### MLE 2 classes

In [28]:
###Two classes: instead of estimating theta, we want to estimate the weights phi_1, phi_2 of each class (?)
theta_k1 = theta_MLE_homo - se
theta_k2 = theta_MLE_homo + se
print(theta_k1)
print(theta_k2)

[ -8.348304  -13.593127   -7.5562053 -14.9846     -5.5535307   0.7111975]
[ 3.2775643 11.033195   2.7757468 11.961174   4.839081   4.310084 ]


In [29]:
@jit
def choice_probas_2classes(phi1):
    phi2 = 1 - phi1
    probas_k1 = ccp(theta_k1)
    probas_k2 = ccp(theta_k2)
    probas = phi1 * probas_k1 + phi2 * probas_k2
    return probas

In [34]:
@jit
def likelihood_2classes(phi1): #(log)-likelihood function
    probas_theta = choice_probas_2classes(phi1) #get the choice probabilities for the candidate theta
    log_likelihood = jnp.sum(jnp.log(probas_theta[jnp.arange(50)[:, None], state_matrix_jnp, choice_jnp])) #sum the log-probabilities of the observed choices
    return -log_likelihood


grad_likelihood_2classes = jit(grad(likelihood_2classes))

In [35]:
weight_1 = minimize_adam(likelihood_2classes, grad_likelihood_2classes, 0.5, verbose=False)
display(Latex(f'$\Theta_1^h$: {theta_k1}'))
display(Latex(f'$\Theta_2^h$: {theta_k2}'))
print(f'Weights: {weight_1.item(), 1-weight_1.item()}')

Convergence reached after 204 iterations. 
Time: 2.9518580436706543 seconds. Norm: 0.08203125


<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

Weights: (0.6556161046028137, 0.3443838953971863)


#### MLE assuming that $\Theta^h$ has a normal distribution across the households 

In [139]:
# Define the number of Monte Carlo draws
S = 1000  # Number of simulation draws

# Generate random draws for Monte Carlo integration
key = jax.random.PRNGKey(123)
mc_draws = jax.random.normal(key, (S, 6))  # 5 parameters (4 betas + 1 eta)

In [140]:
@jit
def mixed_logit_likelihood(theta):
    mu = theta[:6]  # Mean of the random parameters
    sigma = jnp.diag(jnp.exp(theta[6:]))
    betas_eta = mu + jnp.dot(mc_draws, sigma.T)
    
    probas_theta = ccp_vec(betas_eta)  # Shape: (S, T, 6)
    probas_theta_avg = jnp.mean(probas_theta, axis=0)  # Shape: (T, 6)
    
    log_likelihood = jnp.sum(jnp.log(probas_theta_avg[jnp.arange(50)[:, None], state_matrix_jnp, choice_jnp])) 
    return -log_likelihood

grad_mixed_logit_likelihood = jit(grad(mixed_logit_likelihood))

In [141]:
### This does not converge properly and as the norm reduces, the bias is still very high
x_start = jnp.concatenate((theta_MLE_homo, 0.1*jnp.ones(6)))
theta_mixed_logit = minimize_adam(mixed_logit_likelihood, grad_mixed_logit_likelihood, x_start, lr=0.005, maxiter=5000, verbose=1, tol=0.1)

Iteration: 100  Norm: 1204.3084716796875  theta: [-2.95       -1.18       -2.62       -1.74       -0.61        2.9199998
  0.39       -0.24        0.39999998 -0.19999999 -0.31        0.06      ]
Iteration: 200  Norm: 657.1005859375  theta: [-3.51       -1.13       -3.11       -1.63       -0.64        3.11
  0.65999997 -0.38        0.7        -0.31       -0.59        0.69      ]
Iteration: 300  Norm: 364.98651123046875  theta: [-3.9299998  -1.06       -3.46       -1.54       -0.66999996  3.1999998
  0.82       -0.44        0.84999996 -0.41       -0.77        0.94      ]
Iteration: 400  Norm: 232.00759887695312  theta: [-4.19       -1.02       -3.6999998  -1.4599999  -0.68        3.26
  0.90999997 -0.47        0.94       -0.48       -0.9         1.04      ]
Iteration: 500  Norm: 158.82913208007812  theta: [-4.38       -0.98999995 -3.87       -1.41       -0.68        3.31
  0.96999997 -0.45999998  0.98999995 -0.53999996 -1.          1.13      ]
Iteration: 600  Norm: 115.94561767578125  th

In [142]:
theta_mixed_logit[:6] - mu.flatten() #quite strong bias

Array([-6.116932  , -4.1460795 , -5.8156753 , -2.709583  ,  0.42825902,
        5.4318976 ], dtype=float32)

In [143]:
state_matrix_np2 = np.array(jnp.vstack((jnp.zeros((1, 1000)), state_matrix_jnp[1:, :])))

In [144]:
###rerun the state based on choices. Takes 9 secs.
for i in range(N):
    for t in range(49):
        choice_it = choice_jnp[t, i]
        if choice_it != 0:
            state_matrix_np2[t+1, i] = choice_it
        else:
            state_matrix_np2[t+1, i] = state_matrix_np2[t, i]
state_matrix_jnp2 = jnp.array(state_matrix_np2).astype(int)

In [145]:
@jit
def mixed_logit_likelihood2(theta):
    mu = theta[:6]  # Mean of the random parameters
    sigma = jnp.diag(jnp.exp(theta[6:]))
    betas_eta = mu + jnp.dot(mc_draws, sigma.T)
    
    probas_theta = ccp_vec(betas_eta)  # Shape: (S, T, 6)
    probas_theta_avg = jnp.mean(probas_theta, axis=0)  # Shape: (T, 6)
    
    log_likelihood = jnp.sum(jnp.log(probas_theta_avg[jnp.arange(50)[:, None], state_matrix_jnp2, choice_jnp])) 
    return -log_likelihood

grad_mixed_logit_likelihood2 = jit(grad(mixed_logit_likelihood2))

In [147]:
### This does not converge properly and as the norm reduces, the bias is still very high
x_start = jnp.concatenate((theta_MLE_homo, 0.1*jnp.ones(6)))
theta_mixed_logit2 = minimize_adam(mixed_logit_likelihood2, grad_mixed_logit_likelihood2, x_start, lr=0.5, maxiter=10_000, verbose=1, tol=0.2)

Iteration: 100  Norm: 87.0124282836914  theta: [-3.8899999 -0.53      -3.54      -0.9       -0.7        3.1799998
  0.87      -2.32       0.93      -2.3999999 -2.75       1.05     ]
Iteration: 200  Norm: 0.744306206703186  theta: [-4.1        -0.48       -3.6799998  -0.87       -0.71999997  3.24
  0.93       -2.8899999   0.97999996 -2.51       -2.8899999   1.18      ]
Iteration: 300  Norm: 0.4810718894004822  theta: [-4.1        -0.48       -3.6799998  -0.87       -0.71999997  3.24
  0.93       -3.4199998   0.97999996 -2.56       -2.96        1.18      ]
Iteration: 400  Norm: 0.32755860686302185  theta: [-4.0899997  -0.48       -3.6699998  -0.87       -0.71999997  3.23
  0.93       -3.85        0.97999996 -2.61       -3.01        1.17      ]
Iteration: 500  Norm: 0.23669716715812683  theta: [-4.0899997  -0.48       -3.6699998  -0.87       -0.71999997  3.23
  0.93       -4.2         0.97999996 -2.6399999  -3.04        1.17      ]
Iteration: 600  Norm: 57.805503845214844  theta: [-4.13  

In [151]:
theta_mixed_logit[:6] - mu.flatten()

Array([-6.116932  , -4.1460795 , -5.8156753 , -2.709583  ,  0.42825902,
        5.4318976 ], dtype=float32)

In [152]:
theta_mixed_logit2[:6] - mu.flatten()

Array([-2.3790617 , -0.92039466, -2.3002887 ,  0.04121357,  0.5106937 ,
        2.2324889 ], dtype=float32)

array([-1.71,  0.44, -1.37, -0.91, -1.23,  1.  ])