<a href="https://colab.research.google.com/github/ansonkwokth/PlackettLuceModel/blob/main/example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [85]:
# !git clone https://github.com/ansonkwokth/PlackettLuceModel.git

In [86]:
# !python -m unittest plackett_luce/tests/test_utils.py

In [88]:
from plackett_luce import datasets as ds
from plackett_luce.model import PlackettLuceModel
from plackett_luce.utils import EarlyStopper

import torch
from torch import nn

torch.manual_seed(0);


In [89]:
# Custom neural network model for flexible scoring
class NaiveNN(nn.Module):
    def __init__(self, input_dim):
        super(NaiveNN, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 1)  # 1D output for scoring
        )

    def forward(self, x):
        return self.network(x)





# Custom neural network model for flexible scoring
class LessNaiveNN(nn.Module):
    def __init__(self, input_dim):
        super(LessNaiveNN, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 4),
            nn.ReLU(),
            nn.Linear(4, 1)  # 1D output for scoring
        )

    def forward(self, x):
        return self.network(x)


In [90]:
# Parameters
num_samples_train = 1000
num_samples_test = 1000
num_items = 14

# Data generation
print("Generating training and testing data...")
X_train, rankings_train = ds.generate_data(num_samples_train, num_items)
X_test, rankings_test = ds.generate_data(num_samples_test, num_items)
num_features = X_train.shape[-1]

# Create item masks for variable item counts
item_mask_train = torch.ones((num_samples_train, num_items))
item_mask_test = torch.ones((num_samples_test, num_items))
# Simulate some instances with fewer items (e.g., 5 items max but some with only 3)
# item_mask_train[torch.rand(num_samples_train, num_items) < 0.2] = 0  # Randomly mask some items
# item_mask_test[torch.rand(num_samples_test, num_items) < 0.2] = 0


Generating training and testing data...


In [91]:

# Initialize the model
# custom_nn = NaiveNN(input_dim=num_features)
custom_nn = LessNaiveNN(input_dim=num_features)
# Custom early stopper
custom_early_stopper = EarlyStopper(patience=20, min_delta=0.01)
model = PlackettLuceModel(score_model=custom_nn, early_stopper=custom_early_stopper)

# Training
print("Training the model...")

model.fit(X_train, rankings_train, lr=0.01, epochs=500, top_k=3)
# model.fit(X_train, rankings_train, lr=0.01, epochs=500)


Training the model...
Epoch 10/500, Negative Log-Likelihood: 7.3778
Epoch 20/500, Negative Log-Likelihood: 6.6657
Epoch 30/500, Negative Log-Likelihood: 5.6260
Epoch 40/500, Negative Log-Likelihood: 4.5597
Epoch 50/500, Negative Log-Likelihood: 3.6237
Epoch 60/500, Negative Log-Likelihood: 3.2466
Epoch 70/500, Negative Log-Likelihood: 3.0972
Epoch 80/500, Negative Log-Likelihood: 2.9562
Epoch 90/500, Negative Log-Likelihood: 2.8693
Epoch 100/500, Negative Log-Likelihood: 2.7968
Epoch 110/500, Negative Log-Likelihood: 2.7390
Epoch 120/500, Negative Log-Likelihood: 2.6941
Epoch 130/500, Negative Log-Likelihood: 2.6578
Epoch 140/500, Negative Log-Likelihood: 2.6328
Epoch 150/500, Negative Log-Likelihood: 2.6103
Epoch 160/500, Negative Log-Likelihood: 2.5915
Epoch 170/500, Negative Log-Likelihood: 2.5708
Epoch 180/500, Negative Log-Likelihood: 2.5522
Epoch 190/500, Negative Log-Likelihood: 2.5406
Early stopping at epoch 200 with NLL 2.5301


In [93]:

# Test the model
print("\nTesting the model...\n")
predicted_rankings = model.predict(X_test)

