# LambdaRank Implementation in PyTorch

## Key formulation of LambdaRank

Formulation of pairwise ranking, for document $i$ and $j$ - Ranknet Loss function  

\begin{equation}
\begin{split}
L(y, s) &= \sum_{i=1}^{n}\sum_{j=1}^{n}\mathop{\mathbb{I}_{y_i > y_j}} \log_2(1 + e^{-\sigma(s_i - s_j)}) \\
& = \sum_{y_i > y_j} \log_2(1+e^{-\sigma(s_i - s_j)})
\end{split}
\end{equation}

#### Ranking Metrics - NDGC
\begin{equation}
\text{NDCG} = \frac{1}{\text{maxDCG}} \sum_{i=1}^{n} \frac{2^{y_i} - 1}{\log_2(1+i)} = \sum_{i=1}^{n}\frac{G_i}{D_i}
\end{equation}
where
\begin{equation}
G_i = \frac{2^{y_i} - 1}{\text{maxDCG}}, D_i = \log_2(1+i)
\end{equation}

- $G_i$ is the gain function
- $D_i$ is the discount functions
- $\text{maxDCG}$ is a constant factor per query

#### LambdaRank - Dynamically adjust the loss function during the training based on ranking metrics

Define the change of NDCG
\begin{equation}
\Delta\text{NDCG}(i,j) = |G_i - G_j||\frac{1}{D_i} -  \frac{1}{D_j}|
\end{equation}

Loss function
\begin{equation}
L(y,s) = \sum_{y_i>y_j}\Delta\text{NDCG}(i,j) log_2(1+e^{-\sigma(s_i-s_j)})
\end{equation}

### Preview of dataset

The dataset is structued with these information:
- Relevance score of the result to the query
- The query ID
- The features


| Relevance | qid | features (1-136) |
|-----------|-----|------------------|
| 2         | 1   | ...              |
| 3         | 1   | ...              |
| 2         | 1   | ...              |

In [1]:
import gc
import math
import torch
import numpy as np
import random

from collections import defaultdict
from itertools import combinations
from torch.utils.data import Dataset, DataLoader

In [2]:
def _get_format_data(data_path):
        """
        Extract data from data path
        Args:
            data_path (str): Path of the data file
        """
        
        labels = []
        features = []
        query_ids = []
        
        print("Getting data from %s" % data_path)
        
        def _extract_features(toks):
            """Extract features from tokens (e.g. 1: 0 -> 0)"""
            features = []
            for tok in toks:
                features.append(float(tok.split(":")[1]))
            return features

        def _extract_query_data(tok):
            """Extract query features (e.g. qid: 10 -> 10)"""
            # qid
            query_features = [tok.split(":")[1]]
            return query_features
        
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f:
                data, _, comment = line.rstrip().partition("#")
                toks = data.split()

                labels.append(int(toks[0]))                  # label - The relevance score
                features.append(_extract_features(toks[2:]))    # doc features
                query_ids.append(int(_extract_query_data(toks[1])[0]))  # qid
        
        return labels, features, query_ids

In [3]:
class MSLR10KDataset(Dataset):
    """MSLR 10K Pairs Dataset"""
    
    def __init__(self, path, test_ds=False):
        """
        Args:
            path (str)
        """
        
        self.path = path
        self.features = []
        self.labels = []
        self.query_ids = []
        self.test_ds = test_ds
        
        self.dataset = defaultdict(list)
        
        # Generate dataset
        self._extract_raw_data(self.path)
        self._long_to_wide_transform()
        
        self.qids = list(self.dataset)
        

    def _extract_raw_data(self, data_path):
        """
        Extract data from data path
        Args:
            data_path (str): Path of the data file
        """
        
        print("Getting data from %s" % data_path)
        
        def _extract_features(toks):
            """Extract features from tokens (e.g. 1: 0 -> 0)"""
            features = []
            for tok in toks:
                features.append(float(tok.split(":")[1]))
            return features

        def _extract_query_data(tok):
            """Extract query features (e.g. qid: 10 -> 10)"""
            # qid
            query_features = [tok.split(":")[1]]
            return query_features
        
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f:
                data, _, comment = line.rstrip().partition("#")
                toks = data.split()

                if not self.test_ds:
                    self.labels.append(int(toks[0]) + 1)             # label - The relevance score. +1 to make sure no 0 score
                
                self.features.append(_extract_features(toks[2:]))    # doc features
                self.query_ids.append(int(_extract_query_data(toks[1])[0]))  # qid
                
    def _long_to_wide_transform(self):
        """
        Transform long dataset to wide dataset
        """
        for item in zip(self.query_ids, self.features, self.labels):
            self.dataset[item[0]].append((torch.Tensor(np.array(item[1])), item[2]))
        
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        qid = self.qids[idx]
        sample = {"qid": qid,
                  "records": self.dataset[qid]}
        
        return sample

In [4]:
dataset = MSLR10KDataset(path="./data/MSLR-WEB10K/Fold1/train.txt")

Getting data from ./data/MSLR-WEB10K/Fold1/train.txt


