In [163]:
import numpy as np 
import pandas as pd 
import time
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax.scipy.special import logsumexp

import numpy as np
import optax
from IPython.display import display, Latex
from warnings import filterwarnings
filterwarnings('ignore')

In [2]:
np.random.seed(123)

N = 1000
J = 4
T = 50

# Generate the data
np.random.seed(123)
mu = np.array([-1.71, 0.44, -1.37, -0.91, -1.23]).reshape(-1, 1)
sigma = np.diag(np.array([3.22, 3.24, 2.87, 4.15, 1.38])).reshape(5, 5)

print(mu, '\n')
print(sigma)


# generate the random parameters
betas = np.random.multivariate_normal(mu.flatten(), sigma, N)
betas_np = betas[:, :-1]
etas_np = betas[:, -1]

[[-1.71]
 [ 0.44]
 [-1.37]
 [-0.91]
 [-1.23]] 

[[3.22 0.   0.   0.   0.  ]
 [0.   3.24 0.   0.   0.  ]
 [0.   0.   2.87 0.   0.  ]
 [0.   0.   0.   4.15 0.  ]
 [0.   0.   0.   0.   1.38]]


In [3]:
###Extract data

price_transition_states = pd.read_csv(r'price_transition_states.csv')
price_transition_matrix = pd.read_csv(r'transition_prob_matrix.csv')

print(price_transition_states.head())
print(price_transition_matrix.head())

         X1        X2        X3        X4  X5  id
0  0.809830  2.730574  0.770491  2.966650   0   1
1  0.785441  2.453663  0.771422  2.964841   0   2
2  0.823926  2.238301  0.790901  2.914317   0   3
3  1.005571  1.581941  0.918171  2.084867   0   4
4  0.798525  2.693030  0.762935  2.690978   0   5
         V1        V2        V3        V4        V5        V6        V7  \
0  0.050283  0.020251  0.017505  0.003655  0.026568  0.006548  0.006352   
1  0.040564  0.023169  0.016224  0.003816  0.024936  0.007435  0.007776   
2  0.034444  0.016302  0.015208  0.004359  0.022437  0.006756  0.007631   
3  0.017591  0.008223  0.009868  0.007142  0.013897  0.006043  0.011003   
4  0.042309  0.016379  0.015619  0.003890  0.027125  0.006523  0.006811   

         V8        V9       V10  ...       V91       V92       V93       V94  \
0  0.008375  0.016022  0.013590  ...  0.017482  0.003807  0.008912  0.005710   
1  0.006751  0.014843  0.012902  ...  0.020574  0.003186  0.010063  0.005734   
2  0.0084

In [4]:
price_transition_states

Unnamed: 0,X1,X2,X3,X4,X5,id
0,0.809830,2.730574,0.770491,2.966650,0,1
1,0.785441,2.453663,0.771422,2.964841,0,2
2,0.823926,2.238301,0.790901,2.914317,0,3
3,1.005571,1.581941,0.918171,2.084867,0,4
4,0.798525,2.693030,0.762935,2.690978,0,5
...,...,...,...,...,...,...
95,1.096901,2.540894,0.933728,2.026588,0,96
96,1.228067,1.985250,0.941085,2.525047,0,97
97,1.118906,3.728712,1.055625,2.475710,0,98
98,0.847625,2.102348,1.363938,1.771260,0,99


In [5]:
price_transition_matrix

