In [1]:
import torch

### RankNet implementation

In [2]:
class RankNet(torch.nn.Module):
    def __init__(self, num_input_features, hidden_dim=10):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.model = torch.nn.Sequential(
            torch.nn.Linear(num_input_features, self.hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(self.hidden_dim, 1)
        )
        
        self.out_activation = torch.nn.Sigmoid()
        
    def forward(self, input_1, input_2):
        logits_1 = self.predict(input_1)
        logits_2 = self.predict(input_2)
        
        logits_diff = logits_1 - logits_2
        out = self.out_activation(logits_diff)
        
        return out
    
    def predict(self, inp):
        logits = self.model(inp)
        return logits

In [3]:
ranknet_model = RankNet(num_input_features=10)

In [4]:
inp_1, inp_2 = torch.rand(4, 10), torch.rand(4, 10)
# batch_size x input_dim

In [5]:
preds = ranknet_model(inp_1, inp_2)
preds

tensor([[0.5412],
        [0.5100],
        [0.5285],
        [0.4553]], grad_fn=<SigmoidBackward0>)

In [6]:
first_linear_layer = ranknet_model.model[0]

In [9]:
first_linear_layer.weight.grad

In [10]:
criterion = torch.nn.BCELoss()
loss = criterion(preds, torch.ones_like(preds))
loss.backward()

In [11]:
first_linear_layer.weight.grad

tensor([[-0.0415, -0.0078, -0.0061,  0.0021,  0.0235, -0.0311, -0.0283,  0.0069,
         -0.0027, -0.0071],
        [ 0.0221, -0.0022, -0.0170, -0.0241, -0.0385, -0.0046, -0.0084, -0.0124,
         -0.0289,  0.0154],
        [-0.0672, -0.0125, -0.0098,  0.0035,  0.0379, -0.0504, -0.0458,  0.0112,
         -0.0044, -0.0115],
        [-0.0093,  0.0084,  0.0110, -0.0041,  0.0143, -0.0028,  0.0029,  0.0048,
          0.0039,  0.0021],
        [ 0.0128,  0.0024,  0.0019, -0.0007, -0.0072,  0.0096,  0.0087, -0.0021,
          0.0008,  0.0022],
        [ 0.0420,  0.0281,  0.0188,  0.0062, -0.0091,  0.0371,  0.0296,  0.0127,
          0.0197,  0.0173],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0357,  0.0067,  0.0052, -0.0018, -0.0202,  0.0268,  0.0243, -0.0060,
          0.0023,  0.0061],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  

In [12]:
ranknet_model.zero_grad()

### ListNet

In [71]:
from torch import Tensor
from itertools import combinations
import numpy as np

def compute_gain(y_value: float, gain_scheme: str) -> float:
    if gain_scheme == 'exp2':
        gain = 2**y_value -  1
    elif gain_scheme == 'const':
        gain = y_value
    else:
        raise ValueError(f'{gain_scheme} method not supported, only exp2 and const')
    return gain

from math import log2

def dcg(ys_true: Tensor, ys_pred: Tensor, gain_scheme: str) -> float:
    _, argsort = sort(ys_pred, descending=True,dim=0)
    ys_true_sorted = ys_true[argsort]
    ret = 0
    for idx, cur_y in enumerate(ys_true_sorted, 1):
        gain = compute_gain(cur_y.item(), gain_scheme)
        ret += gain / log2(idx + 1)
    return ret

def num_swapped_pairs(ys_true: Tensor, ys_pred: Tensor) -> int:
    
    ys_pred_sorted, argsort = sort(ys_pred, descending=True, dim = 0)
    ys_true_sorted = ys_true[argsort]
    
    num_objects = ys_true_sorted.shape[0]
    swapped_cnt = 0
    for curr_obj in range(num_objects-1):
        for next_obj in range(curr_obj+1, num_objects):
            if ys_true_sorted[curr_obj] < ys_true_sorted[next_obj]:
                if ys_pred_sorted[curr_obj] > ys_pred_sorted[next_obj]:
                    swapped_cnt += 1
            #elif ys_true_sorted[cur_obj] > ys_true_sorted[next_obj]:
            #    if ys_pred_sorted[curr_obj] < ys_pred_sorted[next_obj]:
            #        swapped_cnt += 1
    
    return swapped_cnt

def ndcg(ys_true: Tensor,  ys_pred: Tensor, gain_scheme: str = 'const') -> float:
    dcg_ = dcg(ys_true, ys_pred, gain_scheme)
    ideal_dcg =  dcg(ys_true, ys_true, gain_scheme)
    return dcg_ / ideal_dcg

In [67]:
class ListNet(torch.nn.Module):
    def __init__(self, num_input_features, hidden_dim=10):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.model = torch.nn.Sequential(
            torch.nn.Linear(num_input_features, self.hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(self.hidden_dim, 1)
        )
        
    def forward(self, input_1):
        logits = self.model(input_1)
        return logits

In [40]:
test_arr = np.array([-2, 0.7, 2, 3])
np.digitize(test_arr, [-1, 1])

array([0, 1, 2, 2])

In [41]:
def listnet_ce_loss(y_i, z_i):
    '''
    y_i: (n_i, 1) GT
    z_i: (n_i, 1) preds
    '''
    P_y_i = torch.softmax(y_i, dim=0)
    P_z_i = torch.softmax(z_i, dim=0)
    return -torch.sum(P_y_i * torch.log(P_z_i/P_y_i))

def listnet_kl_loss(y_i, z_i):
    '''
    y_i: (n_i, 1) GT
    z_i: (n_i, 1) preds
    '''
    P_y_i = torch.softmax(y_i, dim=0)
    P_z_i = torch.softmax(z_i, dim=0)
    return -torch.sum(P_y_i * torch.log(P_z_i/P_y_i))

def make_dataset(N_train, N_valid, vector_dim):
    fake_weights = torch.randn(vector_dim, 1)
    
    X_train = torch.randn(N_train, vector_dim)
    X_valid = torch.randn(N_valid, vector_dim)
    
    ys_train_score = torch.mm(X_train, fake_weights)
    ys_train_score += torch.randn_like(ys_train_score)
    
    ys_valid_score = torch.mm(X_train, fake_weights)
    ys_valid_score += torch.randn_like(ys_valid_score)
    
    bins = [-1, 1] # 3 relevances
    # bins = [-1, 0, 1, 2] - 5 relevances
    ys_train_rel = torch.Tensor(
         np.digitize(ys_train_score.clone().detach().numpy(), bins=bins)
    )
    ys_valid_rel = torch.Tensor(
        np.digitize(ys_valid_score.clone().detach().numpy(), bins=bins)
    )
    
    return X_train, X_valid, ys_train_rel, ys_valid_rel

In [42]:
N_train = 1000
N_valid = 500

vector_dim = 100
epochs = 2

batch_size = 16

X_train, X_valid, ys_train, ys_valid = make_dataset(N_train, N_valid, vector_dim)

net = ListNet(num_input_features=vector_dim)
opt = torch.optim.Adam(net.parameters())

In [43]:
torch.unique(ys_train)

tensor([0., 1., 2.])

In [44]:
torch.randperm(3)

tensor([2, 1, 0])

In [45]:
from torch import sort

In [72]:
for epoch in range(epochs):
    idx = torch.randperm(N_train)
    
    X_train = X_train[idx]
    ys_train = ys_train[idx]
    
    cur_batch=0
    for it in range(N_train // batch_size):
        batch_X = X_train[cur_batch:cur_batch + batch_size]
        batch_ys = ys_train[cur_batch: cur_batch + batch_size]
        cur_batch += batch_size
        
        opt.zero_grad()
        if len(batch_X) > 0:
            batch_pred = net(batch_X)
            batch_loss = listnet_ce_loss(batch_ys, batch_pred)
            batch_loss.backward(retain_graph=True)
            opt.step()
            
        if it % 10 == 0:
            with torch.no_grad():
                valid_pred = net(X_valid)
                valid_swapped_pairs = num_swapped_pairs(ys_valid, valid_pred)
                ndcg_score = ndcg(ys_valid, valid_pred)
            print(f'epoch: {epoch+1}.\tNumber of swapped pairs: ',
                 f'{valid_swapped_pairs}/{N_valid*(N_valid-1) // 2}\t',
                 f'nDCG: {ndcg_score:.4f}')

epoch: 1.	Number of swapped pairs:  35553/124750	 nDCG: 0.4751
epoch: 1.	Number of swapped pairs:  35518/124750	 nDCG: 0.4739
epoch: 1.	Number of swapped pairs:  35376/124750	 nDCG: 0.4741
epoch: 1.	Number of swapped pairs:  35332/124750	 nDCG: 0.4735
epoch: 1.	Number of swapped pairs:  35125/124750	 nDCG: 0.4759
epoch: 1.	Number of swapped pairs:  35109/124750	 nDCG: 0.4763
epoch: 1.	Number of swapped pairs:  35079/124750	 nDCG: 0.4766
epoch: 2.	Number of swapped pairs:  35089/124750	 nDCG: 0.4766
epoch: 2.	Number of swapped pairs:  35092/124750	 nDCG: 0.4762
epoch: 2.	Number of swapped pairs:  35145/124750	 nDCG: 0.4756
epoch: 2.	Number of swapped pairs:  35119/124750	 nDCG: 0.4758
epoch: 2.	Number of swapped pairs:  35037/124750	 nDCG: 0.4760
epoch: 2.	Number of swapped pairs:  34953/124750	 nDCG: 0.4758
epoch: 2.	Number of swapped pairs:  34848/124750	 nDCG: 0.4759
