# Learning to rank notebook

Reference:   
[1] [From RankNet to LambdaRank to LambdaMART: An Overview](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.180.634&rep=rep1&type=pdf)

In [None]:
import pandas as pd
import numpy as np

from collections import defaultdict

## Setup a data loader

In [None]:
import torch

from torch.utils.data import Dataset, DataLoader

In [None]:
class MSLR10KDataset(Dataset):
    """MSLR 10K Dataset"""
    
    def __init__(self, path):
        """
        Args:
            pairs (list of tuples): The pairs of record to be compared
            scores (list of int): The scores of 1, -1, 0
            i_features (list of list): Feature list of ith document
            j_features (list of list): Feature list of jth document
        """
        
        self.path = path
        self.features = []
        self.labels = []
        self.query_ids = []
        
        # Generate dataset
        self._get_format_data(self.path)
        self.pairs, self.scores, self.i_features, self.j_features = \
            self._get_pair_doc_data(self.labels, self.query_ids)
        

    def _get_format_data(self, data_path):
        """
        Extract data from data path
        Args:
            data_path (str): Path of the data file
        """
        
        print("Getting data from %s" % data_path)
        
        def _extract_features(toks):
            """Extract features from tokens (e.g. 1: 0 -> 0)"""
            features = []
            for tok in toks:
                features.append(float(tok.split(":")[1]))
            return features

        def _extract_query_data(tok):
            """Extract query features (e.g. qid: 10 -> 10)"""
            # qid
            query_features = [tok.split(":")[1]]
            return query_features
        
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f:
                data, _, comment = line.rstrip().partition("#")
                toks = data.split()

                self.labels.append(int(toks[0]))                  # label - The relevance score
                self.features.append(_extract_features(toks[2:]))    # doc features
                self.query_ids.append(_extract_query_data(toks[1]))  # qid
                
    def _get_pair_doc_data(self, y_train, query_id):
        """
        Get pairs data
        Args:
            y_train (list): List of relevance score
            query_id (list): List of query_id
        """
        pairs = []
        scores = []
        i_features = []
        j_features = []

        for i in range(0, len(query_id) - 1):
            for j in range(i + 1, len(query_id)):

                # Make sure the documents are for the same query id
                if query_id[i][0] != query_id[j][0]:
                    break

                if y_train[i] > y_train[j]:
                    pairs.append((i, j))
                    i_features.append(self.features[i])
                    j_features.append(self.features[j])
                    scores.append(1)
                elif y_train[i] < y_train[j]:
                    pairs.append((j, i))
                    i_features.append(self.features[j])
                    j_features.append(self.features[i])
                    scores.append(1)
                else:
                    pairs.append((i, j))
                    i_features.append(self.features[i])
                    j_features.append(self.features[j])
                    scores.append(0)

        return pairs, scores, i_features, j_features
    
    def __len__(self):
        return len(self.query_ids)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        sample = {"pairs": self.pairs[idx],
                  "i_features": torch.tensor(np.array(self.i_features[idx])),
                  "j_features": torch.tensor(np.array(self.j_features[idx])),
                  "scores": torch.tensor(self.scores[idx])}
        
        return sample

In [None]:
dataset = MSLR10KDataset(path="./data/MSLR-WEB10K/Fold1/train.txt")

# Setup the model

In [None]:
import torch.nn as nn
import torch.optim as optim

In [None]:
class RankNet(nn.Module):
    """Pairwise Ranking Ranknet"""
    
    def __init__(self, num_features, hidden_size_1=32, hidden_size_2=16):
        
        super().__init__()
        
        self.model = nn.Sequential(nn.Linear(num_features, hidden_size_1),
                                   nn.Dropout(0.5),
                                   nn.ReLU(),
                                   nn.Linear(hidden_size_1, hidden_size_2),
                                   nn.Dropout(0.5),
                                   nn.ReLU(),
                                   nn.Linear(hidden_size_2, 1))
        self.output = nn.Sigmoid()
        
    def forward(self, input_i, input_j):
        si = self.model(input_i)
        sj = self.model(input_j)
        diff = si - sj
        prob = self.output(diff)
        return prob
    
    def predict(self, x):
        return self.model(x)

In [None]:
ranknet = RankNet(num_features=136)

In [None]:
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

In [None]:
criterion = nn.BCELoss()
optimizer = optim.Adam(ranknet.parameters(), lr=1e-2)

In [None]:
losses = []

In [None]:
for epoch in range(30):
    
    # Load dataloader
    running_loss = 0.0
    
    for i, sample_batched in enumerate(dataloader):
        pair_index = sample_batched["pairs"]
        i_features = sample_batched["i_features"].float()
        j_features = sample_batched["j_features"].float()
        labels = sample_batched["scores"].view(-1, 1).float()
    
        # Forward pass
        outputs = ranknet.forward(i_features, j_features)
#         print(outputs.grad_fn)
        loss = criterion(outputs, labels)
#         print(ranknet.model.parameters)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # print statistics
        running_loss += loss.item()
        if i % 100 == 99:
            # print every 2000 iterations
#             print(list(ranknet.parameters())[0])
            print("\r[%d, %5d] loss: %.3f" % (epoch + 1,
                                              i + 1,
                                              running_loss / 100), end="")
            running_loss = 0.0
#     break