# RankNet

In [2]:
import torch

In [3]:
class RankNet(torch.nn.Module):
    def __init__(
        self,
        num_input_features: int,
        hidden_dim: int = 10,
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.model = torch.nn.Sequential(
            torch.nn.Linear(num_input_features, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, 1),
        )
        self.out_activation = torch.nn.Sigmoid()
        
    def forward(
        self,
        input_1: torch.Tensor,
        input_2: torch.Tensor,
    ):
        logits_1 = self.predict(input_1)
        logits_2 = self.predict(input_2)
        
        logits_diff = logits_1 - logits_2
        return self.out_activation(logits_diff)
    
    def predict(self, x):
        return self.model(x)

# ListNet

In [8]:
import sys
import numpy as np

from itertools import combinations

sys.path.append("../week01_metrics")
from metrics import (
    compute_gain,
    dcg,
    ndcg,
    precission_at_k,
    reciprocal_rank,
    p_found,
    num_swapped_pairs
)

In [9]:
class ListNet(torch.nn.Module):
    def __init__(
        self,
        num_input_features: int,
        hidden_dim: int = 10,
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.model = torch.nn.Sequential(
            torch.nn.Linear(num_input_features, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, 1),
        )

    def forward(
        self,
        input_1
    ):
        logits = self.model(input_1)
        return logits

In [10]:
def listnet_ce_loss(y_i, z_i):
    """
    y_i: (n_i, 1) GT
    z_i: (n_i, 1) preds
    """

    P_y_i = torch.softmax(y_i, dim=0)
    P_z_i = torch.softmax(z_i, dim=0)
    return -torch.sum(P_y_i * torch.log(P_z_i))

def listnet_kl_loss(y_i, z_i):
    """
    y_i: (n_i, 1) GT
    z_i: (n_i, 1) preds
    """
    P_y_i = torch.softmax(y_i, dim=0)
    P_z_i = torch.softmax(z_i, dim=0)
    return -torch.sum(P_y_i * torch.log(P_z_i/P_y_i))


def make_dataset(N_train, N_valid, vector_dim):
    fake_weights = torch.randn(vector_dim, 1)

    X_train = torch.randn(N_train, vector_dim)
    X_valid = torch.randn(N_valid, vector_dim)

    ys_train_score = torch.mm(X_train, fake_weights)
    ys_train_score += torch.randn_like(ys_train_score)

    ys_valid_score = torch.mm(X_valid, fake_weights)
    ys_valid_score += torch.randn_like(ys_valid_score)

#     bins = [-1, 1]  # 3 relevances
    bins = [-1, 0, 1, 2]  # 5 relevances
    ys_train_rel = torch.Tensor(
        np.digitize(ys_train_score.clone().detach().numpy(), bins=bins)
    )
    ys_valid_rel = torch.Tensor(
        np.digitize(ys_valid_score.clone().detach().numpy(), bins=bins)
    )

    return X_train, X_valid, ys_train_rel, ys_valid_rel

In [11]:
N_train = 1000
N_valid = 500

vector_dim = 100
epochs = 2

batch_size = 16

X_train, X_valid, ys_train, ys_valid = make_dataset(N_train, N_valid, vector_dim)

net = ListNet(num_input_features=vector_dim)
opt = torch.optim.Adam(net.parameters())


In [12]:
for epoch in range(epochs):
    idx = torch.randperm(N_train)

    X_train = X_train[idx]
    ys_train = ys_train[idx]

    cur_batch = 0
    for it in range(N_train // batch_size):
        batch_X = X_train[cur_batch: cur_batch + batch_size]
        batch_ys = ys_train[cur_batch: cur_batch + batch_size]
        cur_batch += batch_size

        opt.zero_grad()
        if len(batch_X) > 0:
            batch_pred = net(batch_X)
            batch_loss = listnet_kl_loss(batch_ys, batch_pred)
#             batch_loss = listnet_ce_loss(batch_ys, batch_pred)
            batch_loss.backward(retain_graph=True)
            opt.step()

        if it % 10 == 0:
            with torch.no_grad():
                valid_pred = net(X_valid)
                valid_swapped_pairs = num_swapped_pairs(ys_valid, valid_pred)
                ndcg_score = ndcg(ys_valid, valid_pred)
            print(f"epoch: {epoch + 1}.\tNumber of swapped pairs: " 
                  f"{valid_swapped_pairs}/{N_valid * (N_valid - 1) // 2}\t"
                  f"nDCG: {ndcg_score:.4f}")

epoch: 1.	Number of swapped pairs: 0/124750	nDCG: 0.8010
epoch: 1.	Number of swapped pairs: 0/124750	nDCG: 0.8233
epoch: 1.	Number of swapped pairs: 0/124750	nDCG: 0.8345
epoch: 1.	Number of swapped pairs: 0/124750	nDCG: 0.8499
epoch: 1.	Number of swapped pairs: 0/124750	nDCG: 0.8631
epoch: 1.	Number of swapped pairs: 0/124750	nDCG: 0.8792
epoch: 1.	Number of swapped pairs: 0/124750	nDCG: 0.8911
epoch: 2.	Number of swapped pairs: 0/124750	nDCG: 0.8929
epoch: 2.	Number of swapped pairs: 0/124750	nDCG: 0.9058
epoch: 2.	Number of swapped pairs: 0/124750	nDCG: 0.9174
epoch: 2.	Number of swapped pairs: 0/124750	nDCG: 0.9301
epoch: 2.	Number of swapped pairs: 0/124750	nDCG: 0.9403
epoch: 2.	Number of swapped pairs: 0/124750	nDCG: 0.9482
epoch: 2.	Number of swapped pairs: 0/124750	nDCG: 0.9535
