In [21]:
from itertools import combinations
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd


df = pd.read_csv('train_cleaned_small.csv')
df_val = pd.read_csv('val_cleaned_small.csv')

In [22]:
def dcg(scores):
    scores = np.array(scores,dtype = float)
    num = 2**scores-1
    for i in range(len(num)):
        num[i] /= np.log2(i+2)
    return np.sum(num)


def ndcg_k(scores, k):
    top_k = scores[:k]
    ideal_top_k = sorted(scores)[::-1][:k]
    ndcg = dcg(top_k)
    indcg = dcg(ideal_top_k)
    return ndcg/indcg

In [23]:
def clean_data(x):
    query_ids = sorted(list(set(x['query_id'])))
    ys_train = np.array(x[x['query_id']==query_ids[0]]['relevance_label'].tolist())
    ys_train_final = []
    ys_train_final.append(ys_train)
    for i in range(len(query_ids)):
        if i == 0:
            continue

        y_new = np.array(x[x['query_id']==query_ids[i]]['relevance_label'].tolist())
        ys_train_final.append(y_new)
    

    ys_train = torch.tensor(ys_train_final,dtype=torch.float32)
    
    X = np.array(x[x['query_id']==query_ids[0]].iloc[:,2:])
    x_train_final = []
    x_train_final.append(X)
    for i in range(len(query_ids)):
        if i == 0:
            continue

        x_new = np.array(x[x['query_id']==query_ids[i]].iloc[:,2:])
        x_train_final.append(x_new)
    

    X_train = torch.tensor(np.array(x_train_final),dtype=torch.float32)
    
    return X_train,ys_train


In [25]:
X_train,ys_train = clean_data(df)

In [26]:
import torch
import torch.nn as nn
import torch.optim as optim

class Pointwise(nn.Module):
    def __init__(self, input_dim, hidden_dim, hidden_dim2, output_dim):
        super(Pointwise, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim,hidden_dim2)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_dim2, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return x


In [27]:
class_counts = torch.bincount(torch.tensor(ys_train.flatten(),dtype=torch.int64))
class_weights = 1.0 / class_counts.float()

  class_counts = torch.bincount(torch.tensor(ys_train.flatten(),dtype=torch.int64))


In [33]:
class_weights = class_weights / class_weights.sum()
class_weights

tensor([0.0065, 0.0146, 0.0319, 0.2841, 0.6629])

In [37]:
df['relevance_label'].value_counts()

0    5489
1    2458
2    1123
3     126
4      54
Name: relevance_label, dtype: int64

In [38]:
input_dim = 136
hidden_dim = 512
hidden_dim2 = 256
output_dim = 5
learning_rate = 0.01
num_epochs = 1

# Initialize model and optimizer
model = Pointwise(input_dim, hidden_dim,hidden_dim2, output_dim)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss(weight = class_weights)


# Train model
for q in range(X_train.shape[0]): ## For each query
    X = X_train[q]
    Y = torch.tensor(ys_train[q],dtype = torch.long)
    for epoch in range(1):
        optimizer.zero_grad()
        output = model(X)
        loss = loss_fn(output,Y)
        loss.backward()
        optimizer.step()
            
    

  Y = torch.tensor(ys_train[q],dtype = torch.long)


In [39]:
n = X_train.shape[0]
ndcg_list = []
max_ndcg = 0
for i in range(n): ## over queries
    scores_list = []
    Y = ys_train[i]
    for j in range(X_train[i].shape[0]):  ## over documents
        output = torch.argmax(model(X_train[i][j]))
        scores_list.append(float(output))
    rank_pred = np.argsort(scores_list)[::-1].tolist()
    rank_score = Y[rank_pred]
    ndcg = ndcg_k(rank_score, 10)
    ndcg_list.append(ndcg)

tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)


tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)


tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)


tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)


tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)


tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)


tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)


  return ndcg/indcg


tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)


tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)


tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(2)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(2)
tensor(2)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)


In [40]:
np.nanmean(ndcg_list)

0.3278127127822598