In [1]:
import numpy as np 
import pandas as pd 
import time
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax.scipy.special import logsumexp

import numpy as np
import optax
from IPython.display import display, Latex
from warnings import filterwarnings
filterwarnings('ignore')

In [2]:
np.random.seed(123)

N = 1000
J = 4
T = 150

# Generate the data
np.random.seed(123)
mu = np.array([-1.71, 0.44, -1.37, -0.91, -1.23, 1]).reshape(-1, 1)
sigma = np.diag(np.array([3.22, 3.24, 2.87, 4.15, 1.38, 1])).reshape(6, 6)

#sigma = np.diag(np.zeros(6)).reshape(6, 6) #This was run as a test for theta MLE homogeneous
# Under the "correct" distributional assumption, that is if we generate homogeneous consumers,
# our MLE hmg estimate for Theta, including gamma, is very well identified. So any problem in 
# identification encountered later is driven by consumer heterogeneity as a confounder.

print(mu, '\n')
print(sigma)


# generate the random parameters
betas = np.random.multivariate_normal(mu.flatten(), sigma, N)
betas_np = betas[:, :-2]
etas_np = betas[:, -2]
gammas_np = betas[:, -1]

[[-1.71]
 [ 0.44]
 [-1.37]
 [-0.91]
 [-1.23]
 [ 1.  ]] 

[[3.22 0.   0.   0.   0.   0.  ]
 [0.   3.24 0.   0.   0.   0.  ]
 [0.   0.   2.87 0.   0.   0.  ]
 [0.   0.   0.   4.15 0.   0.  ]
 [0.   0.   0.   0.   1.38 0.  ]
 [0.   0.   0.   0.   0.   1.  ]]


In [3]:
price_transition_states = pd.read_csv(r'price_transition_states.csv')
price_transition_matrix = pd.read_csv(r'transition_prob_matrix.csv')
price_transition_matrix_np = price_transition_matrix.to_numpy()
price_transition_states_np = price_transition_states.to_numpy()

In [4]:
def simulate_prices(states, transition, T):
    state_indices = np.arange(states.shape[0])

    price_simu = np.zeros((T, 6)) #create a matrix to store the simulated prices
    price_simu[0] = states[0] #fix the initial vector of prices
    
    for t in range(1, T):
        preceding_state = price_simu[t-1, :] #take the preceding state
        index_preceding_state = int(preceding_state[-1] - 1) #take the index of the preceding state (-1 for 0-indexing in Python)
        index_next_state = np.random.choice(state_indices, p=(transition[index_preceding_state, :].flatten())) #draw the next state
        price_simu[t, :] = states[index_next_state] #update the price vector and store it
    return price_simu

In [5]:
price_150_by_6 = simulate_prices(price_transition_states_np, price_transition_matrix_np, T)
prices_150_by_4 = price_150_by_6[:, :-2] #remove the indices column

In [6]:
## generate baseline utility data (no loyalty)
utility_np = np.zeros((T, 1+J, N)) # 1 for the outside option, J for the number of products
for t in range(1, T):
    for i in range(N):
        utility_np[t, 0, i] = np.random.gumbel() #outside option, just a random noise
        utility_np[t, 1:, i] = betas_np[i, :] + etas_np[i]*prices_150_by_4[t, :] + np.random.gumbel(size=J) #utility for the J products

#utility_np_orig = utility_np.copy()

In [7]:
### add loyalty
state_matrix = np.zeros((T, N), dtype=int) #the state at time 0 is 0
state_matrix[1, :] = np.argmax(utility_np[0, :, :], axis=0) #initialize the state simulation

for t in range(1, T-1):
    for i in range(N):
        state_it = state_matrix[t, i]
        for j in range(1, J+1): #exclude the outside option
            utility_np[t, j, i] += gammas_np[i] * (j == state_it)
        choice = np.argmax(utility_np[t, :, i])
        if choice==0:
            state_matrix[t+1, i] = state_it ### if the outside option is chosen, the state remains the same
        else:
            state_matrix[t+1, i] = choice ### if a product is chosen, the state is updated

In [8]:
#utility_orig_jnp = jnp.array(utility_np_orig[100:, :, :])  #50 x 5 x 1000
utility_jnp = jnp.array(utility_np[100:, :, :])            #50 x 5 x 1000
choice_jnp = jnp.argmax(utility_np, axis=1)[100:, :]       #50 x 1000
prices_50_by_4_jnp = jnp.array(prices_150_by_4[100:, :])   #50 x 4
state_matrix_jnp = jnp.array(state_matrix[100:, :])        #50 x 1000

