In [1]:
import numpy as np 
import pandas as pd 
import time
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax.scipy.special import logsumexp

import numpy as np
import optax
from IPython.display import display, Latex
from warnings import filterwarnings
filterwarnings('ignore')

In [2]:
np.random.seed(123)

N = 1000
J = 4
T = 50

# Generate the data
np.random.seed(123)
mu = np.array([-1.71, 0.44, -1.37, -0.91, -1.23]).reshape(-1, 1)
sigma = np.diag(np.array([3.22, 3.24, 2.87, 4.15, 1.38])).reshape(5, 5)

# sigma = np.diag(np.zeros(5)).reshape(5, 5) #This was done for testing, and we do recover the true parameters well enough with the MLE homogeneous approach



print(mu, '\n')
print(sigma)


# generate the random parameters
betas = np.random.multivariate_normal(mu.flatten(), sigma, N)
betas_np = betas[:, :-1]
etas_np = betas[:, -1]

[[-1.71]
 [ 0.44]
 [-1.37]
 [-0.91]
 [-1.23]] 

[[3.22 0.   0.   0.   0.  ]
 [0.   3.24 0.   0.   0.  ]
 [0.   0.   2.87 0.   0.  ]
 [0.   0.   0.   4.15 0.  ]
 [0.   0.   0.   0.   1.38]]


In [3]:
###Extract data

price_transition_states = pd.read_csv(r'price_transition_states.csv')
price_transition_matrix = pd.read_csv(r'transition_prob_matrix.csv')

print(price_transition_states.head())
print(price_transition_matrix.head())

         X1        X2        X3        X4  X5  id
0  0.809830  2.730574  0.770491  2.966650   0   1
1  0.785441  2.453663  0.771422  2.964841   0   2
2  0.823926  2.238301  0.790901  2.914317   0   3
3  1.005571  1.581941  0.918171  2.084867   0   4
4  0.798525  2.693030  0.762935  2.690978   0   5
         V1        V2        V3        V4        V5        V6        V7  \
0  0.050283  0.020251  0.017505  0.003655  0.026568  0.006548  0.006352   
1  0.040564  0.023169  0.016224  0.003816  0.024936  0.007435  0.007776   
2  0.034444  0.016302  0.015208  0.004359  0.022437  0.006756  0.007631   
3  0.017591  0.008223  0.009868  0.007142  0.013897  0.006043  0.011003   
4  0.042309  0.016379  0.015619  0.003890  0.027125  0.006523  0.006811   

         V8        V9       V10  ...       V91       V92       V93       V94  \
0  0.008375  0.016022  0.013590  ...  0.017482  0.003807  0.008912  0.005710   
1  0.006751  0.014843  0.012902  ...  0.020574  0.003186  0.010063  0.005734   
2  0.0084

In [4]:
price_transition_states

Unnamed: 0,X1,X2,X3,X4,X5,id
0,0.809830,2.730574,0.770491,2.966650,0,1
1,0.785441,2.453663,0.771422,2.964841,0,2
2,0.823926,2.238301,0.790901,2.914317,0,3
3,1.005571,1.581941,0.918171,2.084867,0,4
4,0.798525,2.693030,0.762935,2.690978,0,5
...,...,...,...,...,...,...
95,1.096901,2.540894,0.933728,2.026588,0,96
96,1.228067,1.985250,0.941085,2.525047,0,97
97,1.118906,3.728712,1.055625,2.475710,0,98
98,0.847625,2.102348,1.363938,1.771260,0,99


In [5]:
price_transition_matrix