In [5]:
val_dataset = MSLR10KDataset(path="./data/MSLR-WEB10K/Fold1/vali.txt")

Getting data from ./data/MSLR-WEB10K/Fold1/vali.txt


In [6]:
test_dataset = MSLR10KDataset(path="./data/MSLR-WEB10K/Fold1/test.txt", test_ds=True)

Getting data from ./data/MSLR-WEB10K/Fold1/test.txt


In [7]:
# dataset[0]["qid"]

In [8]:
len(dataset[0]["records"])

86

### DCG Function

In [9]:
def DCG(labels, rank=None, n=None):
    """
    Calculate DCG for the labels
    Given label, DCG = sum(2^(relevance) - 1) / log2(rank + 1)
    
    Args:
        labels (torch.Tensor) - Labels sorted in correct order
        rank (torch.Tensor) - Ranking of the labels
    """
    
    if rank is None:
        # Default ranking (1....n)
        rank = torch.arange(0, labels.size()[0]) + 1
    
    if n is not None and n <= len(labels):
        labels = labels[:n]
        rank = rank[:n]
    
    nom = (2 ** labels.view(-1, 1)) - 1
    denom = torch.log2(rank.float().view(-1, 1) + 1)
    
    return torch.sum(nom/denom)

In [10]:
# Test out DCG
records = list(zip(*dataset[0]["records"]))
labels = records[1]
DCG(torch.Tensor(labels))

tensor(52.1050)

In [11]:
print(labels)

(3, 3, 1, 3, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 3, 3, 1, 2, 1, 2, 3, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 3, 4, 1, 1, 1, 1, 1, 2, 1, 1, 1, 3, 2, 3, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 2, 2, 1)


#### Add into dataloader

In [12]:
dataloader = DataLoader(dataset, batch_size=1,
                        shuffle=True, num_workers=1)

In [13]:
validation_dataloader = DataLoader(val_dataset, batch_size=1,
                                   shuffle=True, num_workers=1)

In [38]:
# next(iter(validation_dataloader))

In [36]:
raw_val_records

'records'

## Setup the model

Reference: https://github.com/airalcorn2/RankNet/blob/master/lambdarank.py

In [14]:
import torch.nn as nn
import torch.optim as optim

In [41]:
class RankNet(nn.Module):
    """Pairwise Ranking Ranknet"""
    
    def __init__(self, num_features, num_layers=4, hidden_size=32):
        
        super().__init__()
        
        layer_list = []
        layer_list.append(nn.BatchNorm1d(num_features))
        layer_list.append(nn.Linear(num_features, hidden_size))
#         layer_list.append(nn.Dropout(0.5))
        layer_list.append(nn.ReLU())
        
        for l in range(num_layers - 1):
            layer_list.append(nn.Linear(hidden_size, hidden_size))
#             layer_list.append(nn.Dropout(0.5))
            layer_list.append(nn.ReLU())
            
        layer_list.append(nn.Linear(hidden_size, 1))
        
        self.model = nn.Sequential(*layer_list)
        self.output = nn.Sigmoid()
        
    def forward(self, input_i, input_j):
        si = self.model(input_i)
        sj = self.model(input_j)
        diff = si - sj
        prob = self.output(diff)
        return prob
    
    def predict(self, x, labels, sort=True):
        
        # Get scores
        scores = self.model(x)
        
        if sort:
            sorted_predictions, sorted_idx = torch.sort(scores, dim=0, descending=True)
            sorted_pred_labels = labels[sorted_idx]      
            return sorted_predictions, sorted_idx, sorted_pred_labels
        else:
            return scores

In [42]:
NUM_FEATURES = 136

#### Generate pairwise loss

In [27]:
def _get_scores_from_pairs(labels_i, labels_j):
    """
    If labels_i > labels_j => scores = 1
    elif labels_i == labels_j => scores = 0
    else score = -1
    
    Args:
        labels_i (torch.Tensor): Tensor of relevancy scores of sample i
        labels_j (torch.Tensor): Tensor of relevancy scores of sample j
    Return:
        (idx of i>j sets, idx of i==j sets, idx of i<j sets)
    """
    i_gt_j = (labels_i > labels_j).nonzero()
    i_eq_j = (labels_i == labels_j).nonzero()
    i_lt_j = (labels_i < labels_j).nonzero()
    
    return (i_gt_j, i_eq_j, i_lt_j)

In [28]:
def _delta_ndcg(labels_i, labels_j, rank_i, rank_j, max_dcg):
    """
    Calculate the difference in ndcg if swapping two documents
    """
    gain_i = ((2 ** (labels_i)) - 1) / max_dcg
    discount_i = torch.log2(1 + rank_i.float())

    gain_j = ((2 ** (labels_j)) - 1) / max_dcg
    discount_j = torch.log2(1 + rank_j.float())

    delta_ndcg = (torch.abs(gain_i - gain_j)) * (torch.abs((1/discount_i) - (1/discount_j)))
    
    return delta_ndcg