Unnamed: 0,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,...,V91,V92,V93,V94,V95,V96,V97,V98,V99,V100
0,0.050283,0.020251,0.017505,0.003655,0.026568,0.006548,0.006352,0.008375,0.016022,0.013590,...,0.017482,0.003807,0.008912,0.005710,0.016266,0.008533,0.006934,0.002275,0.001474,0.008172
1,0.040564,0.023169,0.016224,0.003816,0.024936,0.007435,0.007776,0.006751,0.014843,0.012902,...,0.020574,0.003186,0.010063,0.005734,0.013634,0.008189,0.006430,0.001822,0.001520,0.009950
2,0.034444,0.016302,0.015208,0.004359,0.022437,0.006756,0.007631,0.008431,0.019615,0.015308,...,0.016124,0.004657,0.010655,0.004589,0.014261,0.009842,0.009434,0.002486,0.001853,0.009703
3,0.017591,0.008223,0.009868,0.007142,0.013897,0.006043,0.011003,0.009424,0.022074,0.012562,...,0.010919,0.007949,0.013787,0.002806,0.011120,0.011316,0.016923,0.003837,0.003444,0.015517
4,0.042309,0.016379,0.015619,0.003890,0.027125,0.006523,0.006811,0.008304,0.017501,0.014749,...,0.016193,0.004160,0.009280,0.004994,0.017674,0.008824,0.008212,0.002280,0.001601,0.008316
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,0.022406,0.010178,0.011546,0.005559,0.015755,0.005905,0.007434,0.010891,0.027838,0.015147,...,0.010508,0.008855,0.010696,0.003277,0.013045,0.012874,0.017172,0.004015,0.003121,0.010308
96,0.014125,0.006274,0.008583,0.005706,0.010867,0.005148,0.006759,0.012477,0.030100,0.013551,...,0.007031,0.013129,0.010076,0.002364,0.012457,0.014006,0.033324,0.005777,0.003969,0.009970
97,0.013230,0.006321,0.007653,0.004281,0.010171,0.004477,0.006573,0.022304,0.023152,0.009947,...,0.007212,0.014083,0.006911,0.002808,0.011422,0.010620,0.013875,0.008872,0.003258,0.008857
98,0.016398,0.007619,0.009392,0.006894,0.012828,0.006098,0.009410,0.009549,0.022639,0.012384,...,0.010010,0.010943,0.011931,0.002618,0.011045,0.012419,0.017758,0.004460,0.005623,0.013918


In [6]:
price_transition_matrix_np = price_transition_matrix.to_numpy()
price_transition_states_np = price_transition_states.to_numpy()

In [7]:
def simulate_prices(states, transition, T):
    state_indices = np.arange(states.shape[0])

    price_simu = np.zeros((T, 6)) #create a matrix to store the simulated prices
    price_simu[0] = states[0] #fix the initial vector of prices
    
    for t in range(1, T):
        preceding_state = price_simu[t-1, :] #take the preceding state
        index_preceding_state = int(preceding_state[-1] - 1) #take the index of the preceding state (-1 for 0-indexing in Python)
        index_next_state = np.random.choice(state_indices, p=(transition[index_preceding_state, :].flatten())) #draw the next state
        price_simu[t, :] = states[index_next_state] #update the price vector and store it
    return price_simu


In [8]:

price_50_by_6 = simulate_prices(price_transition_states_np, price_transition_matrix_np, T)
prices_50_by_4 = price_50_by_6[:, :-2] #remove the indices column

In [9]:
## generate Utility data
utility_np = np.zeros((1+J, N, T)) # 1 for the outside option, J for the number of products
for t in range(T):
    for i in range(N):
        utility_np[0, i, t] = np.random.gumbel() #outside option, just a random noise
        utility_np[1:, i, t] = betas_np[i, :] + etas_np[i]*prices_50_by_4[t, :] + np.random.gumbel(size=J) #utility for the J products



In [10]:
choice_jnp = jnp.argmax(utility_np, axis=0) #argmax to get the choice number
prices_50_by_4_jnp = jnp.array(prices_50_by_4) #convert the prices to jnp array

Note: I converted most of the numpy objects to jax.numpy objects. JAX is a library that allows us to do parallel computing, which massively speeds up the computation. In addition, it allows us to use relatively easily the optimizer of our choice (I went with Adam), and do automatic differentiation (any function f(x) that is defined, I can just do f_prime = grad(f), and now I have a function that gives me the gradient of f). The only requirement is to write vectorized code (avoid for loops as much as you can). If you don't know how, write the function with for loops, and then ask chatgpt how to make it jit-compatible.

In [96]:
@jit
def choice_probas(theta):
    theta_jnp = jnp.array(theta)
    betas = theta_jnp[:-1]
    eta = theta_jnp[-1]
    v_1to4_utility = betas + eta * prices_50_by_4_jnp #for a candidate theta, compute systematic utility for each time period and product
    v_default = jnp.zeros((T, 1))
    v_utility = jnp.concatenate((v_default, v_1to4_utility), axis=1)

    # Compute choice probabilities with improved numerical stability
    log_sumexps = logsumexp(v_utility, axis=1)
    probas = jnp.exp(v_utility - log_sumexps[:, None]) #get the choice probabilities for each time period and product

    return probas

In [40]:
@jit
def likelihood(theta): #(log)-likelihood function
    probas_theta = choice_probas(theta) #get the choice probabilities for the candidate theta
    log_likelihood = jnp.sum(jnp.log(probas_theta[jnp.arange(T), choice_jnp])) #sum the log-probabilities of the observed choices
    return -log_likelihood