# Evaluate the performance
top1_correct = 0
top2_correct = 0
top3_correct = 0
top1in3_correct = 0
top2in3_correct = 0
top1or2in3_correct = 0

print_first_few = 10
for i, (pred, true) in enumerate(zip(predicted_rankings, rankings_test.tolist())):
    if i < print_first_few:
        print(f"Sample {i + 1}:")
        print(f"  Predicted Ranking: {pred}")
        print(f"  True Ranking:      {true}")

    # Check Top-1 accuracy
    if pred[0] == true[0]:
        top1_correct += 1

    # Check Top-2 accuracy
    if pred[:2] == true[:2]:
        top2_correct += 1

    # Check Top-3 accuracy
    if pred[:3] == true[:3]:
        top3_correct += 1

    # Check Top-1 in first 3 accuracy
    if pred[0] in true[:3]:
        top1in3_correct += 1
    # Check Top-2 in first 3 accuracy
    if pred[1] in true[:3]:
        top2in3_correct += 1
    # Check Top-1 or 2 in first 3 accuracy
    if pred[0] in true[:3] or pred[1] in true[:3]:
        top1or2in3_correct += 1

# Compute percentages
top1_accuracy = top1_correct / num_samples_test * 100
top2_accuracy = top2_correct / num_samples_test * 100
top3_accuracy = top3_correct / num_samples_test * 100
top1in3_accuracy = top1in3_correct / num_samples_test * 100
top2in3_accuracy = top2in3_correct / num_samples_test * 100
top1or2in3_accuracy = top1or2in3_correct / num_samples_test * 100

print(f"\nTop-1 or 2 in 3 Accuracy: {top1or2in3_accuracy:.2f}%")
print(f"Top-1 in 3 Accuracy: {top1in3_accuracy:.2f}%")
print(f"Top-2 in 3 Accuracy: {top2in3_accuracy:.2f}%")
print(f"Top-1 Accuracy: {top1_accuracy:.2f}%")
print(f"Top-2 Accuracy: {top2_accuracy:.2f}%")
print(f"Top-3 Accuracy: {top3_accuracy:.2f}%")



Testing the model...

Sample 1:
  Predicted Ranking: [0, 12, 2, 6, 1, 13, 9, 3, 10, 4, 7, 11, 5, 8]
  True Ranking:      [0, 2, 12, 6, 13, 1, 9, 10, 3, 5, 4, 11, 7, 8]
Sample 2:
  Predicted Ranking: [13, 12, 11, 0, 6, 3, 8, 9, 2, 5, 4, 7, 1, 10]
  True Ranking:      [13, 12, 11, 9, 6, 2, 0, 8, 3, 4, 10, 1, 7, 5]
Sample 3:
  Predicted Ranking: [0, 4, 6, 3, 12, 7, 8, 13, 11, 1, 10, 9, 2, 5]
  True Ranking:      [0, 4, 6, 12, 7, 13, 3, 11, 1, 9, 8, 5, 10, 2]
Sample 4:
  Predicted Ranking: [8, 4, 12, 2, 0, 3, 11, 7, 5, 13, 1, 9, 6, 10]
  True Ranking:      [8, 4, 12, 0, 13, 3, 2, 7, 11, 5, 10, 1, 6, 9]
Sample 5:
  Predicted Ranking: [9, 6, 0, 7, 8, 13, 4, 12, 11, 3, 1, 10, 2, 5]
  True Ranking:      [9, 0, 7, 6, 12, 13, 11, 4, 8, 1, 5, 2, 10, 3]
Sample 6:
  Predicted Ranking: [10, 5, 1, 0, 9, 2, 3, 8, 13, 7, 6, 4, 11, 12]
  True Ranking:      [5, 10, 1, 2, 9, 0, 3, 6, 7, 8, 11, 12, 4, 13]
Sample 7:
  Predicted Ranking: [4, 0, 2, 1, 11, 13, 3, 9, 5, 12, 7, 8, 10, 6]
  True Ranking:      [4