In [29]:
def _extract_features_labels(records, labels=True, sort_by_label=True):
    """
    Extract features and labels (if any) from records
    """
    record_features = torch.cat(list(records[0]))
    
    if labels:
        record_labels = torch.cat(list(records[1])).float()
        
        if sort_by_label:
            sorted_labels, sorted_labels_idx = torch.sort(torch.Tensor(record_labels), descending=True)
            return record_features, record_labels, sorted_labels, sorted_labels_idx
        else:
            return record_features, record_labels
    
    return record_features

In [30]:
losses = []

In [None]:
ranknet = RankNet(NUM_FEATURES, num_layers=3, hidden_size=64)
optimizer = optim.Adam(ranknet.parameters(), lr=1e-3)
BATCH_SIZE = 64

for epoch in range(5):
    
    # print statistics
    running_loss = 0.0
    running_ndcg5 = 0.0
    counter = 0
    pairs_comparison = 0
    
    for i_batch, sample_batched in enumerate(dataloader):

        # For every time only train one query
        num_records = len(sample_batched['records'])
        
        if num_records < 2:
            continue

        # Each record is a tuple (features, label)
        # Break down into lists
        records = list(zip(*sample_batched['records']))
#         record_features = torch.cat(list(records[0]))
#         record_labels = torch.cat(list(records[1])).float()
#         sorted_labels, sorted_labels_idx = torch.sort(torch.Tensor(record_labels), descending=True)
        record_features, record_labels, sorted_labels, sorted_labels_idx = _extract_features_labels(records, labels=True, sort_by_label=True)
        
        
        # Clear the memory
        del records
        gc.collect()

        # Calculate the dcg base on a query
        max_dcg = DCG(sorted_labels)
        max_dcg_5 = DCG(sorted_labels, n=5)

        # Generate combinations on the fly
        combo_ids = list(combinations(range(num_records), 2))
        random.shuffle(combo_ids)
        num_mini_batch = math.floor(len(combo_ids) / BATCH_SIZE)
        
        for j in range(num_mini_batch):
            
            random_sample = combo_ids[j * BATCH_SIZE:j * BATCH_SIZE + BATCH_SIZE]
            sample_i_idx, sample_j_idx = tuple(zip(*random_sample))    
            sample_i_idx, sample_j_idx = torch.Tensor(sample_i_idx).long(), torch.Tensor(sample_j_idx).long()
            
            # Obtain features
            features_i = record_features[sample_i_idx]
            features_j = record_features[sample_j_idx]
            labels_i = record_labels[sample_i_idx]
            labels_j = record_labels[sample_j_idx]   

            # Calculate scores
            scores = torch.where(labels_i > labels_j, torch.ones_like(labels_i), 
                                 torch.where(labels_i == labels_j, torch.zeros_like(labels_i), -torch.ones_like(labels_i)))
            y_bar = (1/2 * (1 + scores)).view(-1, 1)

            # Sort the scores by the current model
#             predicted_scores = ranknet.predict(record_features, sort=False)
            sorted_predictions, sorted_idx, sorted_pred_labels = ranknet.predict(record_features, record_labels, sort=True)
            dcg_5 = DCG(sorted_pred_labels, n=5)

            rank_i = torch.nonzero((sample_i_idx.view(-1, 1) == sorted_idx.flatten()))[:, 1] + 1
            rank_j = torch.nonzero((sample_j_idx.view(-1, 1) == sorted_idx.flatten()))[:, 1] + 1

            # Calculate delta NDCG            
            delta_ndcg = _delta_ndcg(labels_i, labels_j, rank_i, rank_j, max_dcg)

            criterion = nn.BCELoss(weight=delta_ndcg.view(-1, 1), reduction='sum')

            # Forward pass
            logits = ranknet.forward(features_i, features_j)
            loss = criterion(logits, y_bar)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses.append(loss.item())
            running_loss += loss.item()
            running_ndcg5 += (dcg_5/max_dcg_5)
            counter += 1

        if i_batch % 100 == 99:
            pairs_comparison += counter
            
            # Run prediction from validation dataset
            val_ds = next(iter(validation_dataloader))
            val_records = list(zip(*val_ds['records']))
            val_record_features, val_record_labels, val_sorted_labels, val_sorted_labels_idx = _extract_features_labels(val_records, labels=True, sort_by_label=True)
            val_max_dcg_5 = DCG(val_sorted_labels, n=5)
            
            val_sorted_predictions, val_sorted_idx, val_sorted_pred_labels = ranknet.predict(val_record_features, val_record_labels, sort=True)
            val_dcg_5 = DCG(val_sorted_pred_labels, n=5)            
            
            
            print("\r[%d, %d] loss: %.5f; nDCG@5: %.5f; val-nDCG@5: %.5f, pairs_compared: %d"
                  % (epoch + 1,
                     i_batch + 1,
                     running_loss / counter,
                     running_ndcg5 / counter,
                     val_dcg_5/val_max_dcg_5,
                     pairs_comparison), end="")
            
            running_loss = 0.0
            running_ndcg5 = 0.0
            counter = 0

[1, 100] loss: 0.03365; nDCG@5: 0.73619; val-nDCG@5: 0.37869, pairs_compared: 12736