Unnamed: 0,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,...,V91,V92,V93,V94,V95,V96,V97,V98,V99,V100
0,0.050283,0.020251,0.017505,0.003655,0.026568,0.006548,0.006352,0.008375,0.016022,0.013590,...,0.017482,0.003807,0.008912,0.005710,0.016266,0.008533,0.006934,0.002275,0.001474,0.008172
1,0.040564,0.023169,0.016224,0.003816,0.024936,0.007435,0.007776,0.006751,0.014843,0.012902,...,0.020574,0.003186,0.010063,0.005734,0.013634,0.008189,0.006430,0.001822,0.001520,0.009950
2,0.034444,0.016302,0.015208,0.004359,0.022437,0.006756,0.007631,0.008431,0.019615,0.015308,...,0.016124,0.004657,0.010655,0.004589,0.014261,0.009842,0.009434,0.002486,0.001853,0.009703
3,0.017591,0.008223,0.009868,0.007142,0.013897,0.006043,0.011003,0.009424,0.022074,0.012562,...,0.010919,0.007949,0.013787,0.002806,0.011120,0.011316,0.016923,0.003837,0.003444,0.015517
4,0.042309,0.016379,0.015619,0.003890,0.027125,0.006523,0.006811,0.008304,0.017501,0.014749,...,0.016193,0.004160,0.009280,0.004994,0.017674,0.008824,0.008212,0.002280,0.001601,0.008316
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,0.022406,0.010178,0.011546,0.005559,0.015755,0.005905,0.007434,0.010891,0.027838,0.015147,...,0.010508,0.008855,0.010696,0.003277,0.013045,0.012874,0.017172,0.004015,0.003121,0.010308
96,0.014125,0.006274,0.008583,0.005706,0.010867,0.005148,0.006759,0.012477,0.030100,0.013551,...,0.007031,0.013129,0.010076,0.002364,0.012457,0.014006,0.033324,0.005777,0.003969,0.009970
97,0.013230,0.006321,0.007653,0.004281,0.010171,0.004477,0.006573,0.022304,0.023152,0.009947,...,0.007212,0.014083,0.006911,0.002808,0.011422,0.010620,0.013875,0.008872,0.003258,0.008857
98,0.016398,0.007619,0.009392,0.006894,0.012828,0.006098,0.009410,0.009549,0.022639,0.012384,...,0.010010,0.010943,0.011931,0.002618,0.011045,0.012419,0.017758,0.004460,0.005623,0.013918


In [6]:
price_transition_matrix_np = price_transition_matrix.to_numpy()
price_transition_states_np = price_transition_states.to_numpy()

In [7]:
def simulate_prices(states, transition, T):
    state_indices = np.arange(states.shape[0])

    price_simu = np.zeros((T, 6)) #create a matrix to store the simulated prices
    price_simu[0] = states[0] #fix the initial vector of prices
    
    for t in range(1, T):
        preceding_state = price_simu[t-1, :] #take the preceding state
        index_preceding_state = int(preceding_state[-1] - 1) #take the index of the preceding state (-1 for 0-indexing in Python)
        index_next_state = np.random.choice(state_indices, p=(transition[index_preceding_state, :].flatten())) #draw the next state
        price_simu[t, :] = states[index_next_state] #update the price vector and store it
    return price_simu


In [8]:

price_50_by_6 = simulate_prices(price_transition_states_np, price_transition_matrix_np, T)
prices_50_by_4 = price_50_by_6[:, :-2] #remove the indices column

In [9]:
## generate Utility data
utility_np = np.zeros((1+J, N, T)) # 1 for the outside option, J for the number of products
for t in range(T):
    for i in range(N):
        utility_np[0, i, t] = np.random.gumbel() #outside option, just a random noise
        utility_np[1:, i, t] = betas_np[i, :] + etas_np[i]*prices_50_by_4[t, :] + np.random.gumbel(size=J) #utility for the J products



In [10]:
choice_jnp = jnp.argmax(utility_np, axis=0) #argmax to get the choice number
prices_50_by_4_jnp = jnp.array(prices_50_by_4) #convert the prices to jnp array

Note: I converted most of the numpy objects to jax.numpy objects. JAX is a library that allows us to do parallel computing, which speeds up the computation.

* MLE assuming homogeneous $\Theta^h$

In [11]:
@jit
def choice_probas(theta):
    theta_jnp = jnp.array(theta)
    betas = theta_jnp[:-1]
    eta = theta_jnp[-1]
    v_1to4_utility = betas + eta * prices_50_by_4_jnp #for a candidate theta, compute systematic utility for each time period and product
    v_default = jnp.zeros((T, 1))
    v_utility = jnp.concatenate((v_default, v_1to4_utility), axis=1)

    # Compute choice probabilities with improved numerical stability
    log_sumexps = logsumexp(v_utility, axis=1)
    probas = jnp.exp(v_utility - log_sumexps[:, None]) #get the choice probabilities for each time period and product

    return probas

In [12]:
@jit
def likelihood(theta): #(log)-likelihood function
    probas_theta = choice_probas(theta) #get the choice probabilities for the candidate theta
    log_likelihood = jnp.sum(jnp.log(probas_theta[jnp.arange(T), choice_jnp])) #sum the log-probabilities of the observed choices
    return -log_likelihood

