Using LowRankLogit for Assortment Optimization

This example demonstrates how the low-rank MMNL model avoids IIA problems
and enables rich counterfactual analysis for revenue management.

In [1]:
import torch
from torchonometrics.choice import LowRankLogit

In [2]:
# Simulate data: 100 users, 20 products, rank-3 preference structure
n_users = 100
n_items = 20
rank = 3
n_samples = n_users * 50

# Generate heterogeneous user preferences (low-rank structure)
torch.manual_seed(42)
true_A = torch.randn(n_users, rank)
true_B = torch.randn(n_items, rank)
true_theta = true_A @ true_B.T
true_theta = true_theta - true_theta.mean(dim=1, keepdim=True)

# Generate training data with varying assortments
user_indices = torch.randint(0, n_users, (n_samples,))
assortments = torch.zeros(n_samples, n_items)
for i in range(n_samples):
    # Each user sees random subset of 10 products
    available = torch.randperm(n_items)[:10]
    assortments[i, available] = 1

# Generate choices
item_indices = torch.zeros(n_samples, dtype=torch.long)
for i in range(n_samples):
    user = user_indices[i]
    available = assortments[i].bool()
    utilities = true_theta[user, available]
    probs = torch.nn.functional.softmax(utilities, dim=0)
    available_items = torch.where(available)[0]
    chosen_idx = torch.multinomial(probs, 1).item()
    item_indices[i] = available_items[chosen_idx]

### model fits

In [3]:

# Fit the model
print("Fitting low-rank logit model...")
model = LowRankLogit(rank=rank, n_users=n_users, n_items=n_items, lam=0.01)
model.fit(user_indices, item_indices, assortments, verbose=True)
print(f"Model converged in {model.iterations_run} iterations\n")

Fitting low-rank logit model...
Convergence tolerance 0.0001 met at iteration 11.
Model converged in 12 iterations



In [4]:
# === Counterfactual Analysis 1: Product Line Extension ===
print("SCENARIO 1: Adding a premium product to the assortment")
# Current assortment: products 0-14 (standard line)
# Counterfactual: add product 15 (premium product)
test_users = torch.arange(n_users)
baseline = torch.zeros(n_users, n_items)
baseline[:, :15] = 1  # Current 15 products

counterfactual = baseline.clone()
counterfactual[:, 15] = 1  # Add premium product 15

# Assume premium product has higher revenue
item_revenues = torch.ones(n_items) * 10.0  # Standard products: $10
item_revenues[15] = 25.0  # Premium product: $25

results = model.counterfactual(
    test_users, baseline, counterfactual, item_revenues
)

print(f"\nProduct 15 market share: {results['counterfactual_market_share'][15]:.2%}")
print(f"Revenue change: ${results['revenue_change']:.2f} ({results['revenue_change_pct']:.2%})")
print("\nTop 3 products losing market share:")
losers = torch.argsort(results['market_share_change'])[:3]
for idx in losers:
    print(f"  Product {idx.item()}: {results['market_share_change'][idx]:.2%}")

SCENARIO 1: Adding a premium product to the assortment

Product 15 market share: 8.62%
Revenue change: $129.30 (12.93%)

Top 3 products losing market share:
  Product 0: -3.05%
  Product 3: -1.41%
  Product 14: -0.67%


In [6]:
# === Counterfactual Analysis 2: Removing Low-Margin Products ===
print("\n" + "=" * 60)
print("SCENARIO 2: Removing low-margin products")
print("=" * 60)

# Current: all 20 products
# Counterfactual: remove products 16-19 (low-margin)
baseline_full = torch.ones(n_users, n_items)
counterfactual_trim = baseline_full.clone()
counterfactual_trim[:, 16:20] = 0

# Low-margin products have lower revenue
item_revenues_margin = torch.ones(n_items) * 15.0  # Standard: $15
item_revenues_margin[16:20] = 5.0  # Low-margin: $5

results2 = model.counterfactual(
    test_users, baseline_full, counterfactual_trim, item_revenues_margin
)

print(f"\nRevenue change: ${results2['revenue_change']:.2f} ({results2['revenue_change_pct']:.2%})")
print("\nTop 3 products gaining market share:")
gainers = torch.argsort(results2['market_share_change'], descending=True)[:3]
for idx in gainers:
    if idx < 16:  # Exclude removed products
        print(f"  Product {idx.item()}: {results2['market_share_change'][idx]:.2%}")


SCENARIO 2: Removing low-margin products

Revenue change: $193.75 (14.83%)

Top 3 products gaining market share:
  Product 4: 2.13%
  Product 2: 1.87%
  Product 14: 1.84%


In [7]:
# === Counterfactual Analysis 3: User Heterogeneity (No IIA!) ===
print("SCENARIO 3: Heterogeneous substitution patterns")

# Look at how different user types respond to removing product 5
specific_users = torch.tensor([0, 1, 2, 3, 4])  # 5 different users
baseline_single = torch.ones(5, n_items)
counterfactual_single = baseline_single.clone()
counterfactual_single[:, 5] = 0  # Remove product 5

results3 = model.counterfactual(specific_users, baseline_single, counterfactual_single)

print("\nWhen product 5 is removed, different users substitute differently:")
print("(This demonstrates the model avoids IIA - standard MNL would show identical substitution)")
for u in range(5):
    baseline_choice = results3['baseline_choices'][u].item()
    cf_choice = results3['counterfactual_choices'][u].item()

    if baseline_choice == 5:
        print(f"  User {u}: was choosing product 5, now chooses product {cf_choice}")
    else:
        print(f"  User {u}: still chooses product {baseline_choice}")

print("\n" + "=" * 60)
print("Key insight: The low-rank structure captures heterogeneous preferences,")
print("allowing flexible substitution patterns that avoid the IIA problem.")
print("=" * 60)


SCENARIO 3: Heterogeneous substitution patterns

When product 5 is removed, different users substitute differently:
(This demonstrates the model avoids IIA - standard MNL would show identical substitution)
  User 0: still chooses product 15
  User 1: still chooses product 4
  User 2: still chooses product 16
  User 3: still chooses product 14
  User 4: still chooses product 4

Key insight: The low-rank structure captures heterogeneous preferences,
allowing flexible substitution patterns that avoid the IIA problem.