### i: Homogeneous, Two classes, and Mixed Logit
* Homogeneous

In [9]:
@jit
def ccp(theta):
    """
    Compute the choice probabilities for each time period and product for a given theta, for each possible state
    There are 4 possible states (individuals are never in state 0). For a given theta, compute the choice probabilities for each state
    Should a return a (T, J, J+1) array. That is, for each period, for each possible state, the choice probas
    """
    theta_jnp = jnp.array(theta).flatten()
    betas = theta_jnp[:-2]
    eta = theta_jnp[-2]
    gamma = theta_jnp[-1]
    
    #possible states: 0, 1, 2, 3, 4 
    v_1to4_utility_state0 = (betas + eta * prices_50_by_4_jnp).reshape(50, 1, 4)
    v_1to4_utility_state1to4 = (betas + eta * prices_50_by_4_jnp).reshape(50, 1, 4) + gamma * jnp.eye(4)
    v_utility = jnp.concatenate((v_1to4_utility_state0, v_1to4_utility_state1to4), axis=1)
    v_default = jnp.zeros((50, 5, 1))
    v_utility_full = jnp.concatenate((v_default, v_utility), axis=2)

    # Compute choice probabilities 
    log_sumexps = logsumexp(v_utility_full, axis=2, keepdims=True)
    probas = jnp.exp(v_utility_full - log_sumexps) #get the choice probabilities for each time period and product

    return probas
ccp_vec = vmap(ccp)

In [10]:
@jit
def likelihood(theta): #(log)-likelihood function
    probas_theta = ccp(theta) #get the choice probabilities for the candidate theta
    log_likelihood = jnp.sum(jnp.log(probas_theta[jnp.arange(50)[:, None], state_matrix_jnp, choice_jnp])) #sum the log-probabilities of the observed choices
    return -log_likelihood

grad_likelihood = jit(grad(likelihood)) ## gradient of the likelihood function

In [11]:
def minimize_adam(f, grad_f, x0, norm=1e9, tol=0.1, lr=0.05, maxiter=1000, verbose=0, *args): ## generic adam optimizer
  """
  Generic Adam Optimizer. Specify a function f, a starting point x0, possibly a \n
  learning rate in (0, 1). The lower the learning rate, the more stable (and slow) the convergence.
  """
  tic = time.time()
  solver = optax.adam(learning_rate=lr)
  params = jnp.array(x0, dtype=jnp.float32)
  opt_state = solver.init(params)
  iternum = 0
  while norm > tol and iternum < maxiter :
    iternum += 1
    grad = grad_f(params, *args)
    updates, opt_state = solver.update(grad, opt_state, params)
    params = optax.apply_updates(params, updates)
    params = jnp.asarray(params, dtype=jnp.float32)
    norm = jnp.max(jnp.abs(grad))
    if verbose > 0:
      if iternum % 100 == 0:
        print(f"Iteration: {iternum}  Norm: {norm}  theta: {jnp.round(params, 2)}")
    if verbose > 1:
      print(f"Iteration: {iternum}  Norm: {norm}  theta: {jnp.round(params, 2)}")
  tac = time.time()
  if iternum == maxiter:
    print(f"Convergence not reached after {iternum} iterations. \nTime: {tac-tic} seconds. Norm: {norm}")
  else:
    print(f"Convergence reached after {iternum} iterations. \nTime: {tac-tic} seconds. Norm: {norm}")

  return params

In [12]:
theta_MLE_homo = minimize_adam(likelihood, grad_likelihood, jnp.zeros(6), lr=0.01, verbose=1, maxiter=5000)