grad_likelihood = jit(grad(likelihood)) ## gradient of the likelihood function

In [13]:
def minimize_adam(f, grad_f, x0, norm=1e9, tol=0.1, lr=0.05, maxiter=1000, verbose=0, *args): ## generic adam optimizer
  """
  Generic Adam Optimizer. Specify a function f, a starting point x0, possibly a \n
  learning rate in (0, 1). The lower the learning rate, the more stable (and slow) the convergence.
  """
  tic = time.time()
  solver = optax.adam(learning_rate=lr)
  params = jnp.array(x0, dtype=jnp.float32)
  opt_state = solver.init(params)
  iternum = 0
  while norm > tol and iternum < maxiter :
    iternum += 1
    grad = grad_f(params, *args)
    updates, opt_state = solver.update(grad, opt_state, params)
    params = optax.apply_updates(params, updates)
    params = jnp.asarray(params, dtype=jnp.float32)
    norm = jnp.max(jnp.abs(grad))
    if verbose > 0:
      if iternum % 100 == 0:
        print(f"Iteration: {iternum}  Norm: {norm}  theta: {params}")
    if verbose > 1:
      print(f"Iteration: {iternum}  Norm: {norm}  theta: {params}")
  tac = time.time()
  if iternum == maxiter:
    print(f"Convergence not reached after {iternum} iterations. \nTime: {tac-tic} seconds. Norm: {norm}")
  else:
    print(f"Convergence reached after {iternum} iterations. \nTime: {tac-tic} seconds. Norm: {norm}")

  return params

In [14]:
theta_MLE_homo = minimize_adam(likelihood, grad_likelihood, jnp.ones(5), lr=0.1, verbose=1, tol=0.01, maxiter=1500)

Iteration: 100  Norm: 151.35691833496094  theta: [-1.3924706   0.14809617 -1.1837898  -0.5084201  -0.5308946 ]
Iteration: 200  Norm: 64.03196716308594  theta: [-1.4601713  -0.01200273 -1.2231857  -0.6660787  -0.46587494]
Iteration: 300  Norm: 34.01972961425781  theta: [-1.5036387  -0.12207926 -1.2650094  -0.7935174  -0.41740423]
Iteration: 400  Norm: 15.053604125976562  theta: [-1.5295694  -0.18775776 -1.2898902  -0.86957914 -0.388516  ]
Iteration: 500  Norm: 5.780517578125  theta: [-1.5423031  -0.22003625 -1.302108   -0.90698284 -0.3743311 ]
Iteration: 600  Norm: 1.8632659912109375  theta: [-1.5475703  -0.23339154 -1.307161   -0.92246413 -0.36846307]
Iteration: 700  Norm: 0.5316925048828125  theta: [-1.5494206  -0.23808509 -1.3089368  -0.92790514 -0.366402  ]
Iteration: 800  Norm: 0.141448974609375  theta: [-1.5499736  -0.2394883  -1.3094672  -0.92953134 -0.36578575]
Iteration: 900  Norm: 0.02227783203125  theta: [-1.5501127  -0.23984279 -1.3096005  -0.9299428  -0.3656294 ]
Convergenc

In [15]:
print(theta_MLE_homo)

[-1.550132   -0.23989092 -1.3096188  -0.9299984  -0.36560854]


Formula to retrieve MLE Standard Errors