grad_likelihood = jit(grad(likelihood)) ## gradient of the likelihood function

In [141]:
def minimize_adam(f, grad_f, x0, norm=1e9, tol=0.1, lr=0.05, maxiter=1000, verbose=0, *args): ## generic adam optimizer
  """
  Generic Adam Optimizer. Specify a function f, a starting point x0, possibly a \n
  learning rate in (0, 1). The lower the learning rate, the more stable (and slow) the convergence.
  """
  tic = time.time()
  solver = optax.adam(learning_rate=lr)
  params = jnp.array(x0, dtype=jnp.float32)
  opt_state = solver.init(params)
  iternum = 0
  while norm > tol and iternum < maxiter :
    iternum += 1
    grad = grad_f(params, *args)
    updates, opt_state = solver.update(grad, opt_state, params)
    params = optax.apply_updates(params, updates)
    params = jnp.asarray(params, dtype=jnp.float32)
    norm = jnp.max(jnp.abs(grad))
    if verbose > 0:
      if iternum % 100 == 0:
        print(f"Iteration: {iternum}  Norm: {norm}  theta: {params}")
    if verbose > 1:
      print(f"Iteration: {iternum}  Norm: {norm}  theta: {params}")
  tac = time.time()
  if iternum == maxiter:
    print(f"Convergence not reached after {iternum} iterations. \nTime: {tac-tic} seconds. Norm: {norm}")
  else:
    print(f"Convergence reached after {iternum} iterations. \nTime: {tac-tic} seconds. Norm: {norm}")

  return params

In [142]:
theta_MLE_homo = minimize_adam(likelihood, grad_likelihood, jnp.ones(5), lr=0.2, verbose=1)

Iteration: 100  Norm: 335.4840087890625  theta: [-1.4224386   0.07654052 -1.1986421  -0.5866302  -0.5061392 ]
Iteration: 200  Norm: 24.041778564453125  theta: [-1.5150231  -0.15054329 -1.2758065  -0.82709366 -0.40474826]
Iteration: 300  Norm: 4.161994934082031  theta: [-1.5439749  -0.22423369 -1.3037015  -0.9119768  -0.37245965]
Iteration: 400  Norm: 0.4395751953125  theta: [-1.5494919  -0.23826328 -1.3090049  -0.92812485 -0.3663207 ]
Convergence reached after 447 iterations. 
Time: 5.415534734725952 seconds. Norm: 0.0990447998046875


### Formula to retrieve MLE Standard Errors