Iteration: 100  Norm: 7784.0830078125  theta: [-0.84999996  0.31       -0.82       -0.25       -0.55        0.93      ]
Iteration: 200  Norm: 3878.31787109375  theta: [-1.31        0.19999999 -1.24       -0.14       -0.68        1.5999999 ]
Iteration: 300  Norm: 1851.5341796875  theta: [-1.62  0.08 -1.53 -0.16 -0.75  2.  ]
Iteration: 400  Norm: 887.0517578125  theta: [-1.8299999 -0.06      -1.74      -0.25      -0.76       2.23     ]
Iteration: 500  Norm: 418.6636657714844  theta: [-1.99 -0.21 -1.88 -0.37 -0.74  2.36]
Iteration: 600  Norm: 192.72698974609375  theta: [-2.1        -0.35999998 -1.99       -0.52       -0.7         2.4199998 ]
Iteration: 700  Norm: 137.88356018066406  theta: [-2.19 -0.51 -2.07 -0.68 -0.65  2.46]
Iteration: 800  Norm: 120.94245910644531  theta: [-2.26       -0.65       -2.1299999  -0.83       -0.59999996  2.48      ]
Iteration: 900  Norm: 100.07978820800781  theta: [-2.31 -0.77 -2.19 -0.96 -0.55  2.49]
Iteration: 1000  Norm: 79.70939636230469  theta: [-2.36 

In [13]:
print(theta_MLE_homo)

[-2.5353696  -1.2799662  -2.3902292  -1.5117126  -0.35722518  2.5106406 ]


In [15]:
###Computation of standard errors

@jit
def likelihood_it(theta, i, t):
    """
    Computes the likelihood for an individual observation
    """
    probas_theta = ccp(theta)
    likelihood_it = jnp.log(probas_theta[t, state_matrix_jnp[t, i], choice_jnp[t, i]])
    return likelihood_it

grad_likelihood_it = jit(grad(likelihood_it)) ### Takes the gradient of the individual likelihood

@jit
def outer_grad_likelihood(theta, i, t):
    """
    Takes the outer product (column vector x row vector) of the gradient of the individual likelihood
    """
    grad_it = (grad_likelihood_it(theta, i, t)).reshape(-1, 1) 
    return grad_it@grad_it.T


#computes the outer product above for each individual and time period
grad_likelihood_it_vec = vmap(vmap(outer_grad_likelihood, in_axes=(None, 0, None)), in_axes=(None, None, 0)) 

@jit
def compute_standard_errors(theta):
    sum_outers = jnp.sum(grad_likelihood_it_vec(theta, jnp.arange(N), jnp.arange(T)), axis=(0, 1))
    return jnp.diag(jnp.sqrt(jnp.linalg.inv(sum_outers)))

In [16]:
se = compute_standard_errors(theta_MLE_homo)
se

Array([0.02599621, 0.05506609, 0.02310293, 0.06025254, 0.02323857,
       0.00804735], dtype=float32)

* MLE 2 classes

In [17]:
###Two classes: instead of estimating theta, we want to estimate the weights phi_1, phi_2 of each class (?)
theta_k1 = theta_MLE_homo - 3*se
theta_k2 = theta_MLE_homo + 3*se
print(theta_k1)
print(theta_k2)

[-2.6133583 -1.4451644 -2.459538  -1.6924702 -0.4269409  2.4864986]
[-2.457381   -1.114768   -2.3209205  -1.3309549  -0.28750947  2.5347826 ]


In [18]:
@jit
def choice_probas_2classes(phi1):
    phi2 = 1 - phi1
    probas_k1 = ccp(theta_k1)
    probas_k2 = ccp(theta_k2)
    probas = phi1 * probas_k1 + phi2 * probas_k2
    return probas

In [19]:
@jit
def likelihood_2classes(phi1): #(log)-likelihood function
    probas_theta = choice_probas_2classes(phi1) #get the choice probabilities for the candidate theta
    log_likelihood = jnp.sum(jnp.log(probas_theta[jnp.arange(50)[:, None], state_matrix_jnp, choice_jnp])) #sum the log-probabilities of the observed choices
    return -log_likelihood


grad_likelihood_2classes = jit(grad(likelihood_2classes))

In [20]:
weight_1 = minimize_adam(likelihood_2classes, grad_likelihood_2classes, 0.5, verbose=False)
display(Latex(f'$\Theta_1^h$: {theta_k1}'))
display(Latex(f'$\Theta_2^h$: {theta_k2}'))
print(f'Weights: {weight_1.item(), 1-weight_1.item()}')

Convergence reached after 71 iterations. 
Time: 1.3041553497314453 seconds. Norm: 0.09765625


<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

Weights: (0.5020579099655151, 0.49794209003448486)


* MLE assuming that $\Theta^h$ has a normal distribution across the households 

First: sanity check. Let's assume we know $\sigma$: do we recover the proper mu under the correct distributional assumption ?

In [21]:
# Define the number of Monte Carlo draws
S = 1000  # Number of simulation draws

# Generate random draws for Monte Carlo integration
key = jax.random.PRNGKey(123)
mc_draws = jax.random.normal(key, (S, 6))  # 5 parameters (4 betas + 1 eta)

In [22]:
@jit
def mixed_logit_likelihood_correct(theta):
    mu = theta.flatten()  # Mean of the random parameters
    betas_eta = mu + mc_draws * jnp.sqrt(jnp.diag(sigma))
    
    probas_theta = ccp_vec(betas_eta)  # Shape: (S, T, 6)
    probas_theta_avg = jnp.mean(probas_theta, axis=0)  # Shape: (T, 6)
    log_likelihood = jnp.sum(jnp.log(probas_theta_avg[jnp.arange(50)[:, None], state_matrix_jnp, choice_jnp])) 
    return -log_likelihood

grad_mixed_logit_likelihood_correct = jit(grad(mixed_logit_likelihood_correct))

In [23]:
### This does not converge properly and as the norm reduces, the bias is still very high
theta_mixed_logit_correct = minimize_adam(mixed_logit_likelihood_correct, 
                                          grad_mixed_logit_likelihood_correct, 
                                          theta_MLE_homo, 
                                          lr=0.05, maxiter=5000, verbose=0, tol=0.1)
print(theta_mixed_logit_correct)

Convergence reached after 587 iterations. 
Time: 15.25252914428711 seconds. Norm: 0.0977182388305664
[-3.1571164 -0.8895682 -2.804269  -1.686722  -1.1715082  3.8051717]


There is bias. Maybe this is to be expected given the instructions of the exercise.

In [24]:
@jit
def mixed_logit_likelihood(theta):
    mu = theta[:6]  # Mean of the random parameters
    sigma = jnp.exp(theta[6:])
    betas_eta = mu + mc_draws * sigma
    probas_theta = ccp_vec(betas_eta)  # Shape: (S, T, 6)
    probas_theta_avg = jnp.mean(probas_theta, axis=0)  # Shape: (T, 6)
    log_likelihood = jnp.sum(jnp.log(probas_theta_avg[jnp.arange(50)[:, None], state_matrix_jnp, choice_jnp])) 
    return -log_likelihood

grad_mixed_logit_likelihood = jit(grad(mixed_logit_likelihood))

In [25]:
### This does not seem to converge well and the bias is high.
x_start = jnp.concatenate((theta_MLE_homo, jnp.zeros(6)))
theta_mixed_logit = minimize_adam(mixed_logit_likelihood, grad_mixed_logit_likelihood, x_start, lr=0.3, maxiter=1000, verbose=1, tol=0.2)

Iteration: 100  Norm: 21.240543365478516  theta: [-4.47       -0.98999995 -3.98       -1.36       -0.55        3.1
  0.96       -2.1699998   1.         -2.22       -2.8799999   1.1999999 ]
Iteration: 200  Norm: 0.8320649266242981  theta: [-4.69       -0.96999997 -4.14       -1.36       -0.55        3.1399999
  1.02       -2.6499999   1.05       -2.45       -2.98        1.3       ]
Iteration: 300  Norm: 0.579755961894989  theta: [-4.77       -0.96999997 -4.2        -1.37       -0.55        3.1599998
  1.04       -3.1399999   1.06       -2.6        -3.04        1.3399999 ]
Iteration: 400  Norm: 0.4046667218208313  theta: [-4.79       -0.96999997 -4.21       -1.37       -0.55        3.1599998
  1.05       -3.55        1.0699999  -2.72       -3.08        1.35      ]
Iteration: 500  Norm: 0.295127809047699  theta: [-4.79       -0.96999997 -4.21       -1.37       -0.55        3.1599998
  1.05       -3.8999999   1.0699999  -2.82       -3.12        1.3399999 ]
Iteration: 600  Norm: 0.224427655

In [26]:
theta_mixed_logit[:6]

Array([-4.7796974 , -0.96995264, -4.20472   , -1.366045  , -0.5476076 ,
        3.1601624 ], dtype=float32)

In [28]:
jnp.sqrt(jnp.exp(theta_mixed_logit[6:]))

Array([1.6847361 , 0.11596691, 1.7028682 , 0.22995608, 0.20615704,
       1.9556462 ], dtype=float32)

### ii: Re-initialize the initial state and re-run the state computation given choice data

In [29]:
state_matrix_np2 = np.array(jnp.vstack((jnp.zeros((1, 1000)), state_matrix_jnp[1:, :])))

In [30]:
###rerun the state based on choices. Takes 9 secs.
for i in range(N):
    for t in range(49):
        choice_it = choice_jnp[t, i]
        if choice_it != 0:
            state_matrix_np2[t+1, i] = choice_it
        else:
            state_matrix_np2[t+1, i] = state_matrix_np2[t, i]
state_matrix_jnp2 = jnp.array(state_matrix_np2).astype(int)

Again, sanity check. Let's assume sigma is known, and see if we can recover the true mu.

In [31]:
@jit
def mixed_logit_likelihood_correct2(theta):
    mu = theta.flatten()  # Mean of the random parameters
    betas_eta = mu + mc_draws * jnp.sqrt(jnp.diag(sigma))
    
    probas_theta = ccp_vec(betas_eta)  # Shape: (S, T, 6)
    probas_theta_avg = jnp.mean(probas_theta, axis=0)  # Shape: (T, 6)
    log_likelihood = jnp.sum(jnp.log(probas_theta_avg[jnp.arange(50)[:, None], state_matrix_jnp2, choice_jnp])) 
    return -log_likelihood

grad_mixed_logit_likelihood_correct2 = jit(grad(mixed_logit_likelihood_correct2))

In [32]:
### Much closer, but still not there, and one sign is off.
theta_mixed_logit_correct2 = minimize_adam(mixed_logit_likelihood_correct2, 
                                          grad_mixed_logit_likelihood_correct2, 
                                          theta_MLE_homo, 
                                          lr=0.05, maxiter=5000, verbose=0, tol=0.1)
print(theta_mixed_logit_correct2)

Convergence reached after 620 iterations. 
Time: 15.92100715637207 seconds. Norm: 0.09795951843261719
[-2.8360603  -0.38305134 -2.5122223  -1.1973017  -1.3416026   3.8803911 ]


Actual Mixed Logit estimation of $\Theta$ and variance

In [33]:
@jit
def mixed_logit_likelihood2(theta):
    mu = theta[:6]  # Mean of the random parameters
    sigma = jnp.exp(theta[6:]) 
    betas_eta = mu + mc_draws * sigma
    
    probas_theta = ccp_vec(betas_eta)  # Shape: (S, T, 6)
    probas_theta_avg = jnp.mean(probas_theta, axis=0)  # Shape: (T, 6)
    
    log_likelihood = jnp.sum(jnp.log(probas_theta_avg[jnp.arange(50)[:, None], state_matrix_jnp2, choice_jnp])) 
    return -log_likelihood

grad_mixed_logit_likelihood2 = jit(grad(mixed_logit_likelihood2))

In [34]:
### This does not converge easily. Under more cautious learning rates, the bias is very high

x_start = jnp.concatenate((theta_MLE_homo, jnp.zeros(6)))
theta_mixed_logit2 = minimize_adam(mixed_logit_likelihood2, grad_mixed_logit_likelihood2, x_start, lr=0.3, maxiter=1000, verbose=1, tol=0.2)

Iteration: 100  Norm: 17.723278045654297  theta: [-4.17       -0.47       -3.75       -0.85999995 -0.72999996  3.27
  0.95       -2.1399999   1.         -1.9        -2.6299999   1.22      ]
Iteration: 200  Norm: 0.8716120719909668  theta: [-4.14       -0.47       -3.7099998  -0.87       -0.72999996  3.25
  0.95       -2.72        0.98999995 -2.09       -2.75        1.1999999 ]
Iteration: 300  Norm: 0.5742689967155457  theta: [-4.12       -0.47       -3.6999998  -0.87       -0.71999997  3.25
  0.94       -3.23        0.98999995 -2.25       -2.85        1.1899999 ]
Iteration: 400  Norm: 0.39445576071739197  theta: [-4.11       -0.48       -3.6899998  -0.87       -0.71999997  3.24
  0.94       -3.6599998   0.97999996 -2.36       -2.9199998   1.18      ]
Iteration: 500  Norm: 0.2860749661922455  theta: [-4.1        -0.48       -3.6799998  -0.87       -0.71999997  3.24
  0.93       -4.          0.97999996 -2.46       -2.98        1.18      ]
Iteration: 600  Norm: 0.21706745028495789  theta:

Variance-Covariance Matrix:

In [35]:
jnp.sqrt(jnp.exp(theta_mixed_logit[6:]))

Array([1.6847361 , 0.11596691, 1.7028682 , 0.22995608, 0.20615704,
       1.9556462 ], dtype=float32)