# The Red Bus / Blue Bus Paradox and Fixing it with Matrix-Completion (Mixed) Logit

This example illustrates the famous Independence of Irrelevant Alternatives (IIA)
problem in standard multinomial logit models and shows how low-rank logit avoids it.

Setup:
- 3 initial products: Car, Train, Red Bus
- Each product is quite distinct with baseline market shares
- We then add Blue Bus (very similar to Red Bus)

Expected behavior:
- Blue Bus should primarily cannibalize Red Bus (they're substitutes)
- Car and Train shares should remain relatively stable

Standard MNL problem (IIA):
- Adds Blue Bus → all products lose market share proportionally 
- This is unrealistic: why would Car and Train lose share equally to Red Bus?

Low-rank logit solution:
- Captures heterogeneous preferences (bus-lovers, train-lovers, car-lovers)
- Blue Bus cannibalizes Red Bus primarily, as expected

Alternatives - Nested Logit, Multinomial Probit with full covariance:
- Nested Logit requires pre-defined nests (e.g., grouping buses) and has a fugly likelihood
- Multinomial Probit with full covariance has no closed-form likelihood and requires bespoke and brittle simulation-based estimation (GHK simulator etc)

In [1]:
import torch
import numpy as np
from torchonometrics.choice import MultinomialLogit, LowRankLogit

# Set seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## DGP: heterogeneous preferences and price sensitivity

In [2]:
n_users = 1_00
n_items = 4  # Car, Train, Red Bus, Blue Bus (but we'll only use 3 initially)
rank = 2

# Create user preference structure
# Dimension 1: Public transit vs. Private (Car)
# Dimension 2: Rail vs. Road (Train vs. Bus)
A = torch.zeros(n_users, rank)
B = torch.zeros(n_items, rank)

# User types (approximate):
# - 40% prefer private transport (positive dim 1)
# - 60% prefer public transport (negative dim 1)
# Within public transport:
#   - 50% prefer rail (positive dim 2)
#   - 50% prefer road/bus (negative dim 2)

# public/private dim
A[:40, 0] = torch.randn(40) * 0.3 + 1.5   # Car lovers (private)
A[40:70, 0] = torch.randn(30) * 0.3 - 1.0  # Public transit (train lovers)
A[70:, 0] = torch.randn(30) * 0.3 - 1.0    # Public transit (bus lovers)

# rail/bus dim
A[40:70, 1] = torch.randn(30) * 0.3 + 1.0  # Train preference
A[70:, 1] = torch.randn(30) * 0.3 - 1.5    # Bus preference

# Product characteristics
# [Private/Public, Rail/Road]
B[0] = torch.tensor([2.0, 0.0])    # Car: Private
B[1] = torch.tensor([-1.5, 2.0])   # Train: Public, Rail
B[2] = torch.tensor([-1.5, -2.0])  # Red Bus: Public, Road
B[3] = torch.tensor([-1.5, -1.8])  # Blue Bus: Public, Road (very similar to Red Bus)

# Add price as an additional utility component
# Prices: Car is expensive, Train is moderate, Buses are cheap and similar
product_prices = torch.tensor([10.0, 5.0, 2.0, 2.5])  # Car, Train, Red Bus, Blue Bus

# Price sensitivity: users are heterogeneously price sensitive
# More expensive products (Car) attract less price-sensitive users
price_sensitivity = torch.randn(n_users) * 0.5 - 1.0  # Mean -1.0 (negative = prefer lower prices)

# Create utility matrix: base preferences + price effect
true_theta = A @ B.T
true_theta = true_theta + price_sensitivity.unsqueeze(1) * product_prices.unsqueeze(0)
true_theta = true_theta - true_theta.mean(dim=1, keepdim=True)  # Zero-sum constraint

print("\nUser segments:")
print("  Car lovers: 40 users")
print("  Train lovers: 30 users")
print("  Bus lovers: 30 users")
print("\nProduct characteristics (from low-rank structure):")
print(f"  Car: Private transport, Price=${product_prices[0]:.2f}")
print(f"  Train: Public transport, Rail, Price=${product_prices[1]:.2f}")
print(f"  Red Bus: Public transport, Road, Price=${product_prices[2]:.2f}")
print(f"  Blue Bus: Public transport, Road, Price=${product_prices[3]:.2f} (similar to Red Bus)")
print(f"\nNote: Red Bus and Blue Bus have similar prices (${product_prices[2]:.2f} vs ${product_prices[3]:.2f})")




User segments:
  Car lovers: 40 users
  Train lovers: 30 users
  Bus lovers: 30 users

Product characteristics (from low-rank structure):
  Car: Private transport, Price=$10.00
  Train: Public transport, Rail, Price=$5.00
  Red Bus: Public transport, Road, Price=$2.00
  Blue Bus: Public transport, Road, Price=$2.50 (similar to Red Bus)

Note: Red Bus and Blue Bus have similar prices ($2.00 vs $2.50)


### different offline datasets

In [3]:
# ============================================================================
# STEP 2: Generate training data with varying assortments
# ============================================================================
n_samples = n_users * 50
user_indices = torch.randint(0, n_users, (n_samples,))

# Create varying assortments:
# - 70% of time: Car, Train, Red Bus only
# - 30% of time: Car, Train, Blue Bus only (Red Bus unavailable)
# This way the model learns about Blue Bus, but we can still do counterfactual
assortments_train = torch.zeros(n_samples, n_items)
for i in range(n_samples):
    if torch.rand(1).item() < 0.7:
        # Baseline: Car, Train, Red Bus
        assortments_train[i, [0, 1, 2]] = 1
    else:
        # Alternative: Car, Train, Blue Bus (Red unavailable)
        assortments_train[i, [0, 1, 3]] = 1

# Generate choices
item_indices_train = torch.zeros(n_samples, dtype=torch.long)
for i in range(n_samples):
    user = user_indices[i]
    available = assortments_train[i].bool()
    utilities = true_theta[user, available]
    probs = torch.nn.functional.softmax(utilities, dim=0)
    available_items = torch.where(available)[0]
    chosen_idx = torch.multinomial(probs, 1).item()
    item_indices_train[i] = available_items[chosen_idx]

# Compute empirical market shares in training data
product_names = ["Car", "Train", "Red Bus", "Blue Bus"]
empirical_shares = torch.bincount(item_indices_train, minlength=4).float() / n_samples
print("\nEmpirical market shares in training data (with varying assortments):")
for i in range(4):
    print(f"  {product_names[i]}: {empirical_shares[i]:.1%}")

print("\nNote: Red Bus and Blue Bus never appear together in training data")
print("      This simulates introducing a new product variant.")



Empirical market shares in training data (with varying assortments):
  Car: 10.0%
  Train: 22.3%
  Red Bus: 47.9%
  Blue Bus: 19.8%

Note: Red Bus and Blue Bus never appear together in training data
      This simulates introducing a new product variant.


## Model fits

### MNL

In [4]:
# ============================================================================
# STEP 3: Fit Standard Multinomial Logit with Prices
# ============================================================================
# Standard MNL learns choice-specific constants plus price effects
# Features: [intercept, price_car, price_train, price_redbus, price_bluebus]
# Each alternative gets its own coefficients for all features
mnl_model = MultinomialLogit()

# Prepare feature matrix: intercept + all product prices
# Each observation has the same prices (they're product attributes, not user attributes)
X_mnl = torch.zeros(n_samples, 5)  # intercept + 4 prices
X_mnl[:, 0] = 1  # intercept
X_mnl[:, 1:5] = product_prices.repeat(n_samples, 1)  # prices for all products

# Create one-hot encoded choices for training data
y_mnl = torch.nn.functional.one_hot(item_indices_train, num_classes=4).to(torch.float32)

# Initialize parameters for MNL (3 alternatives, 1 is reference)
# Shape: (5 features, 3 alternatives)
init_params_mnl = torch.randn(5, 3) * 0.01

print("\nFitting standard MNL model with price features...")
print("  Features: intercept + prices of all 4 products")
print("  Note: MNL learns separate price coefficients for each alternative")
mnl_model.fit(X_mnl, y_mnl, init_params=init_params_mnl, verbose=False)
print(f"MNL converged in {mnl_model.iterations_run} iterations")

# Predict on 3-product scenario
X_test_mnl = torch.zeros(n_users, 5)
X_test_mnl[:, 0] = 1  # intercept
X_test_mnl[:, 1:5] = product_prices.repeat(n_users, 1)  # prices

mnl_probs_full = mnl_model.predict_proba(X_test_mnl)  # All 4 products

# For 3-product scenario (Car, Train, Red Bus), we need to renormalize
# This is the key: under IIA, relative odds P(i)/P(j) stay constant
mnl_probs_3prod = mnl_probs_full[:, :3].clone()
mnl_probs_3prod = mnl_probs_3prod / mnl_probs_3prod.sum(dim=1, keepdim=True)
mnl_shares_3prod = mnl_probs_3prod.mean(dim=0)

print("\nStandard MNL predicted market shares (Car, Train, Red Bus):")
for i in [0, 1, 2]:
    print(f"  {product_names[i]}: {mnl_shares_3prod[i]:.1%}")

# Show learned coefficients (for interpretation)
mnl_coefs = mnl_model.params["coef"]  # Shape: (5 features, 3 alternatives)
print("\nMNL learned coefficients (reference = Blue Bus):")
print("  Feature          Car        Train      Red Bus")
print("  " + "-" * 50)
feature_names = ["Intercept", "Price(Car)", "Price(Train)", "Price(RedBus)", "Price(BlueBus)"]
for j, fname in enumerate(feature_names):
    print(f"  {fname:15s} {mnl_coefs[j, 0]:8.3f}   {mnl_coefs[j, 1]:8.3f}   {mnl_coefs[j, 2]:8.3f}")



Fitting standard MNL model with price features...
  Features: intercept + prices of all 4 products
  Note: MNL learns separate price coefficients for each alternative
MNL converged in 12 iterations

Standard MNL predicted market shares (Car, Train, Red Bus):
  Car: 12.5%
  Train: 27.8%
  Red Bus: 59.7%

MNL learned coefficients (reference = Blue Bus):
  Feature          Car        Train      Red Bus
  --------------------------------------------------
  Intercept         -0.006     -0.008      0.008
  Price(Car)        -0.057      0.005      0.058
  Price(Train)      -0.007      0.008      0.033
  Price(RedBus)     -0.023      0.006      0.031
  Price(BlueBus)    -0.009      0.011      0.027


### Low-Rank Mixed Logit

In [5]:
# Fit low-rank model on training data (with varying assortments)
lr_logit = LowRankLogit(rank=rank, n_users=n_users, n_items=n_items, lam=0.05)
lr_logit.fit(user_indices, item_indices_train, assortments_train, verbose=True)
print(f"Low-rank logit fitted in {lr_logit.iterations_run} iterations")

# Predict on 3-product scenario (Car, Train, Red Bus)
test_users = torch.arange(n_users)
assort_3prod_test = torch.zeros(n_users, n_items)
assort_3prod_test[:, [0, 1, 2]] = 1  # Only Car, Train, Red Bus
lr_probs_3prod = lr_logit.predict_proba(test_users, assort_3prod_test)
lr_shares_3prod = lr_probs_3prod.mean(dim=0)

print("\nLow-rank logit predicted market shares (Car, Train, Red Bus):")
for i in [0, 1, 2]:
    print(f"  {product_names[i]}: {lr_shares_3prod[i]:.1%}")

Convergence tolerance 0.0001 met at iteration 11.
Low-rank logit fitted in 12 iterations

Low-rank logit predicted market shares (Car, Train, Red Bus):
  Car: 9.7%
  Train: 21.7%
  Red Bus: 68.6%


## Redbus/Bluebus

In [6]:
# ============================================================================
# STEP 5: Add Blue Bus - Standard MNL Prediction (WRONG!)
# ============================================================================
# MNL with all 4 products
mnl_shares_4prod = mnl_probs_full.mean(dim=0)

print("\nStandard MNL predicted market shares (all 4 products):")
for i in range(4):
    print(f"  {product_names[i]}: {mnl_shares_4prod[i]:.1%}")

print("\nMNL-IIA substitution pattern when adding Blue Bus:")
mnl_change = mnl_shares_4prod[:3] - mnl_shares_3prod
for i in range(3):
    pct_loss = -mnl_change[i]/mnl_shares_3prod[i] if mnl_shares_3prod[i] > 0 else 0
    print(f"  {product_names[i]}: {mnl_change[i]:+.1%} (loses {pct_loss:.1%} of its share)")
print(f"  {product_names[3]}: +{mnl_shares_4prod[3]:.1%} (new entrant)")

print("\n⚠️  PROBLEM: All products lose share proportionally!")
print("   - Car loses share to Blue Bus (doesn't make sense)")
print("   - Train loses share to Blue Bus (doesn't make sense)")
print("   - Red Bus loses share to Blue Bus (makes sense, but same as others)")
print("   This is the RED BUS / BLUE BUS PARADOX")


Standard MNL predicted market shares (all 4 products):
  Car: 10.0%
  Train: 22.3%
  Red Bus: 47.9%
  Blue Bus: 19.8%

MNL-IIA substitution pattern when adding Blue Bus:
  Car: -2.5% (loses 19.8% of its share)
  Train: -5.5% (loses 19.8% of its share)
  Red Bus: -11.8% (loses 19.8% of its share)
  Blue Bus: +19.8% (new entrant)

⚠️  PROBLEM: All products lose share proportionally!
   - Car loses share to Blue Bus (doesn't make sense)
   - Train loses share to Blue Bus (doesn't make sense)
   - Red Bus loses share to Blue Bus (makes sense, but same as others)
   This is the RED BUS / BLUE BUS PARADOX


In [7]:
# Counterfactual: Car, Train, Red Bus, Blue Bus (never seen together in training!)
assort_4prod_test = torch.ones(n_users, n_items)
lr_probs_4prod = lr_logit.predict_proba(test_users, assort_4prod_test)
lr_shares_4prod = lr_probs_4prod.mean(dim=0)

print("\nLow-rank logit predicted market shares (all 4 products together):")
for i in range(4):
    print(f"  {product_names[i]}: {lr_shares_4prod[i]:.1%}")

print("\nLow-rank logit substitution pattern (vs 3-product baseline):")
lr_change = lr_shares_4prod[:3] - lr_shares_3prod[:3]
for i in range(3):
    pct_loss = -lr_change[i]/lr_shares_3prod[i] if lr_shares_3prod[i] > 0 else 0
    print(f"  {product_names[i]}: {lr_change[i]:+.1%} (loses {pct_loss:.1%} of its share)")
print(f"  {product_names[3]}: +{lr_shares_4prod[3]:.1%} (new entrant)")

print("\n✓ CORRECT BEHAVIOR:")
print(f"  - Red Bus loses MOST share: {abs(lr_change[2]):.1%} (similar products!)")
print(f"  - Car impact: {abs(lr_change[0]):.1%} (minimal)")
print(f"  - Train impact: {abs(lr_change[1]):.1%} (minimal)")
print("  - Blue Bus primarily cannibalizes Red Bus, not unrelated products")
print("  - This matches real-world intuition about substitution patterns!")



Low-rank logit predicted market shares (all 4 products together):
  Car: 7.9%
  Train: 19.3%
  Red Bus: 52.3%
  Blue Bus: 20.5%

Low-rank logit substitution pattern (vs 3-product baseline):
  Car: -1.7% (loses 18.0% of its share)
  Train: -2.4% (loses 11.1% of its share)
  Red Bus: -16.3% (loses 23.8% of its share)
  Blue Bus: +20.5% (new entrant)

✓ CORRECT BEHAVIOR:
  - Red Bus loses MOST share: 16.3% (similar products!)
  - Car impact: 1.7% (minimal)
  - Train impact: 2.4% (minimal)
  - Blue Bus primarily cannibalizes Red Bus, not unrelated products
  - This matches real-world intuition about substitution patterns!


In [8]:
results = lr_logit.counterfactual(
    user_indices=test_users,
    baseline_assortments=assort_3prod_test,
    counterfactual_assortments=assort_4prod_test,
)

print("\nFormal counterfactual analysis:")
print("  Baseline: {Car, Train, Red Bus}")
print("  Counterfactual: {Car, Train, Red Bus, Blue Bus}")

# Analyze probability shifts (more informative than argmax)
baseline_probs = results['baseline_probs']
cf_probs = results['counterfactual_probs']

# Users who had significant probability on Red Bus in baseline
redbus_users = (baseline_probs[:, 2] > 0.3)  # >30% prob of choosing Red Bus
print(f"\n{redbus_users.sum().item()} users had >30% probability of choosing Red Bus in baseline")
print(f"  Average Red Bus probability for these users:")
print(f"    Baseline (3 products): {baseline_probs[redbus_users, 2].mean():.1%}")
print(f"    Counterfactual (4 products): {cf_probs[redbus_users, 2].mean():.1%}")
print(f"  Average Blue Bus probability for these users in counterfactual: {cf_probs[redbus_users, 3].mean():.1%}")

# Simulate actual choices to show switching patterns
torch.manual_seed(43)  # Different seed for sampling
baseline_samples = torch.multinomial(baseline_probs, 1).squeeze(1)
cf_samples = torch.multinomial(cf_probs, 1).squeeze(1)

switchers_from_redbus = ((baseline_samples == 2) & (cf_samples == 3)).sum().item()
switchers_from_train = ((baseline_samples == 1) & (cf_samples == 3)).sum().item()
switchers_from_car = ((baseline_samples == 0) & (cf_samples == 3)).sum().item()
total_bluebus = (cf_samples == 3).sum().item()

print(f"\nSimulated choice switching (sampling from choice probabilities):")
print(f"Of {total_bluebus} users who chose Blue Bus in 4-product scenario:")
if total_bluebus > 0:
    print(f"  {switchers_from_redbus} switched from Red Bus ({switchers_from_redbus/total_bluebus:.1%})")
    print(f"  {switchers_from_train} switched from Train ({switchers_from_train/total_bluebus:.1%})")
    print(f"  {switchers_from_car} switched from Car ({switchers_from_car/total_bluebus:.1%})")

print("\n✓ KEY INSIGHT: Most Blue Bus users come from Red Bus!")
print("  This demonstrates heterogeneous preferences: bus-lovers switch")
print("  to Blue Bus, while car-lovers and train-lovers stick with their")
print("  preferred modes. The low-rank structure captures these segments.")


Formal counterfactual analysis:
  Baseline: {Car, Train, Red Bus}
  Counterfactual: {Car, Train, Red Bus, Blue Bus}

76 users had >30% probability of choosing Red Bus in baseline
  Average Red Bus probability for these users:
    Baseline (3 products): 86.5%
    Counterfactual (4 products): 65.4%
  Average Blue Bus probability for these users in counterfactual: 24.7%

Simulated choice switching (sampling from choice probabilities):
Of 23 users who chose Blue Bus in 4-product scenario:
  21 switched from Red Bus (91.3%)
  2 switched from Train (8.7%)
  0 switched from Car (0.0%)

✓ KEY INSIGHT: Most Blue Bus users come from Red Bus!
  This demonstrates heterogeneous preferences: bus-lovers switch
  to Blue Bus, while car-lovers and train-lovers stick with their
  preferred modes. The low-rank structure captures these segments.


## summary

In [9]:

# ============================================================================
# STEP 8: Summary
# ============================================================================
print("\n" + "=" * 80)
print("SUMMARY: Why Low-Rank Logit Solves the IIA Problem")
print("=" * 80)

print("""
THE PROBLEM: Independence of Irrelevant Alternatives (IIA)
---------------------------------------------------------
Standard multinomial logit models assume that adding a new alternative
causes PROPORTIONAL substitution from all existing alternatives.

In this example:
  MNL predicts: Adding Blue Bus → all products lose 15% equally
  Reality: Blue Bus should primarily cannibalize Red Bus (similar products)

This is the famous RED BUS / BLUE BUS PARADOX.


THE SOLUTION: Low-Rank Mixed Multinomial Logit
-----------------------------------------------
Low-rank structure θ = AB^T captures heterogeneous user preferences:
  - Bus lovers: high affinity for both Red and Blue Bus (dim 2 < 0)
  - Train lovers: prefer rail transport (dim 2 > 0)
  - Car lovers: prefer private transport (dim 1 > 0)

When Blue Bus is added, the model correctly predicts:
""")

print(f"  Red Bus loses: {abs(lr_change[2]):.1%} of market share")
print(f"  Blue Bus captures: {lr_shares_4prod[3]:.1%} market share")
print(f"  Car loses: {abs(lr_change[0]):.1%} (minimal)")
print(f"  Train loses: {abs(lr_change[1]):.1%} (minimal)")

if total_bluebus > 0:
    print(f"\n  {switchers_from_redbus/total_bluebus:.0%} of Blue Bus users came from Red Bus")




SUMMARY: Why Low-Rank Logit Solves the IIA Problem

THE PROBLEM: Independence of Irrelevant Alternatives (IIA)
---------------------------------------------------------
Standard multinomial logit models assume that adding a new alternative
causes PROPORTIONAL substitution from all existing alternatives.

In this example:
  MNL predicts: Adding Blue Bus → all products lose 15% equally
  Reality: Blue Bus should primarily cannibalize Red Bus (similar products)

This is the famous RED BUS / BLUE BUS PARADOX.


THE SOLUTION: Low-Rank Mixed Multinomial Logit
-----------------------------------------------
Low-rank structure θ = AB^T captures heterogeneous user preferences:
  - Bus lovers: high affinity for both Red and Blue Bus (dim 2 < 0)
  - Train lovers: prefer rail transport (dim 2 > 0)
  - Car lovers: prefer private transport (dim 1 > 0)

When Blue Bus is added, the model correctly predicts:

  Red Bus loses: 16.3% of market share
  Blue Bus captures: 20.5% market share
  Car loses: 1