\begin{align*}
\nabla l_{it} &= \frac{\partial l_{it} (\widehat{\theta})}{\partial \widehat{\theta}} \tag{column vector} \\
SE(\widehat{\theta}) &= diag\Bigg[{\sqrt{\Big( \frac{1}{NT}\sum_{i=1}^N \sum_{t=1}^T \nabla l_{it} \cdot \nabla l_{it}' \Big)^{-1}}}\Bigg]
\end{align*}

In [16]:
###Computation of standard errors

@jit
def likelihood_it(theta, i, t):
    """
    Computes the likelihood for an individual observation
    """
    probas_theta = choice_probas(theta)
    likelihood_it = jnp.log(probas_theta[t, choice_jnp[i, t]])
    return likelihood_it

grad_likelihood_it = jit(grad(likelihood_it)) ### Takes the gradient of the individual likelihood

@jit
def outer_grad_likelihood(theta, i, t):
    """
    Takes the outer product (column vector x row vector) of the gradient of the individual likelihood
    """
    grad_it = (grad_likelihood_it(theta, i, t)).reshape(-1, 1) 
    return grad_it@grad_it.T


#computes the outer product above for each individual and time period
grad_likelihood_it_vec = vmap(vmap(outer_grad_likelihood, in_axes=(None, 0, None)), in_axes=(None, None, 0)) 

@jit
def compute_standard_errors(theta):
    sum_outers = (1/(N*T))*(jnp.sum(grad_likelihood_it_vec(theta, jnp.arange(N), jnp.arange(T)), axis=(0, 1)))
    return jnp.diag(jnp.sqrt(jnp.linalg.inv(sum_outers)))

In [17]:
se = compute_standard_errors(theta_MLE_homo)
display(Latex(f'$\Theta^h$: {theta_MLE_homo}'))
display(Latex(f'$se$: {se}'))

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

* MLE assuming two classes

In [18]:
###Two classes: instead of estimating theta, we want to estimate the weights phi_1, phi_2 of each class (?)
theta_k1 = theta_MLE_homo - se
theta_k2 = theta_MLE_homo + se
print(theta_k1)
print(theta_k2)

[ -6.965951  -10.912307   -6.3434596 -13.264233   -4.9947257]
[ 3.8656867 10.432525   3.7242217 11.404236   4.2635083]


In [19]:
@jit
def choice_probas_2classes(phi1):
    phi2 = 1 - phi1
    probas_k1 = choice_probas(theta_k1)
    probas_k2 = choice_probas(theta_k2)
    probas = phi1 * probas_k1 + phi2 * probas_k2
    return probas

In [20]:
@jit
def likelihood_2classes(phi1): #(log)-likelihood function
    probas_theta = choice_probas_2classes(phi1) #get the choice probabilities for the candidate theta
    log_likelihood = jnp.sum(jnp.log(probas_theta[jnp.arange(T), choice_jnp])) #sum the log-probabilities of the observed choices
    return -log_likelihood


grad_likelihood_2classes = jit(grad(likelihood_2classes))

In [21]:
weight_1 = minimize_adam(likelihood_2classes, grad_likelihood_2classes, 0.5, verbose=False)
display(Latex(f'$\Theta_1^h$: {theta_k1}'))
display(Latex(f'$\Theta_2^h$: {theta_k2}'))
print(f'Weights: {weight_1.item(), 1-weight_1.item()}')

Convergence reached after 165 iterations. 
Time: 2.404602527618408 seconds. Norm: 0.0703125


<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

Weights: (0.725957453250885, 0.274042546749115)


* MLE assuming that $\Theta^h$ has a normal distribution across the households 

In [22]:
# Define the number of Monte Carlo draws
S = 5000  # Number of simulation draws

# Generate random draws for Monte Carlo integration
key = jax.random.PRNGKey(123)
mc_draws = jax.random.normal(key, (S, 5))  # 5 parameters (4 betas + 1 eta)

In [23]:
@jit
def mixed_logit_likelihood(theta):
    mu = theta[:5]  # Mean of the random parameters
    sigma = jnp.exp(theta[5:])
    betas_eta = mu + mc_draws * jnp.sqrt(sigma)
    
    probas_theta = vmap(choice_probas)(betas_eta)  # Shape: (S, T, 5)
    probas_theta_avg = jnp.mean(probas_theta, axis=0)  # Shape: (T, 5)
    
    log_likelihood = jnp.sum(jnp.log(probas_theta_avg[jnp.arange(T), choice_jnp]))  # Sum over T
    return -log_likelihood

grad_mixed_logit_likelihood = jit(grad(mixed_logit_likelihood))

In [24]:
x_start = jnp.concatenate((theta_MLE_homo, jnp.ones(5)))
theta_mixed_logit = minimize_adam(mixed_logit_likelihood, grad_mixed_logit_likelihood, x_start, lr=0.1, maxiter=5000, verbose=1, tol=0.2)

Iteration: 100  Norm: 14.214195251464844  theta: [-1.5548645  -0.02482525 -1.3207934  -0.99813336 -1.0442771   0.70649076
  1.551473    0.80982864  1.3152521  -0.09989887]
Iteration: 200  Norm: 4.883601665496826  theta: [-1.364749    0.0413483  -1.145643   -0.9625387  -1.1118015   0.48166847
  1.6117646   0.61018     1.3661939   0.05143078]
Iteration: 300  Norm: 2.6269466876983643  theta: [-1.2292743   0.08922753 -1.0077752  -0.93832475 -1.1660131   0.30134913
  1.66933     0.4228158   1.4257964   0.15263534]
Iteration: 400  Norm: 1.1853811740875244  theta: [-1.1455742   0.11061405 -0.9148244  -0.9353715  -1.1994507   0.1719849
  1.7142999   0.2672374   1.4741052   0.21238129]
Iteration: 500  Norm: 0.49642378091812134  theta: [-1.0949494   0.11227275 -0.8538697  -0.9485005  -1.216657    0.07611551
  1.7485735   0.13745132  1.5133996   0.24276757]
Iteration: 600  Norm: 0.41016486287117004  theta: [-1.062452    0.10262876 -0.811927   -0.9712046  -1.2241095  -0.00178099
  1.7757965   0.02

In [25]:
### Estimate of Theta: not too bad, and the signs are correct
theta_mixed_logit[:5]

Array([-0.9368873 ,  0.01981145, -0.6546789 , -1.1986973 , -1.2291056 ],      dtype=float32)

In [26]:
### Estimate of the variance-covariance matrix. Seems hard to recover, probably also causing a larger biase in the point estimate.
jnp.diag(jnp.exp(theta_mixed_logit[5:]))

Array([[0.6204143, 0.       , 0.       , 0.       , 0.       ],
       [0.       , 6.54409  , 0.       , 0.       , 0.       ],
       [0.       , 0.       , 0.5004512, 0.       , 0.       ],
       [0.       , 0.       , 0.       , 5.850934 , 0.       ],
       [0.       , 0.       , 0.       , 0.       , 1.3153167]],      dtype=float32)

* Sanity Check: what if we assume the correct sigma and only estimate for theta ?

Conclusion: we do recover an unbiased estimate of $\mu$ if we are able to assume to know $\sigma$ already.

In [27]:
@jit
def mixed_logit_likelihood_correct(theta):
    mu = theta.flatten()  # Mean of the random parameters
    betas_eta = mu + mc_draws * jnp.sqrt(jnp.diag(sigma)) 

    def compute_probas(draw):
        return choice_probas(draw)  # Shape: (T, 5)
    
    probas_theta = vmap(compute_probas)(betas_eta)  # Shape: (S, T, 5)
    probas_theta_avg = jnp.mean(probas_theta, axis=0)  # Shape: (T, 5)
    
    log_likelihood = jnp.sum(jnp.log(probas_theta_avg[jnp.arange(T), choice_jnp]))  # Sum over T
    return -log_likelihood

grad_mixed_logit_likelihood_correct = jit(grad(mixed_logit_likelihood_correct))

In [28]:
theta_mixed_logit_correct = minimize_adam(mixed_logit_likelihood_correct, 
                                          grad_mixed_logit_likelihood_correct, 
                                          theta_MLE_homo, 
                                          lr=0.05, maxiter=5000, verbose=1, tol=0.1)

Iteration: 100  Norm: 54.19804000854492  theta: [-1.8604772  -0.00691507 -1.4520905  -1.3856395  -1.0146875 ]
Iteration: 200  Norm: 21.64241600036621  theta: [-1.7583201   0.26995337 -1.3512261  -1.0736028  -1.1389674 ]
Iteration: 300  Norm: 7.9752068519592285  theta: [-1.696638    0.42954588 -1.2925091  -0.88821095 -1.2112354 ]
Iteration: 400  Norm: 2.3028757572174072  theta: [-1.670866    0.49620482 -1.2679734  -0.8107995  -1.2414314 ]
Iteration: 500  Norm: 0.535210132598877  theta: [-1.6627408  0.517212  -1.2602401 -0.7864084 -1.250948 ]
Iteration: 600  Norm: 0.09628894925117493  theta: [-1.6607764   0.52229446 -1.2583694  -0.78050625 -1.2532526 ]
Convergence reached after 600 iterations. 
Time: 12.209501504898071 seconds. Norm: 0.09628894925117493


In [29]:
theta_mixed_logit_correct

Array([-1.6607764 ,  0.52229446, -1.2583694 , -0.78050625, -1.2532526 ],      dtype=float32)