\begin{align*}
\nabla l_{it} &= \frac{\partial l_{it} (\theta)}{\partial \theta} \tag{column vector} \\
SE(\widehat{\theta}) &= diag\Bigg[{\sqrt{\Big(\sum_{i=1}^N \sum_{t=1}^T \nabla l_{it} \cdot \nabla l_{it}' \Big)^{-1}}}\Bigg]
\end{align*}

In [44]:
###Computation of standard errors

@jit
def likelihood_it(theta, i, t):
    """
    Computes the likelihood for an individual observation
    """
    probas_theta = choice_probas(theta)
    likelihood_it = jnp.log(probas_theta[t, choice_jnp[i, t]])
    return likelihood_it

grad_likelihood_it = jit(grad(likelihood_it)) ### Takes the gradient of the individual likelihood

@jit
def outer_grad_likelihood(theta, i, t):
    """
    Takes the outer product (column vector x row vector) of the gradient of the individual likelihood
    """
    grad_it = (grad_likelihood_it(theta, i, t)).reshape(-1, 1) 
    return grad_it@grad_it.T


#computes the outer product above for each individual and time period
grad_likelihood_it_vec = vmap(vmap(outer_grad_likelihood, in_axes=(None, 0, None)), in_axes=(None, None, 0)) 

@jit
def compute_standard_errors(theta):
    sum_outers = (jnp.sum(grad_likelihood_it_vec(theta, jnp.arange(N), jnp.arange(T)), axis=(0, 1)))
    return jnp.diag(jnp.sqrt(jnp.linalg.inv(sum_outers)))

#### MLE assuming homogeneous $\Theta^h$

In [45]:
se = compute_standard_errors(theta_MLE_homo)
display(Latex(f'$\Theta^h$: {theta_MLE_homo}'))
display(Latex(f'$se$: {se}'))

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

In [46]:
###Two classes: instead of estimating theta, we want to estimate the weights phi_1, phi_2 of each class (?)
theta_k1 = theta_MLE_homo - 3*se
theta_k2 = theta_MLE_homo + 3*se
print(theta_k1)
print(theta_k2)

[-1.6226189 -0.3826288 -1.3769871 -1.094965  -0.4279044]
[-1.4773014  -0.09627143 -1.2419189  -0.7640183  -0.3036981 ]


In [49]:
@jit
def choice_probas_2classes(phi1):
    phi2 = 1 - phi1
    probas_k1 = choice_probas(theta_k1)
    probas_k2 = choice_probas(theta_k2)
    probas = phi1 * probas_k1 + phi2 * probas_k2
    return probas

In [50]:
@jit
def likelihood_2classes(phi1): #(log)-likelihood function
    probas_theta = choice_probas_2classes(phi1) #get the choice probabilities for the candidate theta
    log_likelihood = jnp.sum(jnp.log(probas_theta[jnp.arange(T), choice_jnp])) #sum the log-probabilities of the observed choices
    return -log_likelihood


grad_likelihood_2classes = jit(grad(likelihood_2classes))

In [69]:
weight_1 = minimize_adam(likelihood_2classes, grad_likelihood_2classes, 0.5, verbose=False)
display(Latex(f'$\Theta_1^h$: {theta_k1}'))
display(Latex(f'$\Theta_2^h$: {theta_k2}'))
print(f'Weights: {weight_1.item(), 1-weight_1.item()}')

Convergence reached after 47 iterations. 
Time: 0.5442581176757812 seconds. Norm: 0.0


<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

Weights: (0.5217000246047974, 0.47829997539520264)


#### MLE assuming that $\Theta^h$ has a normal distribution across the households 

Prompt, given to French AI 'Le Chat, by Mistral AI':
* Yo le chat ! Je voudrais faire maximum likelihood, under assumption that beta is a random coefficient that is normally distributed. HOw do I do this ?
* I'll show you my current code for MLE estimation under homogeneous theta. Show me how to modify it to get what I want

In [None]:
choice_probas_vec = vmap(choice_probas)

@jit
def likelihood_normal(params, n_draws=500): #(log)-likelihood function
    mu, log_sigma = params[:5], params[5:]

    sigma = jnp.exp(log_sigma)

    key = jax.random.PRNGKey(0)

    theta_list = mu + sigma * jax.random.normal(key, (n_draws, len(mu)))
    probas_theta = choice_probas_vec(theta_list)
    
    log_likelihood = jnp.sum(jnp.log(probas_theta[:, jnp.arange(T), choice_jnp]))
    return -log_likelihood/n_draws

grad_likelihood_normal = jit(grad(likelihood_normal)) ## gradient of the likelihood function

In [224]:
#x_start = jnp.concatenate((theta_MLE_homo, jnp.diag(jnp.ones(5)).flatten()))
x_start = jnp.concatenate((theta_MLE_homo, jnp.ones(5)))

In [225]:
likelihood_normal(x_start)

Array(94102.47, dtype=float32)

In [226]:
theta_normal = minimize_adam(likelihood_normal, grad_likelihood_normal, x_start, verbose=1)

Iteration: 100  Norm: 256.5911865234375  theta: [-1.5359179  -0.20468949 -1.2961324  -0.8844269  -0.38002786 -0.00313769
  0.00301184 -0.00271258 -0.00261913  0.00476202]
Iteration: 200  Norm: 2.271724224090576  theta: [-1.5472858e+00 -2.3251373e-01 -1.3069135e+00 -9.2161208e-01
 -3.6882609e-01  2.9319970e-05 -3.6787351e-05  2.6668191e-05
  2.0714982e-05  1.9729567e-05]
Iteration: 300  Norm: 0.15042029321193695  theta: [-1.5498588e+00 -2.3918414e-01 -1.3093596e+00 -9.2919540e-01
 -3.6591575e-01  9.2118945e-08  5.0928367e-07  7.0738949e-08
  1.1398337e-06 -2.2799252e-07]
Convergence reached after 324 iterations. 
Time: 48.29884672164917 seconds. Norm: 0.09304982423782349


In [230]:
print(f'Theta Normal: {theta_normal[:5]}')
print(f'Variance matrix: {jnp.sqrt(jnp.abs(theta_normal[5:]))}')

Theta Normal: [-1.5499992  -0.23954648 -1.3094931  -0.9296078  -0.36575806]
Variance matrix: [0.0002194  0.00047981 0.00019497 0.00072211 0.0003106 ]
