<a href="https://colab.research.google.com/github/ansonkwokth/PlackettLuceModel/blob/main/example_dataframe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# # Clone the repository
# !git clone https://github.com/ansonkwokth/PlackettLuceModel.git

# # Navigate to the project directory
# %cd /content/PlackettLuceModel

# # Install dependencies
# !pip install .

In [2]:
# !python ./scripts/generate_example_data.py

## Import

In [3]:
%cd /content/PlackettLuceModel
from plackett_luce.utils import DataLoader
from plackett_luce.model import PlackettLuceModel
from plackett_luce.utils import EarlyStopper

import pandas as pd
import numpy as np
import torch
from torch import nn

torch.manual_seed(0);

/content/PlackettLuceModel


In [4]:
df = pd.read_csv("./data/example_data/example_data.csv")

In [5]:
X, rankings, mask = DataLoader().transform(df)

In [38]:
n_train = 1000
X_train = X[:n_train]
rankings_train = rankings[:n_train]
mask_train = mask[:n_train]

In [39]:
# Custom neural network model for flexible scoring
class NaiveNN(nn.Module):
    def __init__(self, input_dim):
        super(NaiveNN, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 1)  # 1D output for scoring
        )

    def forward(self, x):
        return self.network(x)





# Custom neural network model for flexible scoring
class LessNaiveNN(nn.Module):
    def __init__(self, input_dim):
        super(LessNaiveNN, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 4),
            nn.ReLU(),
            nn.Linear(4, 1)  # 1D output for scoring
        )

    def forward(self, x):
        return self.network(x)

In [40]:
# Initialize the model
num_features = X_train.shape[-1]
# custom_nn = NaiveNN(input_dim=num_features)
custom_nn = LessNaiveNN(input_dim=num_features)
# Custom early stopper
custom_early_stopper = EarlyStopper(patience=5, min_delta=0.01)
model = PlackettLuceModel(score_model=custom_nn, early_stopper=custom_early_stopper)
print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

# Training
print("Training the model...")

model.fit(X_train, rankings_train, lr=0.01, epochs=500, top_k=3, item_mask=mask_train)



Trainable params: 433
Training the model...
Epoch 10/500, Negative Log-Likelihood: 7.3377
Epoch 20/500, Negative Log-Likelihood: 6.7862
Epoch 30/500, Negative Log-Likelihood: 5.9639
Epoch 40/500, Negative Log-Likelihood: 4.7617
Epoch 50/500, Negative Log-Likelihood: 4.0406
Epoch 60/500, Negative Log-Likelihood: 3.8171
Epoch 70/500, Negative Log-Likelihood: 3.6588
Epoch 80/500, Negative Log-Likelihood: 3.5560
Epoch 90/500, Negative Log-Likelihood: 3.4771
Early stopping at epoch 97 with NLL 3.4312


In [45]:
n_test = -1000
X_test = X[n_test:]
rankings_test = rankings[n_test:]
mask_test = mask[n_test:]

In [46]:
pred_test = model.predict(X_test)
pred_test = torch.tensor(pred_test)

In [47]:
pred_test[mask_test == 0] = -99

In [59]:

for i in range(10):
    print("instance", i+1)
    print("Pred first 5", pred_test[i][:5].tolist())
    print("True first 5", rankings_test[i][:5].tolist())
    print()

instance 1
Pred first 5 [9, 8, 13, 6, 2]
True first 5 [9, 6, 13, 8, 3]

instance 2
Pred first 5 [10, 0, 9, 2, 4]
True first 5 [9, 10, 4, 5, 2]

instance 3
Pred first 5 [4, 5, 3, 0, 6]
True first 5 [4, 5, 2, 3, 0]

instance 4
Pred first 5 [0, 3, 8, 1, 5]
True first 5 [0, 1, 3, 5, 9]

instance 5
Pred first 5 [6, 1, 2, 10, 0]
True first 5 [6, 1, 10, 11, 2]

instance 6
Pred first 5 [8, 13, 4, 12, 10]
True first 5 [13, 8, 12, 3, 11]

instance 7
Pred first 5 [7, 2, 1, 6, 3]
True first 5 [7, 2, 1, 6, 3]

instance 8
Pred first 5 [9, 4, 10, 8, 5]
True first 5 [9, 4, 10, 8, 7]

instance 9
Pred first 5 [3, 2, 8, 4, 7]
True first 5 [8, 3, 2, 7, 9]

instance 10
Pred first 5 [2, 9, 5, 4, 3]
True first 5 [2, 5, 9, 4, 3]

