In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [2]:
import re, string, unicodedata
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import pandas as pd
import numpy as np
# import smart_open
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

import math
from math import log
from numpy.random import default_rng, rand


In [226]:
train_data = pd.read_csv("../part2/train_data.tsv", sep='\t')
validation_data = pd.read_csv("../part2/validation_data.tsv", sep='\t')


In [227]:
train_data.head(10)

Unnamed: 0,qid,pid,queries,passage,relevancy
0,188714,1000052,foods and supplements to lower blood sugar,Watch portion sizes: ■ Even healthy foods will...,0.0
1,995526,1000094,where is the federal penitentiary in ind,It takes THOUSANDS of Macy's associates to bri...,0.0
2,660957,1000115,what foods are good if you have gout?,The good news is that you will discover what g...,0.0
3,837202,1000252,what is the nutritional value of oatmeal,"Oats make an easy, balanced breakfast. One cup...",0.0
4,130825,1000268,definition for daring,Such a requirement would have three desirable ...,0.0
5,408149,1000288,is dhgate a scam,If you think you ve been targeted by a counter...,0.0
6,1019649,1000419,what study for mets to brain,Sorry he's having so much pain. The reason tha...,0.0
7,1099065,1000436,how far deep to plant beet early wonder,"The simplest way, and my preference, is to roa...",0.0
8,1084910,1000466,what disease do roof rats cause,1 A cage trap baited with peanut butter or a s...,0.0
9,959083,1000479,when was niagara falls created,"Bulbar Onset – ALS. ALS is like Niagara Falls,...",0.0


In [228]:
train_data.shape

(4364339, 5)

In [229]:
validation_data.head(10)

Unnamed: 0,qid,pid,queries,passage,relevancy
0,1082792,1000084,what does the golgi apparatus do to the protei...,"Start studying Bonding, Carbs, Proteins, Lipid...",0.0
1,995825,1000492,where is the graphic card located in the cpu,"For example, a “PC Expansion Card” maybe the j...",0.0
2,995825,1000494,where is the graphic card located in the cpu,The Common Cards & Buses. The most common type...,0.0
3,1091246,1000522,property premises meaning,The occurrence of since tells us that the firs...,0.0
4,1047854,1000585,what is printing mechanism,Windows desktop applications Develop Desktop t...,0.0
5,991832,1000599,who discovered the element carbon,1. 1 a nonmetallic element existing in the th...,0.0
6,185299,1000647,fastest cell phone processor,Tips for calling a cell phone in Greece: To ca...,0.0
7,574730,1000663,what are the three monetary policy tools of th...,"Federal Reserve updates including rates, news ...",0.0
8,1085008,1000675,what did maria theresa do for the serfs,"In this feudal system, the king awarded land g...",0.0
9,609628,1000771,what county is mitchell south dakota in,South Dakota: According to our research of Sou...,0.0


In [230]:
train_data.shape
validation_data.shape

(4364339, 5)

(1103039, 5)

# Preprocessing

In [231]:
def remove_punctuation(tokens):
    new_tokens = []
    for token in tokens:
        new_token = re.sub(r'[^\w\s]', '', token)
        if new_token != '':
            new_tokens.append(new_token)
    return new_tokens

def remove_stopwords(tokens):
    new_tokens = []
    stopword_set = set(stopwords.words('english'))
    for token in tokens:
        if token not in stopword_set:
            new_tokens.append(token)
    return new_tokens

def lemmatize_verbs(tokens):
    lemmatizer = WordNetLemmatizer()
    root_words = []
    for token in tokens:
        root_word = lemmatizer.lemmatize(token, pos='v')
#         root_word = lemmatizer.lemmatize(token, pos='n')
#         root_word = lemmatizer.lemmatize(token, pos='a')
        root_words.append(root_word)
    return root_words

def remove_numbers(tokens):
    """Replace all interger occurrences in list of tokenized words with textual representation"""
    new_tokens = []
    for token in tokens:
        if token.isdigit():
            pass
        else:
            new_tokens.append(token)
    return new_tokens

def preprocessing(passage):
    passage = passage.lower()
    tokens = nltk.word_tokenize(passage)
    tokens = remove_punctuation(tokens)
    tokens = remove_stopwords(tokens)
    tokens = lemmatize_verbs(tokens)
    tokens = remove_numbers(tokens)
    return tokens

# Subtask 1

## BM25

In [163]:
validation_data_no_dup_passages = validation_data.drop_duplicates(subset=['pid'], inplace=False)


In [164]:
# validation_data_no_dup_passages.head(20)
validation_data_no_dup_passages.shape

(955211, 5)

In [165]:
def get_passage_average_length_and_total_word_occurences_corpus():
    number_of_passages = len(validation_data_no_dup_passages)
    count_total_length = 0
    for idx, row in validation_data_no_dup_passages.iterrows():
        count_total_length += len(preprocessing(row['passage']))
    return count_total_length, count_total_length/number_of_passages 


In [166]:
# total_word_occurences, avdl = get_passage_average_length_and_total_word_occurences_corpus()
total_word_occurences = 30757932 # for validation data
avdl = 32.200144261320276 # for validation data

In [167]:
# total_word_occurences
# avdl

In [168]:
from math import log

k1 = 1.2
k2 = 100
b = 0.75
R = 0
r = 0
N = len(validation_data_no_dup_passages)

def K_cal(dl):
    return k1 * ((1-b) + b * (float(dl)/float(avdl)) )


def BM25_cal(query, passage):
    query_tokens = preprocessing(query)
    passage_tokens = preprocessing(passage)
    query_length = len(query_tokens)
    query_token_freq_dict = nltk.FreqDist(query_tokens)
    passage_token_freq_dict = nltk.FreqDist(passage_tokens)
    dl = len(passage_tokens)
    K = K_cal(dl)
    score = 0
    for token in query_tokens:
        try:
            n = len(inverted_index[token])
        except:
            n = 0
        f = passage_token_freq_dict[token]
        qf = query_token_freq_dict[token]
        first_term = log( ( (r + 0.5) / (R - r + 0.5) ) / ( (n - r + 0.5) / (N - n - R + r + 0.5)) )
        second_term = ((k1 + 1) * f) / (K + f)
        third_term = ((k2+1) * qf) / (k2 + qf)
        score += first_term * second_term * third_term
    return score
    

In [169]:
bm25_rankings = []
for idx, row in validation_data.iterrows():
#     print('count:', idx+1)
    query = row['queries']
    passage = row['passage']
    bm25_rankings.append(BM25_cal(query, passage))

    

In [170]:
ranking_k = 100
results_bm25 = np.array(bm25_rankings).argsort()[-ranking_k:][::-1]
results_bm25

array([ 218534, 1076853,   95986,   95987,  831065,  725920,  504442,
        654647,  341262,  281695,  950916,  586050,  871938,  793356,
        944762,  714508,  900185,  206973,  629358,  701682,  248618,
        259324,  702792,  286327,  205094,   98221,  519566,  242694,
        705667,  922032,  197484,  431189,  419681,  237993,  502543,
        793493,  646679,  559150,  303774,  638714,  778563,  155707,
         45281,  132857,  213755,  299424,  252540,    9338, 1075869,
        139233,  733889,  401291,  373812,  784890, 1008311,  911301,
        909449,  558879,  857483,  602553,  171827,  217514,  857646,
        539660,  195886, 1078616,  554316,  872234,  346989,   28250,
       1026344,  330458, 1026153,  612051,  395691,  530538,  313678,
        693646,  175037,  751525,  216400,  663797,   99840,  928996,
          5558,  428329,  546185,   46827,  112665,  118094,  752555,
       1017110,  455444,  910607,  648719,  793681, 1063650,  212436,
        415429,  654

In [171]:
ranking_list_df = validation_data.loc[results_bm25]
ranking_list_df

Unnamed: 0,qid,pid,queries,passage,relevancy
218534,1007691,7251254,"when allocating service department costs, the ...",Direct method allocates each service departmen...,1.0
1076853,1007691,7251259,"when allocating service department costs, the ...",The direct method is the most widely-used meth...,0.0
95986,1007691,7251251,"when allocating service department costs, the ...",service department provides a large amount of ...,0.0
95987,1007691,7251253,"when allocating service department costs, the ...","The rows sum to 100%, so that all services pro...",0.0
831065,1007691,7251255,"when allocating service department costs, the ...",The most defensible sequence is to start with ...,0.0
...,...,...,...,...,...
793681,1007691,4814576,"when allocating service department costs, the ...",Service Members | Veterans | Both. Military On...,0.0
1063650,1007691,6395207,"when allocating service department costs, the ...",that hospital emergency department services ar...,0.0
212436,1007691,6872353,"when allocating service department costs, the ...",Yelp Customer Service customer service phone n...,0.0
415429,1007691,4114248,"when allocating service department costs, the ...","A service fee, service charge, or surcharge is...",0.0


In [172]:
validation_data[(validation_data['relevancy'] < 1.0) & (validation_data['relevancy'] > 0)]


Unnamed: 0,qid,pid,queries,passage,relevancy


In [149]:
def average_precision_cal(ranking_list_df):
    ranking_list_df = ranking_list_df.reset_index(drop=True, inplace=False)
    total_relevant_retrieved = 0
    precision_sum = 0
    for idx, row in ranking_list_df.iterrows():
        relevancy = row['relevancy']
        if (relevancy):
#             isRelevant = True
            total_relevant_retrieved += 1
            precision = total_relevant_retrieved / (idx + 1)
            precision_sum += precision
    result = precision_sum / len(ranking_list_df)
    return result

In [174]:
average_precision_cal(ranking_list_df)

0.011818181818181818

In [175]:
ranking_list_df.sort_values(by=['relevancy'], ascending=False)


Unnamed: 0,qid,pid,queries,passage,relevancy
218534,1007691,7251254,"when allocating service department costs, the ...",Direct method allocates each service departmen...,1.0
950916,1089945,7079883,the __________ test is a quick and dirty test ...,• The Smell Test is familiar ground in most bu...,1.0
539660,1007691,423230,"when allocating service department costs, the ...",IT Service (ITILv3): A Service provided to one...,0.0
612051,1007691,994382,"when allocating service department costs, the ...",All UK telephone numbers beginning with the di...,0.0
1026153,1007691,3941750,"when allocating service department costs, the ...",Respite (Out-of-Home) Services [edit]. Respite...,0.0
...,...,...,...,...,...
431189,1007691,5146501,"when allocating service department costs, the ...",The Department offers service coordination and...,0.0
197484,1007691,5904987,"when allocating service department costs, the ...",If you are a delinquent juror who has been ins...,0.0
922032,1007691,5220119,"when allocating service department costs, the ...",There are different types of customer service ...,0.0
705667,1007691,7088169,"when allocating service department costs, the ...",Us Postal Service Customer Service Phone Numbe...,0.0


In [176]:
ranking_list_df

Unnamed: 0,qid,pid,queries,passage,relevancy
218534,1007691,7251254,"when allocating service department costs, the ...",Direct method allocates each service departmen...,1.0
1076853,1007691,7251259,"when allocating service department costs, the ...",The direct method is the most widely-used meth...,0.0
95986,1007691,7251251,"when allocating service department costs, the ...",service department provides a large amount of ...,0.0
95987,1007691,7251253,"when allocating service department costs, the ...","The rows sum to 100%, so that all services pro...",0.0
831065,1007691,7251255,"when allocating service department costs, the ...",The most defensible sequence is to start with ...,0.0
...,...,...,...,...,...
793681,1007691,4814576,"when allocating service department costs, the ...",Service Members | Veterans | Both. Military On...,0.0
1063650,1007691,6395207,"when allocating service department costs, the ...",that hospital emergency department services ar...,0.0
212436,1007691,6872353,"when allocating service department costs, the ...",Yelp Customer Service customer service phone n...,0.0
415429,1007691,4114248,"when allocating service department costs, the ...","A service fee, service charge, or surcharge is...",0.0


In [150]:
def get_IDCG(ranking_list_df):
    ranking_list_df_sorted = ranking_list_df.sort_values(by=['relevancy'], ascending=False)
#     ranking_list_df_sorted = ranking_list_df.reset_index(drop=True, inplace=False)
    ranking_list_df_sorted = ranking_list_df_sorted.reset_index().reindex(ranking_list_df_sorted.columns, axis=1)
    ideal_discounted_gain_sum = 0
    for idx, row in ranking_list_df_sorted.iterrows():
        index = idx + 1
        relevance_score = row['relevancy']
        gain = 2 ** relevance_score - 1
        discounted_gain = gain / math.log2(index + 1)
        ideal_discounted_gain_sum += discounted_gain
    return ideal_discounted_gain_sum

In [151]:
def get_NDCG(ranking_list_df):
    ranking_list_df = ranking_list_df.reset_index(drop=True, inplace=False)
    discounted_gain_sum = 0
    for idx, row in ranking_list_df.iterrows():
        index = idx + 1
        relevance_score = row['relevancy']
        gain = 2 ** relevance_score - 1
        discounted_gain = gain / math.log2(index + 1)
        discounted_gain_sum += discounted_gain
    IDCG = get_IDCG(ranking_list_df)
    
    nDCG = discounted_gain_sum / IDCG
    return nDCG

In [179]:
get_NDCG(ranking_list_df)

0.7841802768331765

# Subtask 2

In [22]:
# train_data = pd.read_csv("../part2/train_data.tsv", sep='\t')
# validation_data = pd.read_csv("../part2/validation_data.tsv", sep='\t')

# Under Sampling

In [232]:
import random

def sample_together(n, X, y):
    rows = random.sample(np.arange(0,len(X.index)).tolist(),n)
    rows_major = random.sample(np.arange(0,len(y.index)).tolist(),n * 2)
    return X.iloc[rows,], y.iloc[rows_major,]

In [233]:
def undersample(X, y, under = 1):
    y_min = y[y.relevancy == under]
    y_max = y[y.relevancy != under]
    X_min = X.filter(y_min.index,axis = 0)
    X_max = X.filter(y_max.index,axis = 0)

    X_under, y_under = sample_together(len(y_min.index), X_max, y_max)
    
    X = pd.concat([X_under, X_min])
    y = pd.concat([y_under, y_min])
    return X, y



In [234]:
train_data[['relevancy']]

Unnamed: 0,relevancy
0,0.0
1,0.0
2,0.0
3,0.0
4,0.0
...,...
4364334,0.0
4364335,0.0
4364336,0.0
4364337,0.0


In [235]:
X_train, y_train = undersample(train_data, train_data[['relevancy']])

In [236]:
train_data = X_train
train_data.shape

(9594, 5)

In [237]:
X_val, y_val = undersample(validation_data, validation_data[['relevancy']])

In [238]:
validation_data = X_val
validation_data.shape

(2416, 5)

## Preprocessing

In [239]:
# passages = validation_data.passage.values[:1000]
# validation_data_temp = validation_data[:1000]
# validation_data_temp['passage_cleaned']=validation_data_temp.passage.apply(lambda x: preprocessing(x))
# validation_data_temp['queries_cleaned']=validation_data_temp.queries.apply(lambda x: preprocessing(x))

In [240]:
train_data['passage_cleaned']=train_data.passage.apply(lambda x: preprocessing(x))


In [241]:
train_data['query_cleaned']=train_data.queries.apply(lambda x: preprocessing(x))

In [242]:
train_data

Unnamed: 0,qid,pid,queries,passage,relevancy,passage_cleaned,query_cleaned
1026739,559018,1920711,what are forests sinks for?,The average temperature in temperate deciduous...,0.0,"[average, temperature, temperate, deciduous, f...","[forest, sink]"
624132,999550,3266224,"where is andrews, tx",Andrews Air Force Base Study Area. Andrews Air...,0.0,"[andrews, air, force, base, study, area, andre...","[andrews, tx]"
1837576,632192,7309994,what does annual deductible mean,Another is to set up a child care center at yo...,0.0,"[another, set, child, care, center, business, ...","[annual, deductible, mean]"
350667,868487,6784062,what kind of gems are pink,Boat Fishing. There are two main methods of fi...,0.0,"[boat, fish, two, main, methods, fish, pink, s...","[kind, gems, pink]"
1973560,1006911,1563031,define podiatry,(a) Identify and understand the organization’s...,0.0,"[identify, understand, organization, priority,...","[define, podiatry]"
...,...,...,...,...,...,...,...
4343822,401287,860900,is a written prescription required for hydroco...,Refills of hydrocodone combination products wi...,1.0,"[refill, hydrocodone, combination, products, r...","[write, prescription, require, hydrocodone]"
4352626,541272,876066,was wilson a good president,"Woodrow Wilson (1856-1924), the 28th U.S. pres...",1.0,"[woodrow, wilson, 28th, us, president, serve, ...","[wilson, good, president]"
4356455,845529,882642,what is the salary range of a dentist,Dentist Salary. (United States). The average s...,1.0,"[dentist, salary, unite, state, average, salar...","[salary, range, dentist]"
4360007,850361,926854,what is the temperature in washington,July is the hottest month in Washington DC wit...,1.0,"[july, hottest, month, washington, dc, average...","[temperature, washington]"


In [243]:
validation_data['passage_cleaned']=validation_data.passage.apply(lambda x: preprocessing(x))


In [244]:
validation_data['query_cleaned']=validation_data.queries.apply(lambda x: preprocessing(x))

In [245]:
validation_data

Unnamed: 0,qid,pid,queries,passage,relevancy,passage_cleaned,query_cleaned
907485,1003015,4247300,when was ft mcnair established?,Canadian football has a larger field (450 ft o...,0.0,"[canadian, football, larger, field, ft, yards,...","[ft, mcnair, establish]"
1011693,1097461,3005820,how many calories in cooked roast beef lean,Cooked ground beef that's 70-percent lean prov...,0.0,"[cook, grind, beef, 70percent, lean, provide, ...","[many, calories, cook, roast, beef, lean]"
839197,86624,7755924,causes of getting winded easily,"Matthew Gray Gubler lists his top babes, remai...",0.0,"[matthew, gray, gubler, list, top, babes, rema...","[cause, get, wind, easily]"
107533,910375,7985238,what type of automatic transmission does the f...,Buying a used 2008 Ford Expedition transmissio...,0.0,"[buy, use, ford, expedition, transmission, non...","[type, automatic, transmission, ferrari, f430,..."
68786,718444,5486308,what is another word for comma,Re: keyboard question. I see the comma that yo...,0.0,"[keyboard, question, see, comma, speak, press,...","[another, word, comma]"
...,...,...,...,...,...,...,...
1088617,137508,7994418,definition of purse,noun. verb. The definition of a purse is a sma...,1.0,"[noun, verb, definition, purse, small, bag, po...","[definition, purse]"
1088695,1090395,7999912,sterility meaning,sterility noun [U] us / stəˈrɪl·ɪ·t̬i / › the ...,1.0,"[sterility, noun, u, us, stəˈrɪlɪti, condition...","[sterility, mean]"
1088766,291248,8003933,how many ounces of liquid is in one gallon,There are 128 fluid ounces in 1 US gallon. The...,1.0,"[fluid, ounces, us, gallon, fluid, ounces, uk,...","[many, ounces, liquid, one, gallon]"
1088815,20520,8007226,arbitrage definition,Full Definition of ARBITRAGE. 1. : the nearly ...,1.0,"[full, definition, arbitrage, nearly, simultan...","[arbitrage, definition]"


In [247]:
train_data = train_data.sample(frac = 1).reset_index(drop=True, inplace=False)
validation_data = validation_data.sample(frac = 1).reset_index(drop=True, inplace=False)


In [249]:
validation_data

Unnamed: 0,qid,pid,queries,passage,relevancy,passage_cleaned,query_cleaned
0,1047088,7913269,who played peter in the eddy duchin story,Start reading while you listen to a rendition ...,1.0,"[start, read, listen, rendition, take, manhatt...","[play, peter, eddy, duchin, story]"
1,931726,4717882,what's the difference between jambalaya and pa...,This recipe was given to ... the myriad of jam...,0.0,"[recipe, give, myriad, jambalaya, recipes, try...","[difference, jambalaya, paella]"
2,678176,6065658,what is a colonel in the air force,Every Air Force One flight is considered as a ...,0.0,"[every, air, force, one, flight, consider, mil...","[colonel, air, force]"
3,782417,6440626,what is peeling feet causes,Bacterial and fungal infections—referred as At...,1.0,"[bacterial, fungal, infectionsreferred, athlet...","[peel, feet, cause]"
4,1099903,7212883,hiouchi ca is in what county,Hiouchi is a census-designated place in Del No...,1.0,"[hiouchi, censusdesignated, place, del, norte,...","[hiouchi, ca, county]"
...,...,...,...,...,...,...,...
2411,814699,7848421,what is the definition of an erection,Webster Dictionary(0.00 / 0 votes)Rate this de...,1.0,"[webster, dictionary, vote, rate, definition, ...","[definition, erection]"
2412,906692,6533173,what to do in seattle on memorial day,Less than a year after the publication of this...,0.0,"[less, year, publication, comic, columbus, day...","[seattle, memorial, day]"
2413,1008968,7066529,what's the limit on dolphin fish in florida,"Minks swim enough to get the fish they need, b...",0.0,"[minks, swim, enough, get, fish, need, swim, a...","[limit, dolphin, fish, florida]"
2414,215603,4851610,how did the uluru rock form,Some marble also forms by contact metamorphism...,0.0,"[marble, also, form, contact, metamorphism, ho...","[uluru, rock, form]"


## Loading word2vec embedding

In [25]:
# loading pre-trained embeddings, each word is represented as a 300 dimensional vector
import gensim
from gensim.scripts.glove2word2vec import glove2word2vec
W2V_PATH="../glove.6B.300d.txt"





In [26]:
import gensim
import os
import shutil
import hashlib
from sys import platform

def getFileLineNums(filename):
    f = open(filename, 'r', encoding='utf-8')
    count = 0
    for line in f:
        count += 1
    return count


def prepend_line(infile, outfile, line):
    with open(infile, 'r') as old:
        with open(outfile, 'w') as new:
            new.write(str(line) + "\n")
            shutil.copyfileobj(old, new)

def prepend_slow(infile, outfile, line):
    with open(infile, 'r', encoding='utf-8') as fin:
        with open(outfile, 'w', encoding='utf-8') as fout:
            fout.write(line + "\n")
            for line in fin:
                fout.write(line)

def load(filename):
    num_lines = getFileLineNums(filename)
    gensim_file = 'glove_model.txt'
    gensim_first_line = "{} {}".format(num_lines, 300)
    # Prepends the line.
    if platform == "linux" or platform == "linux2":
        prepend_line(filename, gensim_file, gensim_first_line)
    else:
        prepend_slow(filename, gensim_file, gensim_first_line)

    model = gensim.models.KeyedVectors.load_word2vec_format(gensim_file)
    return model
model_w2v = load(W2V_PATH)

In [27]:
model_w2v

<gensim.models.keyedvectors.KeyedVectors at 0x20bbac3d788>

# Embedding documents and queries

In [45]:
# # dictionary with list object in values
# details = {
#     'Name' : ['Ankit', 'Aishwarya', 'Shaurya','Shivangi'],
#     'Age' : [23, 21, 22,21],
#     'University' : ['BHU', 'JNU', 'DU', 'BHU'],
# }
   
# # creating a Dataframe object 
# df = pd.DataFrame(details,columns = ['Name','Age','University'])

In [49]:
# check = [1,2]

In [51]:
# df.drop(df.index[check], inplace=True)

In [52]:
# df

Unnamed: 0,Name,Age,University
0,Ankit,23,BHU
3,Shivangi,21,BHU


In [31]:
# df.reset_index(drop=True, inplace=True)

In [32]:
# df

## removing rows that is not defined in the embeddings

### train data

In [250]:
remove_idx = []
for i, row in train_data.iterrows():
    passage = row['passage_cleaned']
    query = row['query_cleaned']
    passage_length = len(passage)
    query_length = len(query)
    for j in range(passage_length):
        word = passage[j]
        if not word in model_w2v:
            remove_idx.append(i)
            break
    for k in range(query_length):
        word = query[k]
        if not word in model_w2v:
            remove_idx.append(i)
            break


In [251]:
remove_idx = np.array(remove_idx)
remove_idx = np.unique(remove_idx)

In [252]:
remove_idx.shape

(4752,)

In [253]:
train_data.drop(train_data.index[remove_idx], inplace=True)

In [254]:
train_data.reset_index(drop=True, inplace=True)

### validation data

In [255]:
remove_idx = []
for i, row in validation_data.iterrows():
    passage = row['passage_cleaned']
    query = row['query_cleaned']
    passage_length = len(passage)
    query_length = len(query)
    for j in range(passage_length):
        word = passage[j]
        if not word in model_w2v:
            remove_idx.append(i)
            break
    for k in range(query_length):
        word = query[k]
        if not word in model_w2v:
            remove_idx.append(i)
            break


In [256]:
remove_idx = np.array(remove_idx)
remove_idx = np.unique(remove_idx)

In [257]:
validation_data.drop(validation_data.index[remove_idx], inplace=True)

In [258]:
validation_data.reset_index(drop=True, inplace=True)

In [259]:
validation_data

Unnamed: 0,qid,pid,queries,passage,relevancy,passage_cleaned,query_cleaned
0,1047088,7913269,who played peter in the eddy duchin story,Start reading while you listen to a rendition ...,1.0,"[start, read, listen, rendition, take, manhatt...","[play, peter, eddy, duchin, story]"
1,931726,4717882,what's the difference between jambalaya and pa...,This recipe was given to ... the myriad of jam...,0.0,"[recipe, give, myriad, jambalaya, recipes, try...","[difference, jambalaya, paella]"
2,678176,6065658,what is a colonel in the air force,Every Air Force One flight is considered as a ...,0.0,"[every, air, force, one, flight, consider, mil...","[colonel, air, force]"
3,907538,6969920,what to expect with alcohol withdrawal,Delirium tremens is a severe form of alcohol w...,0.0,"[delirium, tremens, severe, form, alcohol, wit...","[expect, alcohol, withdrawal]"
4,1090054,7078842,temperature diode,5. Diode Temperature Sensors . The ordinary se...,1.0,"[diode, temperature, sensors, ordinary, semico...","[temperature, diode]"
...,...,...,...,...,...,...,...
1212,1082448,2254080,what does unequal distribution of wealth mean ...,Although the response is more fully developed ...,0.0,"[although, response, fully, develop, jacob, ri...","[unequal, distribution, wealth, mean, great, d..."
1213,814699,7848421,what is the definition of an erection,Webster Dictionary(0.00 / 0 votes)Rate this de...,1.0,"[webster, dictionary, vote, rate, definition, ...","[definition, erection]"
1214,906692,6533173,what to do in seattle on memorial day,Less than a year after the publication of this...,0.0,"[less, year, publication, comic, columbus, day...","[seattle, memorial, day]"
1215,1008968,7066529,what's the limit on dolphin fish in florida,"Minks swim enough to get the fish they need, b...",0.0,"[minks, swim, enough, get, fish, need, swim, a...","[limit, dolphin, fish, florida]"


## embedding documents

### training data

In [260]:
# tokenize and pad every document to make them of the same size
# passage_tokenizer=Tokenizer()
# passage_tokenizer.fit_on_texts(train_data.passage_cleaned)
passage_max_length = 128 # document length including padding
query_max_length = 64 # query length including padding

In [261]:
train_data

Unnamed: 0,qid,pid,queries,passage,relevancy,passage_cleaned,query_cleaned
0,985259,5805959,definition of private nuisance,Allergic rhinitis: Your nose knows. It's easy ...,0.0,"[allergic, rhinitis, nose, know, easy, dismiss...","[definition, private, nuisance]"
1,995756,1602468,definition of ally,"ally (Verb) To unite, or form a connection bet...",1.0,"[ally, verb, unite, form, connection, families...","[definition, ally]"
2,1088221,7095129,"weather in shirdi, india",Monsoon Season in Shirdi. June marks the begin...,1.0,"[monsoon, season, shirdi, june, mark, begin, m...","[weather, shirdi, india]"
3,897981,4744445,what state is the city evergreen,"Vancouver, Washington. For the Canadian city, ...",0.0,"[vancouver, washington, canadian, city, see, v...","[state, city, evergreen]"
4,203458,1962641,hippo denotation definition,Not to be confused with detonation. Denotation...,0.0,"[confuse, detonation, denotation, translation,...","[hippo, denotation, definition]"
...,...,...,...,...,...,...,...
4837,813953,2737649,what is the current time in lagos nigeria,"What is current time in Christchurch, New Zeal...",0.0,"[current, time, christchurch, new, zealand, ch...","[current, time, lagos, nigeria]"
4838,1100933,7108413,economic function of profits and losses,Question: 50) The economic function of profits...,1.0,"[question, economic, function, profit, losses,...","[economic, function, profit, losses]"
4839,728110,119352,what is carvedilol used for,Carvedilol is used to treat high blood pressur...,1.0,"[carvedilol, use, treat, high, blood, pressure...","[carvedilol, use]"
4840,639545,7794678,what does increment of,Opposite Of Praise; Which Part Of Your Body Do...,0.0,"[opposite, praise, part, body, like, male, fem...",[increment]


In [262]:
train_data_length = train_data.shape[0]
document_word_embeddings_train=np.zeros((train_data_length, passage_max_length,300)) # 64 == padding
passages = train_data.passage_cleaned
for i in range(len(passages)):
    passage = passages[i]
    passage_length = len(passage)
    for j in range(passage_length): 
        word = passage[j]
#         print(word)
#         if word in model_w2v:
        document_word_embeddings_train[i][j] = model_w2v[word]
#         else:
#             remove_idx.append(i)
#             break
document_word_embeddings_train.shape

(4842, 128, 300)

In [263]:
average_vector_list_length = document_word_embeddings_train.shape[0]
average_document_vectors_train = np.zeros((average_vector_list_length,300))
for i in range(average_vector_list_length):
    average_document_vectors_train[i] = np.mean(document_word_embeddings_train[i], axis=0)
average_document_vectors_train.shape

(4842, 300)

### validation data

In [264]:
validation_data_length = validation_data.shape[0]
document_word_embeddings_val=np.zeros((validation_data_length, passage_max_length,300)) # 64 == padding
passages = validation_data.passage_cleaned
for i in range(len(passages)):
    passage = passages[i]
    passage_length = len(passage)
    for j in range(passage_length): 
        word = passage[j]
#         if word in model_w2v:
        document_word_embeddings_val[i][j] = model_w2v[word]
document_word_embeddings_val.shape

(1217, 128, 300)

In [265]:
average_vector_list_length = document_word_embeddings_val.shape[0]
average_document_vectors_val = np.zeros((average_vector_list_length,300))
for i in range(average_vector_list_length):
    average_document_vectors_val[i] = np.mean(document_word_embeddings_val[i], axis=0)
average_document_vectors_val.shape

(1217, 300)

## embedding quries

## 문제점:
1. 특정 단어가 word embedding에 없는 경우가 있음
-> query_vector 의 모든 entry가 0임

-> cosine similarity 계산할 때 nan 이 나옴

temp solution:
cosine similarity 가 0일떄 결과값에 0 을 넣어줌

### training data

In [266]:
train_data_length = train_data.shape[0]
query_word_embeddings_train=np.zeros((train_data_length, query_max_length,300)) # 64 == padding
queries = train_data.query_cleaned
for i in range(len(queries)):
    query = queries[i]
    query_length = len(query)
    for j in range(query_length): 
        word = query[j]
#         if word in model_w2v:
        query_word_embeddings_train[i][j] = model_w2v[word]
query_word_embeddings_train.shape

(4842, 64, 300)

In [267]:
average_vector_list_length = query_word_embeddings_train.shape[0]
average_query_vectors_train = np.zeros((average_vector_list_length,300))
for i in range(average_vector_list_length):
    average_query_vectors_train[i] = np.mean(query_word_embeddings_train[i], axis=0)
average_query_vectors_train.shape

(4842, 300)

### validation data

In [268]:
# tokenize and pad every document to make them of the same size
# query_tokenizer=Tokenizer()
# query_tokenizer.fit_on_texts(validation_data.query_cleaned)


In [269]:
validation_data_length = validation_data.shape[0]
query_word_embeddings_val=np.zeros((validation_data_length, query_max_length,300)) # 64 == padding
queries = validation_data.query_cleaned
for i in range(len(queries)):
    query = queries[i]
    query_length = len(query)
    for j in range(query_length): 
        word = query[j]
#         if word in model_w2v:
        query_word_embeddings_val[i][j] = model_w2v[word]
query_word_embeddings_val.shape

(1217, 64, 300)

In [270]:
average_vector_list_length = query_word_embeddings_val.shape[0]
average_query_vectors_val = np.zeros((average_vector_list_length,300))
for i in range(average_vector_list_length):
    average_query_vectors_val[i] = np.mean(query_word_embeddings_val[i], axis=0)
average_query_vectors_val.shape

(1217, 300)

## creating training data(adding features)

In [271]:
num_of_features = 3 # cosine_sim, bm25, Word Mover's Distance

In [272]:
x_train = np.zeros((average_query_vectors_train.shape[0], num_of_features))
x_val = np.zeros((average_query_vectors_val.shape[0], num_of_features))
x_train.shape

(4842, 3)

In [273]:
x_train[0][0]

0.0

### Cosine Simliarity

In [274]:
def cosine_sim_formula(a, b):
    cos_sim = np.dot(a, b)/(np.linalg.norm(a)*np.linalg.norm(b))
    if np.isnan(cos_sim):
        cos_sim = 0
    return cos_sim

In [275]:
for i in range(x_train.shape[0]):
    query_vector = average_query_vectors_train[i]
    passage_vector = average_document_vectors_train[i]
    x_train[i][0] = cosine_sim_formula(query_vector, passage_vector)
    if np.isnan(x_train[i][0]):
        x_train[i][0] = 0


In [276]:
train_data['co_similarity'] = x_train[:,0]

In [277]:
train_data

Unnamed: 0,qid,pid,queries,passage,relevancy,passage_cleaned,query_cleaned,co_similarity
0,985259,5805959,definition of private nuisance,Allergic rhinitis: Your nose knows. It's easy ...,0.0,"[allergic, rhinitis, nose, know, easy, dismiss...","[definition, private, nuisance]",0.458360
1,995756,1602468,definition of ally,"ally (Verb) To unite, or form a connection bet...",1.0,"[ally, verb, unite, form, connection, families...","[definition, ally]",0.574709
2,1088221,7095129,"weather in shirdi, india",Monsoon Season in Shirdi. June marks the begin...,1.0,"[monsoon, season, shirdi, june, mark, begin, m...","[weather, shirdi, india]",0.567133
3,897981,4744445,what state is the city evergreen,"Vancouver, Washington. For the Canadian city, ...",0.0,"[vancouver, washington, canadian, city, see, v...","[state, city, evergreen]",0.703817
4,203458,1962641,hippo denotation definition,Not to be confused with detonation. Denotation...,0.0,"[confuse, detonation, denotation, translation,...","[hippo, denotation, definition]",0.411904
...,...,...,...,...,...,...,...,...
4837,813953,2737649,what is the current time in lagos nigeria,"What is current time in Christchurch, New Zeal...",0.0,"[current, time, christchurch, new, zealand, ch...","[current, time, lagos, nigeria]",0.583818
4838,1100933,7108413,economic function of profits and losses,Question: 50) The economic function of profits...,1.0,"[question, economic, function, profit, losses,...","[economic, function, profit, losses]",0.860066
4839,728110,119352,what is carvedilol used for,Carvedilol is used to treat high blood pressur...,1.0,"[carvedilol, use, treat, high, blood, pressure...","[carvedilol, use]",0.395157
4840,639545,7794678,what does increment of,Opposite Of Praise; Which Part Of Your Body Do...,0.0,"[opposite, praise, part, body, like, male, fem...",[increment],0.015316


In [278]:
for i in range(len(x_val)):
    query_vector = average_query_vectors_val[i]
    passage_vector = average_document_vectors_val[i]
    x_val[i][0] = cosine_sim_formula(query_vector, passage_vector)
    if np.isnan(x_val[i][0]):
        x_val[i][0] = 0

In [279]:
validation_data['co_similarity'] = x_val[:,0]

In [280]:
validation_data

Unnamed: 0,qid,pid,queries,passage,relevancy,passage_cleaned,query_cleaned,co_similarity
0,1047088,7913269,who played peter in the eddy duchin story,Start reading while you listen to a rendition ...,1.0,"[start, read, listen, rendition, take, manhatt...","[play, peter, eddy, duchin, story]",0.667414
1,931726,4717882,what's the difference between jambalaya and pa...,This recipe was given to ... the myriad of jam...,0.0,"[recipe, give, myriad, jambalaya, recipes, try...","[difference, jambalaya, paella]",0.311800
2,678176,6065658,what is a colonel in the air force,Every Air Force One flight is considered as a ...,0.0,"[every, air, force, one, flight, consider, mil...","[colonel, air, force]",0.794860
3,907538,6969920,what to expect with alcohol withdrawal,Delirium tremens is a severe form of alcohol w...,0.0,"[delirium, tremens, severe, form, alcohol, wit...","[expect, alcohol, withdrawal]",0.735714
4,1090054,7078842,temperature diode,5. Diode Temperature Sensors . The ordinary se...,1.0,"[diode, temperature, sensors, ordinary, semico...","[temperature, diode]",0.762792
...,...,...,...,...,...,...,...,...
1212,1082448,2254080,what does unequal distribution of wealth mean ...,Although the response is more fully developed ...,0.0,"[although, response, fully, develop, jacob, ri...","[unequal, distribution, wealth, mean, great, d...",0.658310
1213,814699,7848421,what is the definition of an erection,Webster Dictionary(0.00 / 0 votes)Rate this de...,1.0,"[webster, dictionary, vote, rate, definition, ...","[definition, erection]",0.495840
1214,906692,6533173,what to do in seattle on memorial day,Less than a year after the publication of this...,0.0,"[less, year, publication, comic, columbus, day...","[seattle, memorial, day]",0.685713
1215,1008968,7066529,what's the limit on dolphin fish in florida,"Minks swim enough to get the fish they need, b...",0.0,"[minks, swim, enough, get, fish, need, swim, a...","[limit, dolphin, fish, florida]",0.669959


In [281]:
y_train = train_data.relevancy.values
y_val = validation_data.relevancy.values

### BM25

In [282]:
train_data_no_dup_passages = train_data.drop_duplicates(subset=['pid'], inplace=False)
validation_data_no_dup_passages = validation_data.drop_duplicates(subset=['pid'], inplace=False)

In [283]:
train_data_no_dup_passages.shape
validation_data_no_dup_passages.shape
N_train = train_data_no_dup_passages.shape[0]
N_val = validation_data_no_dup_passages.shape[0]

(4837, 8)

(1216, 8)

In [284]:
def get_passage_average_length(dataframe_no_dup_passages):
    number_of_passages = validation_data_no_dup_passages.shape[0]
    count_total_length = 0
    for idx, row in dataframe_no_dup_passages.iterrows():
        count_total_length += len(row['passage_cleaned'])
    return count_total_length/number_of_passages 


In [285]:
# total_word_occurences, avdl = get_passage_average_length_and_total_word_occurences_corpus()
avdl_train = get_passage_average_length(train_data_no_dup_passages)
avdl_val = get_passage_average_length(validation_data_no_dup_passages)

In [286]:
from math import log

k1 = 1.2
k2 = 100
b = 0.75
R = 0
r = 0
# N = len(validation_data_no_dup_passages)

def K_cal(dl, avdl):
    return k1 * ((1-b) + b * (float(dl)/float(avdl)) )


def BM25_cal_for_preprocessed_words(query_tokens, passage_tokens, N, avdl):
    query_length = len(query_tokens)
    query_token_freq_dict = nltk.FreqDist(query_tokens)
    passage_token_freq_dict = nltk.FreqDist(passage_tokens)
    dl = len(passage_tokens)
    K = K_cal(dl, avdl)
    score = 0
    for token in query_tokens:
        try:
            n = len(inverted_index[token])
        except:
            n = 0
        f = passage_token_freq_dict[token]
        qf = query_token_freq_dict[token]
        first_term = log( ( (r + 0.5) / (R - r + 0.5) ) / ( (n - r + 0.5) / (N - n - R + r + 0.5)) )
        second_term = ((k1 + 1) * f) / (K + f)
        third_term = ((k2+1) * qf) / (k2 + qf)
        score += first_term * second_term * third_term
    return score
    

In [287]:
for idx, row in train_data.iterrows():
    query_cleaned = row['query_cleaned']
    passage_cleaned = row['passage_cleaned']
    bm25_score = BM25_cal_for_preprocessed_words(query_cleaned, passage_cleaned, N_train, avdl_train)
    x_train[idx][1] = bm25_score

In [288]:
for idx, row in validation_data.iterrows():
    query_cleaned = row['query_cleaned']
    passage_cleaned = row['passage_cleaned']
    bm25_score = BM25_cal_for_preprocessed_words(query_cleaned, passage_cleaned, N_val, avdl_val)
    x_val[idx][1] = bm25_score

In [289]:
train_data['bm25'] = x_train[:,1]
validation_data['bm25'] = x_val[:,1]

In [290]:
train_data

Unnamed: 0,qid,pid,queries,passage,relevancy,passage_cleaned,query_cleaned,co_similarity,bm25
0,985259,5805959,definition of private nuisance,Allergic rhinitis: Your nose knows. It's easy ...,0.0,"[allergic, rhinitis, nose, know, easy, dismiss...","[definition, private, nuisance]",0.458360,13.845754
1,995756,1602468,definition of ally,"ally (Verb) To unite, or form a connection bet...",1.0,"[ally, verb, unite, form, connection, families...","[definition, ally]",0.574709,16.750030
2,1088221,7095129,"weather in shirdi, india",Monsoon Season in Shirdi. June marks the begin...,1.0,"[monsoon, season, shirdi, june, mark, begin, m...","[weather, shirdi, india]",0.567133,18.024893
3,897981,4744445,what state is the city evergreen,"Vancouver, Washington. For the Canadian city, ...",0.0,"[vancouver, washington, canadian, city, see, v...","[state, city, evergreen]",0.703817,33.303577
4,203458,1962641,hippo denotation definition,Not to be confused with detonation. Denotation...,0.0,"[confuse, detonation, denotation, translation,...","[hippo, denotation, definition]",0.411904,16.269547
...,...,...,...,...,...,...,...,...,...
4837,813953,2737649,what is the current time in lagos nigeria,"What is current time in Christchurch, New Zeal...",0.0,"[current, time, christchurch, new, zealand, ch...","[current, time, lagos, nigeria]",0.583818,34.712475
4838,1100933,7108413,economic function of profits and losses,Question: 50) The economic function of profits...,1.0,"[question, economic, function, profit, losses,...","[economic, function, profit, losses]",0.860066,64.461813
4839,728110,119352,what is carvedilol used for,Carvedilol is used to treat high blood pressur...,1.0,"[carvedilol, use, treat, high, blood, pressure...","[carvedilol, use]",0.395157,29.165468
4840,639545,7794678,what does increment of,Opposite Of Praise; Which Part Of Your Body Do...,0.0,"[opposite, praise, part, body, like, male, fem...",[increment],0.015316,13.133287


## Word Mover's Distance

In [291]:
x_train

array([[4.58360448e-01, 1.38457543e+01, 0.00000000e+00],
       [5.74708553e-01, 1.67500303e+01, 0.00000000e+00],
       [5.67133391e-01, 1.80248929e+01, 0.00000000e+00],
       ...,
       [3.95157173e-01, 2.91654682e+01, 0.00000000e+00],
       [1.53157989e-02, 1.31332868e+01, 0.00000000e+00],
       [8.35775371e-01, 3.02564727e+01, 0.00000000e+00]])

In [292]:
for i, row in train_data.iterrows():
    query = row['query_cleaned']
    passage = row['passage_cleaned']
#     print(model_w2v.wmdistance(query, passage))
    x_train[i][2] = model_w2v.wmdistance(query, passage)


In [293]:
x_train

array([[4.58360448e-01, 1.38457543e+01, 1.18668625e+00],
       [5.74708553e-01, 1.67500303e+01, 1.08585083e+00],
       [5.67133391e-01, 1.80248929e+01, 1.06007027e+00],
       ...,
       [3.95157173e-01, 2.91654682e+01, 1.12293186e+00],
       [1.53157989e-02, 1.31332868e+01, 1.38670092e+00],
       [8.35775371e-01, 3.02564727e+01, 9.71023516e-01]])

In [294]:
train_data['word_mover_distance'] = x_train[:,2]

In [295]:
for i, row in validation_data.iterrows():
    query = row['query_cleaned']
    passage = row['passage_cleaned']
#     print(model_w2v.wmdistance(query, passage))
    x_val[i][2] = model_w2v.wmdistance(query, passage)


In [296]:
validation_data['word_mover_distance'] = x_val[:,2]

In [297]:
validation_data

Unnamed: 0,qid,pid,queries,passage,relevancy,passage_cleaned,query_cleaned,co_similarity,bm25,word_mover_distance
0,1047088,7913269,who played peter in the eddy duchin story,Start reading while you listen to a rendition ...,1.0,"[start, read, listen, rendition, take, manhatt...","[play, peter, eddy, duchin, story]",0.667414,34.087137,1.035701
1,931726,4717882,what's the difference between jambalaya and pa...,This recipe was given to ... the myriad of jam...,0.0,"[recipe, give, myriad, jambalaya, recipes, try...","[difference, jambalaya, paella]",0.311800,8.747635,1.218251
2,678176,6065658,what is a colonel in the air force,Every Air Force One flight is considered as a ...,0.0,"[every, air, force, one, flight, consider, mil...","[colonel, air, force]",0.794860,24.430865,0.919003
3,907538,6969920,what to expect with alcohol withdrawal,Delirium tremens is a severe form of alcohol w...,0.0,"[delirium, tremens, severe, form, alcohol, wit...","[expect, alcohol, withdrawal]",0.735714,17.991335,1.058040
4,1090054,7078842,temperature diode,5. Diode Temperature Sensors . The ordinary se...,1.0,"[diode, temperature, sensors, ordinary, semico...","[temperature, diode]",0.762792,24.713307,0.991616
...,...,...,...,...,...,...,...,...,...,...
1212,1082448,2254080,what does unequal distribution of wealth mean ...,Although the response is more fully developed ...,0.0,"[although, response, fully, develop, jacob, ri...","[unequal, distribution, wealth, mean, great, d...",0.658310,17.338328,1.129944
1213,814699,7848421,what is the definition of an erection,Webster Dictionary(0.00 / 0 votes)Rate this de...,1.0,"[webster, dictionary, vote, rate, definition, ...","[definition, erection]",0.495840,15.504570,1.187488
1214,906692,6533173,what to do in seattle on memorial day,Less than a year after the publication of this...,0.0,"[less, year, publication, comic, columbus, day...","[seattle, memorial, day]",0.685713,22.929246,0.994379
1215,1008968,7066529,what's the limit on dolphin fish in florida,"Minks swim enough to get the fish they need, b...",0.0,"[minks, swim, enough, get, fish, need, swim, a...","[limit, dolphin, fish, florida]",0.669959,18.218457,1.082485


## Saving/loading pickle

In [None]:
train_data.to_pickle('train_data.pkl')
validation_data.to_pickle('validation_data.pkl')

## Load

In [80]:
train_data = pd.read_pickle('train_data.pkl')
validation_data = pd.read_pickle('validation_data.pkl')
num_of_features = 3

In [81]:
x_train = np.zeros((train_data.shape[0], num_of_features))

In [112]:
x_train = train_data[['co_similarity', 'bm25', 'word_mover_distance']].values

In [113]:
x_train.shape

(4830, 3)

In [101]:
y_train = train_data['relevancy'].values

In [102]:
x_val = np.zeros((validation_data.shape[0], num_of_features))

In [103]:
x_val = validation_data[['co_similarity', 'bm25', 'word_mover_distance']].values

In [104]:
x_val.shape

(1219, 3)

In [105]:
y_val = validation_data['relevancy'].values
y_val

array([1., 0., 0., ..., 1., 0., 0.])

## Convert data into metrics form

In [114]:
x_train.shape

(4830, 3)

In [115]:
x_train = x_train.reshape(x_train.shape[0],x_train.shape[1], 1)
x_val = x_val.reshape(x_val.shape[0],x_val.shape[1], 1)
x_train.shape
x_val.shape

(4830, 3, 1)

(1219, 3, 1)

In [116]:
# x_train = x_train.squeeze(2)
# x_val = x_val.squeeze(2)
# x_train.shape

## Logisitc Regression

## 현재 문제 
1. 딥러닝 과제처럼, 한 data씩 처리 할건지, 아니면 인터넷 예제처럼 모든 value를 metrics 에 넣어서 한번에 처리할건지
2. 딥러닝 과제처럼 처리했을 경우, gradient descent function은 무엇인지, 인터넷 예제처럼 했을 때도, 인터넷에 있는 gradient descent function 이 어떻게 derive 됬는지 알기

A: 결국엔 둘다 똑같은 방법인데, metrics 로 처리하는 경우는 한번에 모든 데이터를 처리하는 경우이고, 딥러닝 과제는 한 data sampling 을 하나씩 처리하는 방법이다. 결국엔 둘 다 똑같다.

In [130]:
class LogisticRegression:

    def __init__(self):
        self.lr = 0.01
        self.epoch = 100
        self.weights = None
        self.bias = None
        self.loss_train = None
        self.loss_val = None

        self.acc_train = None
        self.train_correct = None
        self.val_correct = None
        
        self.losslist_train = []
        self.losslist_val = []
        self.acclist_train = []
        self.acclist_val = []
    

        
    def fit(self, trainxs, trainys, trainxs_val, trainys_val):
        n_samples, n_features, _ = trainxs.shape
        n_samples_val = trainxs_val.shape[0]

        # init parameters
#         self.weights = np.zeros((1, n_features))
        self.weights = np.random.normal(0,1,(1,n_features))#np.random.randn(1,n_features)#np.zeros((1, n_features))
        self.bias = 0
        
        # gradient descent
        for e in range(self.epoch):
            idx = self.shuffleIdx(trainxs.shape[0])
            X = trainxs[idx]
            Y = trainys[idx]
            
            
            self.loss_train = 0
            self.loss_val = 0

            self.acc_train = 0
            self.train_correct = 0
            self.val_correct = 0
            total_weight_gradient1 = 0
            total_weight_gradient2 = 0
            total_weight_gradient3 = 0
            
            total_bias_gradient = 0
            
            for i in range(trainxs.shape[0]):
                x = X[i]
                y = Y[i]
                # approximate output variable (y) with linear combination of weights and x, plus bias
                linear_equation = np.dot(self.weights, x)[0][0] + self.bias
                # apply sigmoid function
                prediction = self.sigmoid(linear_equation)
#                 print("prediction:", prediction)
                if prediction >= 0.5:
                    yprime = 1
                else:
                    yprime = 0

                if yprime == y:
                    self.train_correct += 1

                self.loss_train += self.loss_function(y, prediction)
#                 print("x[2]:", x[2])
                # compute gradients
                dw1 = (prediction - y) * x[0]  #derivative w.r.t weight1
                dw2 = (prediction - y) * x[1] #derivative w.r.t weight2
                dw3 = (prediction - y) * x[2] #derivative w.r.t weight3
                
                db = prediction - y  #derivative w.r.t bias
                    
                    
                total_weight_gradient1 += dw1
                total_weight_gradient2 += dw2
                total_weight_gradient3 += dw3
                
                total_bias_gradient += db
               
                 # COMPUTING LOSS AND ACCURACY OF VALIDATION SET
                if (i < trainxs_val.shape[0]):
                    val_x = trainxs_val[i]
                    val_y = trainys_val[i]
                    linear_equation = np.dot(self.weights, val_x)[0][0] + self.bias
                    prediction = self.sigmoid(linear_equation)

                    if prediction  >= 0.5:
                        yprime = 1
                    else:
                        yprime = 0

                    if yprime == val_y:
                        self.val_correct += 1

                    self.loss_val += ((val_y - prediction)**2)/2

                    
            # COMPUTING THE AVERAGE OF GRADIENTS FOR EACH EPOCH SINCE WE ARE DOING        
            # FULL-BATCH GRADIENT DESCEND AND UPDATING THE PARAMETERS AFTER EACH EPOCH
            total_weight_gradient1 = total_weight_gradient1 / n_samples
            total_weight_gradient2 = total_weight_gradient2 / n_samples
            total_weight_gradient3 = total_weight_gradient3 / n_samples
            
            total_bias_gradient = total_bias_gradient / n_samples

            # update parameters
#             print(total_weight_gradient1,total_weight_gradient2,total_weight_gradient3)
            self.weights[0][0] -= self.lr * total_weight_gradient1
            self.weights[0][1] -= self.lr * total_weight_gradient2
            self.weights[0][2] -= self.lr * total_weight_gradient3
            
            self.bias -= self.lr * total_bias_gradient
            
            self.loss_train = self.loss_train/n_samples
            self.losslist_train.append(self.loss_train)

            self.loss_val = self.loss_val/(n_samples_val)
            self.losslist_val.append(self.loss_val)

            self.train_correct = self.train_correct/n_samples
            self.acclist_train.append(self.train_correct)

            self.val_correct = self.val_correct/(n_samples_val)
            self.acclist_val.append(self.val_correct)
            
            print("- Loss on Training / Validation Data at Epoch {}: {:.4f} / {:.4f}".format(e+1, self.loss_train, self.loss_val)) 
            print("- Accuracy on Training / Validation Data at Epoch {}: {} / {}\n".format(e+1, self.train_correct, self.val_correct))

            
    def shuffleIdx(self, n):
        rng = default_rng()
        rand_idx = rng.permutation(n)
        return rand_idx

    def predict(self, X):
#         linear_equation = np.dot(X, self.weights) + self.bias
        y_predicted_cls = np.zeros(X.shape[0])
        for i in range(X.shape[0]):
            x = X[i]
            linear_equation = np.dot(self.weights, x)[0][0] + self.bias
            y_predicted = self.sigmoid(linear_equation)
#             print("y_predicted:", y_predicted)
            y_predicted_cls[i] = y_predicted
#             print(y_predicted)
#             if y_predicted > 0.5:
#                 y_predicted_cls[i] = 1
#             else:
#                 y_predicted_cls[i] = 0
#             y_predicted_cls = [1 if i > 0.5 else 0 for i in y_predicted]
        return y_predicted_cls

    def sigmoid(self, x):
        return 1 / (1 + np.exp(-x))
    
    def loss_function(self, y, prediction):
        prediction = np.clip(prediction,1e-3,1-1e-3)
        ls = -log((1 - prediction)**(1 - y)) - log(prediction**y)
        return ls

In [131]:
lr = LogisticRegression()
lr.fit(x_train, y_train, x_val, y_val)
# predictions = regressor.predict(xtest)


- Loss on Training / Validation Data at Epoch 1: 3.6302 / 0.2604
- Accuracy on Training / Validation Data at Epoch 1: 0.47329192546583854 / 0.4782608695652174

- Loss on Training / Validation Data at Epoch 2: 3.6302 / 0.2604
- Accuracy on Training / Validation Data at Epoch 2: 0.47329192546583854 / 0.4782608695652174

- Loss on Training / Validation Data at Epoch 3: 3.6302 / 0.2604
- Accuracy on Training / Validation Data at Epoch 3: 0.47329192546583854 / 0.4782608695652174

- Loss on Training / Validation Data at Epoch 4: 3.6302 / 0.2604
- Accuracy on Training / Validation Data at Epoch 4: 0.47329192546583854 / 0.4782608695652174

- Loss on Training / Validation Data at Epoch 5: 3.6302 / 0.2604
- Accuracy on Training / Validation Data at Epoch 5: 0.47329192546583854 / 0.4782608695652174

- Loss on Training / Validation Data at Epoch 6: 3.6302 / 0.2602
- Accuracy on Training / Validation Data at Epoch 6: 0.47329192546583854 / 0.4782608695652174

- Loss on Training / Validation Data at 

- Loss on Training / Validation Data at Epoch 52: 0.8573 / 0.1447
- Accuracy on Training / Validation Data at Epoch 52: 0.44451345755693583 / 0.34864643150123054

- Loss on Training / Validation Data at Epoch 53: 1.2494 / 0.1752
- Accuracy on Training / Validation Data at Epoch 53: 0.5252587991718427 / 0.5209187858900738

- Loss on Training / Validation Data at Epoch 54: 1.0094 / 0.1583
- Accuracy on Training / Validation Data at Epoch 54: 0.47329192546583854 / 0.4364232977850697

- Loss on Training / Validation Data at Epoch 55: 1.4706 / 0.1877
- Accuracy on Training / Validation Data at Epoch 55: 0.5252587991718427 / 0.5209187858900738

- Loss on Training / Validation Data at Epoch 56: 0.8556 / 0.1445
- Accuracy on Training / Validation Data at Epoch 56: 0.44761904761904764 / 0.3494667760459393

- Loss on Training / Validation Data at Epoch 57: 1.2473 / 0.1748
- Accuracy on Training / Validation Data at Epoch 57: 0.5252587991718427 / 0.5209187858900738

- Loss on Training / Validatio

In [132]:
lr.bias

-0.06096328807334777

In [133]:
lr.weights.shape

(1, 3)

In [134]:
lr.predict(x_val)

array([0.89262863, 0.83402486, 0.89245691, ..., 0.96180768, 0.80324584,
       0.7697553 ])

In [139]:
validation_data_lr = validation_data.copy(deep=True)

In [145]:
validation_data_lr['score'] = lr.predict(x_val)
validation_data_lr


Unnamed: 0,qid,pid,queries,passage,relevancy,passage_cleaned,query_cleaned,co_similarity,bm25,word_mover_distance,score
0,467612,7552919,oasis hospital contact number,Contact Us Contact Us If you are unable to cal...,1.0,"[contact, us, contact, us, unable, call, oasis...","[oasis, hospital, contact, number]",0.683691,24.411400,1.050934,0.892629
1,1046952,7318074,who played luke on general hospital,It was in the oral and facial unit and it was ...,0.0,"[oral, facial, unit, dentists, room, set, like...","[play, luke, general, hospital]",0.637770,17.520037,1.099775,0.834025
2,1062190,5300247,why did bob marley die of cancer,exists and is an alternate of . Napoleon Bonap...,0.0,"[exist, alternate, napoleon, bonaparte, die, b...","[bob, marley, die, cancer]",0.572388,25.229735,0.938711,0.892457
3,946825,7963463,when does la liga espanola start,Traditional new La Liga season starts on the t...,1.0,"[traditional, new, la, liga, season, start, th...","[la, liga, espanola, start]",0.488907,26.889117,1.041251,0.902010
4,412982,6020416,is indirect labor an overhead cost,Ans: d Response A: Materials requisition slips...,0.0,"[ans, response, materials, requisition, slip, ...","[indirect, labor, overhead, cost]",0.709268,27.235843,1.016331,0.910881
...,...,...,...,...,...,...,...,...,...,...,...
1214,922335,7811374,what was the reason for the snake war,The Snake War has been widely forgotten in Uni...,1.0,"[snake, war, widely, forget, unite, state, his...","[reason, snake, war]",0.775045,26.856588,0.991801,0.910384
1215,341039,7330977,how tall is kate mckinnon,The list of complete Kate McKinnon body measur...,1.0,"[list, complete, kate, mckinnon, body, measure...","[tall, kate, mckinnon]",0.379143,15.984450,1.173580,0.805003
1216,685591,7471912,what is a google brand account?,Hello all Google brand account is that account...,1.0,"[hello, google, brand, account, account, faceb...","[google, brand, account]",0.893495,39.647229,0.657675,0.961808
1217,1097885,7229072,how long nicotine stay in your saliva,Addiction. Nicotine from cigarettes is as addi...,0.0,"[addiction, nicotine, cigarettes, addictive, h...","[long, nicotine, stay, saliva]",0.778365,14.248260,0.933692,0.803246


In [148]:
lr_ranked_list_df = validation_data_lr.sort_values(by=['score'], ascending=False)
lr_ranked_list_df = lr_ranked_list_df.reset_index().reindex(lr_ranked_list_df.columns, axis=1)
lr_ranked_list_df

Unnamed: 0,qid,pid,queries,passage,relevancy,passage_cleaned,query_cleaned,co_similarity,bm25,word_mover_distance,score
0,1007691,7251254,"when allocating service department costs, the ...",Direct method allocates each service departmen...,1.0,"[direct, method, allocate, service, department...","[allocate, service, department, cost, method, ...",0.958384,189.441988,0.445933,0.999999
1,1089945,7079883,the __________ test is a quick and dirty test ...,• The Smell Test is familiar ground in most bu...,1.0,"[smell, test, familiar, grind, businesses, goo...","[__________, test, quick, dirty, test, decide,...",0.954141,151.095160,0.468387,0.999986
2,1085764,7116260,what constitutional clause prohibits a state f...,What constitutional clause prohibits a State f...,1.0,"[constitutional, clause, prohibit, state, draw...","[constitutional, clause, prohibit, state, draw...",0.892600,124.247251,0.580840,0.999906
3,8701,7990420,_____ is the ability of cardiac pacemaker cell...,A skeletal muscle contracts only after it is s...,1.0,"[skeletal, muscle, contract, stimulate, nerve,...","[_____, ability, cardiac, pacemaker, cells, sp...",0.948867,109.809304,0.417478,0.999735
4,1093920,7711895,is the process in which we detect physical ene...,Sensation is the process by which we detect ph...,1.0,"[sensation, process, detect, physical, energy,...","[process, detect, physical, energy, environmen...",0.862633,95.906577,0.726514,0.999295
...,...,...,...,...,...,...,...,...,...,...,...
1214,1087566,7100934,what are hcp,"HCP, Inc. is a real estate investment trust wh...",1.0,"[hcp, inc, real, estate, investment, trust, in...",[hcp],-0.010663,7.877156,1.394468,0.673517
1215,714672,7480666,what is an fmo,Florida Marketing Organization has a dedicated...,1.0,"[florida, market, organization, dedicate, focu...",[fmo],-0.161294,6.973536,1.450475,0.648188
1216,293992,2790193,how many product lines does coca cola have,COCA-COLA ON SOCIAL. Coca-Cola Great Britain i...,1.0,"[cocacola, social, cocacola, great, britain, r...","[many, product, line, coca, cola]",0.675006,0.000000,1.120419,0.593006
1217,1029030,7792577,what is utah law on non compete,There are limited situations where a reasonabl...,0.0,"[limit, situations, reasonable, noncompete, ag...","[utah, law, non, compete]",0.620559,0.000000,1.209333,0.591462


In [152]:
average_precision_cal(lr_ranked_list_df)

0.397661437204143

In [153]:
get_NDCG(lr_ranked_list_df)

0.9581420368748004

In [135]:
from sklearn.linear_model import LogisticRegression as Logistic_sklearn

In [136]:
lr_sklearn = Logistic_sklearn()
lr_sklearn.fit(x_train.squeeze(2), y_train)

LogisticRegression()

In [137]:
pred_sklearn = lr_sklearn.predict(x_val.squeeze(2))

In [138]:
score = lr_sklearn.score(x_val.squeeze(2), y_val)
score

0.6119770303527482

# Subtask 3


In [3]:
import pandas as pd
import numpy as np
from xgboost import DMatrix,train



In [4]:
train_data = pd.read_pickle('train_data.pkl')
validation_data = pd.read_pickle('validation_data.pkl')

In [6]:
train_data
validation_data

Unnamed: 0,qid,pid,queries,passage,relevancy,passage_cleaned,query_cleaned,co_similarity,bm25,word_mover_distance
0,1094440,7670435,avranches france weather,Avranches: Annual Weather Averages July is the...,1.0,"[avranches, annual, weather, average, july, ho...","[avranches, france, weather]",0.451117,31.469813,0.999413
1,932223,7812781,what's the lizard from monsters inc name,What is the lizards name in monsters inc. kgb ...,1.0,"[lizards, name, monsters, inc, kgb, answer, ar...","[lizard, monsters, inc, name]",0.838567,66.593644,0.768262
2,694561,4500432,what is a phantom color poodle,Height: The Standard Poodle is above 15 inches...,0.0,"[height, standard, poodle, inch, usually, inch...","[phantom, color, poodle]",0.636981,34.366482,0.947913
3,1088164,7095519,weather in warwick,The lowest monthly average temperature in Warw...,1.0,"[lowest, monthly, average, temperature, warwic...","[weather, warwick]",0.470898,17.919724,1.079831
4,621419,7686585,what disney movie is love is an open door from,Love is an Open Door - Frozen HD 1080p. Love i...,1.0,"[love, open, door, freeze, hd, 1080p, love, op...","[disney, movie, love, open, door]",0.790705,47.003638,0.952079
...,...,...,...,...,...,...,...,...,...,...
4825,685177,5448202,what is a good facial cleanser,A cleanser is a facial care product that is us...,1.0,"[cleanser, facial, care, product, use, remove,...","[good, facial, cleanser]",0.647698,28.812012,1.058737
4826,1100051,4943970,health effects of styrofoam cups,To estimate independent effects of alcohol or ...,0.0,"[estimate, independent, effect, alcohol, smoke...","[health, effect, styrofoam, cup]",0.649538,31.013759,1.036648
4827,231109,7763691,how far is grand hyatt denver to the conventio...,City Center Grand Opening. The City of Pembrok...,0.0,"[city, center, grand, open, city, pembroke, pi...","[far, grand, hyatt, denver, convention, center]",0.690669,31.013759,1.001228
4828,693736,7503664,what is a parking brake,What is the Purpose of a Parking Brake. Posted...,1.0,"[purpose, park, brake, post, august, dave, joh...","[park, brake]",0.727260,33.868585,0.966780


Unnamed: 0,qid,pid,queries,passage,relevancy,passage_cleaned,query_cleaned,co_similarity,bm25,word_mover_distance
0,467612,7552919,oasis hospital contact number,Contact Us Contact Us If you are unable to cal...,1.0,"[contact, us, contact, us, unable, call, oasis...","[oasis, hospital, contact, number]",0.683691,24.411400,1.050934
1,1046952,7318074,who played luke on general hospital,It was in the oral and facial unit and it was ...,0.0,"[oral, facial, unit, dentists, room, set, like...","[play, luke, general, hospital]",0.637770,17.520037,1.099775
2,1062190,5300247,why did bob marley die of cancer,exists and is an alternate of . Napoleon Bonap...,0.0,"[exist, alternate, napoleon, bonaparte, die, b...","[bob, marley, die, cancer]",0.572388,25.229735,0.938711
3,946825,7963463,when does la liga espanola start,Traditional new La Liga season starts on the t...,1.0,"[traditional, new, la, liga, season, start, th...","[la, liga, espanola, start]",0.488907,26.889117,1.041251
4,412982,6020416,is indirect labor an overhead cost,Ans: d Response A: Materials requisition slips...,0.0,"[ans, response, materials, requisition, slip, ...","[indirect, labor, overhead, cost]",0.709268,27.235843,1.016331
...,...,...,...,...,...,...,...,...,...,...
1214,922335,7811374,what was the reason for the snake war,The Snake War has been widely forgotten in Uni...,1.0,"[snake, war, widely, forget, unite, state, his...","[reason, snake, war]",0.775045,26.856588,0.991801
1215,341039,7330977,how tall is kate mckinnon,The list of complete Kate McKinnon body measur...,1.0,"[list, complete, kate, mckinnon, body, measure...","[tall, kate, mckinnon]",0.379143,15.984450,1.173580
1216,685591,7471912,what is a google brand account?,Hello all Google brand account is that account...,1.0,"[hello, google, brand, account, account, faceb...","[google, brand, account]",0.893495,39.647229,0.657675
1217,1097885,7229072,how long nicotine stay in your saliva,Addiction. Nicotine from cigarettes is as addi...,0.0,"[addiction, nicotine, cigarettes, addictive, h...","[long, nicotine, stay, saliva]",0.778365,14.248260,0.933692


## Preprocessing the data into the form which XGboost can take

In [7]:
num_of_features = 3

## train data

In [8]:
train_data

Unnamed: 0,qid,pid,queries,passage,relevancy,passage_cleaned,query_cleaned,co_similarity,bm25,word_mover_distance
0,1094440,7670435,avranches france weather,Avranches: Annual Weather Averages July is the...,1.0,"[avranches, annual, weather, average, july, ho...","[avranches, france, weather]",0.451117,31.469813,0.999413
1,932223,7812781,what's the lizard from monsters inc name,What is the lizards name in monsters inc. kgb ...,1.0,"[lizards, name, monsters, inc, kgb, answer, ar...","[lizard, monsters, inc, name]",0.838567,66.593644,0.768262
2,694561,4500432,what is a phantom color poodle,Height: The Standard Poodle is above 15 inches...,0.0,"[height, standard, poodle, inch, usually, inch...","[phantom, color, poodle]",0.636981,34.366482,0.947913
3,1088164,7095519,weather in warwick,The lowest monthly average temperature in Warw...,1.0,"[lowest, monthly, average, temperature, warwic...","[weather, warwick]",0.470898,17.919724,1.079831
4,621419,7686585,what disney movie is love is an open door from,Love is an Open Door - Frozen HD 1080p. Love i...,1.0,"[love, open, door, freeze, hd, 1080p, love, op...","[disney, movie, love, open, door]",0.790705,47.003638,0.952079
...,...,...,...,...,...,...,...,...,...,...
4825,685177,5448202,what is a good facial cleanser,A cleanser is a facial care product that is us...,1.0,"[cleanser, facial, care, product, use, remove,...","[good, facial, cleanser]",0.647698,28.812012,1.058737
4826,1100051,4943970,health effects of styrofoam cups,To estimate independent effects of alcohol or ...,0.0,"[estimate, independent, effect, alcohol, smoke...","[health, effect, styrofoam, cup]",0.649538,31.013759,1.036648
4827,231109,7763691,how far is grand hyatt denver to the conventio...,City Center Grand Opening. The City of Pembrok...,0.0,"[city, center, grand, open, city, pembroke, pi...","[far, grand, hyatt, denver, convention, center]",0.690669,31.013759,1.001228
4828,693736,7503664,what is a parking brake,What is the Purpose of a Parking Brake. Posted...,1.0,"[purpose, park, brake, post, august, dave, joh...","[park, brake]",0.727260,33.868585,0.966780


In [9]:
train_data = train_data.sort_values(by=['queries'], ascending=False)
train_data


Unnamed: 0,qid,pid,queries,passage,relevancy,passage_cleaned,query_cleaned,co_similarity,bm25,word_mover_distance
1069,1081569,8801198,youngest actors nominated for oscar,"In 1998, Gabriel appeared on the soundtrack of...",0.0,"[gabriel, appear, soundtrack, babe, pig, city,...","[youngest, actors, nominate, oscar]",0.625069,27.390679,1.129081
4295,1080939,2137948,xfinity business customer service number,xfinity® Customer Service Phone Number: 1-800-...,1.0,"[xfinity, customer, service, phone, number, sp...","[xfinity, business, customer, service, number]",0.850198,60.572041,0.610794
31,1080939,3086536,xfinity business customer service number,Leave your name and number and we'll get back ...,0.0,"[leave, name, number, get, back, within, minut...","[xfinity, business, customer, service, number]",0.617578,56.289721,0.947025
1501,1080537,1737709,wright city mo population,"63361 is the zip code for Montgomery City, MO ...",0.0,"[zip, code, montgomery, city, mo, also, locati...","[wright, city, mo, population]",0.687016,31.745812,0.831660
2225,1080537,7964183,wright city mo population,Most / Least Educated Cities in MO. As of 2010...,1.0,"[least, educate, cities, mo, total, population...","[wright, city, mo, population]",0.633695,57.677868,0.944734
...,...,...,...,...,...,...,...,...,...,...
2244,1093507,7741722,7) how does human rights law differ from the l...,Traditional international law is a law of powe...,1.0,"[traditional, international, law, law, power, ...","[human, right, law, differ, law, war]",0.880920,115.298056,0.832027
2206,1093507,7301660,7) how does human rights law differ from the l...,"Today, freedom of speech, or the freedom of ex...",0.0,"[today, freedom, speech, freedom, expression, ...","[human, right, law, differ, law, war]",0.821726,86.749005,0.840228
1071,1215,2888358,3 levels of government in canada and their res...,1 They ruled that the Congress could not ban s...,0.0,"[rule, congress, could, ban, slavery, territor...","[level, government, canada, responsibilities]",0.771057,29.399721,1.035060
640,1215,5754270,3 levels of government in canada and their res...,Conceives of federalism as a marble cake in wh...,0.0,"[conceive, federalism, marble, cake, level, go...","[level, government, canada, responsibilities]",0.682814,32.846153,1.001763


In [10]:
lambdaMart_x_train = np.zeros((train_data.shape[0], num_of_features))

In [11]:
lambdaMart_x_train

array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       ...,
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]])

In [12]:
lambdaMart_x_train = train_data[['co_similarity', 'bm25', 'word_mover_distance']].values
lambdaMart_x_train

array([[ 0.62506872, 27.39067855,  1.12908076],
       [ 0.85019787, 60.57204123,  0.6107939 ],
       [ 0.61757767, 56.2897212 ,  0.94702547],
       ...,
       [ 0.77105719, 29.39972084,  1.03506017],
       [ 0.68281389, 32.84615267,  1.00176317],
       [ 0.70663054, 31.70384234,  0.911557  ]])

In [13]:
lambdaMart_y_train = train_data.relevancy.values

In [14]:
query_count_dict_train = train_data['queries'].value_counts().to_dict()

In [15]:
query_count_dict_train

{'definition of tour of duty in the military': 5,
 'battle of gettysburg how many troops': 5,
 'what are signs of anxiety in your chest': 5,
 'monthly average hours daylight los angeles': 5,
 "which amendment prohibited states from violating citizens' civil rights?": 5,
 'most expensive hotels in new york city': 5,
 'cad heart related': 5,
 'what is badger inc in penn yan ny': 5,
 'what do dark chocolate do for body': 5,
 'what county is wauseon oh': 5,
 'what are oxidizing and reducing agents reddit': 5,
 'what is psychic remote viewing': 5,
 'who directed the giver': 4,
 'difference between statutory and regulatory requirements': 4,
 'what is a phantom color poodle': 4,
 'percentage of white people in the us on welfare': 4,
 'effects of hyperthyroidism': 4,
 'what makes a community healthy': 4,
 'what does rough trade mean': 4,
 'who was involved in the locarno treaties': 4,
 'what is at the base of the aorta': 4,
 'fec independent expenditure definition': 4,
 'define taxonomical': 4

In [16]:
len(query_count_dict_train)

3154

In [17]:
dgroup_train = []
dgroup_train

[]

In [18]:
checker = []
for idx, row in train_data.iterrows():
    query = row['queries']
    if query in checker:
        continue
    checker.append(query)
    group_entry_count = query_count_dict_train[query]
    dgroup_train.append(group_entry_count)

In [19]:
# dgroup_train = np.array(dgroup_train)

In [20]:
dgroup_train

[1,
 2,
 2,
 1,
 2,
 1,
 2,
 2,
 3,
 1,
 1,
 1,
 2,
 2,
 1,
 2,
 1,
 2,
 1,
 3,
 2,
 2,
 2,
 3,
 2,
 1,
 1,
 1,
 1,
 1,
 2,
 2,
 1,
 2,
 1,
 4,
 1,
 2,
 1,
 1,
 1,
 3,
 1,
 1,
 1,
 1,
 1,
 2,
 3,
 1,
 2,
 3,
 1,
 1,
 2,
 1,
 3,
 2,
 2,
 1,
 1,
 1,
 1,
 2,
 1,
 1,
 2,
 2,
 1,
 1,
 4,
 1,
 1,
 3,
 1,
 1,
 1,
 1,
 1,
 1,
 2,
 1,
 1,
 1,
 1,
 2,
 1,
 1,
 2,
 1,
 1,
 1,
 2,
 1,
 1,
 1,
 1,
 1,
 2,
 2,
 1,
 1,
 2,
 3,
 1,
 1,
 2,
 3,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 2,
 1,
 3,
 1,
 1,
 1,
 2,
 2,
 1,
 3,
 1,
 1,
 2,
 1,
 1,
 1,
 1,
 1,
 1,
 2,
 2,
 1,
 1,
 1,
 1,
 1,
 4,
 3,
 2,
 1,
 3,
 2,
 3,
 2,
 1,
 1,
 1,
 2,
 1,
 1,
 1,
 2,
 1,
 1,
 1,
 1,
 1,
 2,
 2,
 3,
 1,
 3,
 1,
 1,
 4,
 1,
 1,
 1,
 1,
 1,
 3,
 2,
 2,
 1,
 1,
 1,
 1,
 2,
 1,
 1,
 1,
 1,
 1,
 2,
 1,
 2,
 1,
 2,
 1,
 2,
 1,
 1,
 2,
 2,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 4,
 1,
 1,
 1,
 2,
 1,
 1,
 3,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 2,
 1,
 2,
 1,
 3,
 1,
 1,
 1,
 2,
 3,
 1,
 2,
 1,
 1,
 5,
 1,
 1,
 1,
 1,


In [21]:
np.sum(dgroup_train)

4830

## validation data

In [22]:
validation_data = validation_data.sort_values(by=['queries'], ascending=False)
validation_data

Unnamed: 0,qid,pid,queries,passage,relevancy,passage_cleaned,query_cleaned,co_similarity,bm25,word_mover_distance
898,1078446,7081137,wine cabinets definition,How to Remove an Odor From Wooden Cabinets in ...,0.0,"[remove, odor, wooden, cabinets, kitchen, deod...","[wine, cabinets, definition]",0.536559,11.240729,1.156885
103,1078446,7865137,wine cabinets definition,wine rack noun [C]. › a ​wooden or ​metal ​fra...,1.0,"[wine, rack, noun, c, wooden, metal, frame, us...","[wine, cabinets, definition]",0.659124,21.287782,1.034665
94,1101868,7500425,willie weeks net worth,Willie Geist Net Worth is $2 Million. Willie G...,1.0,"[willie, geist, net, worth, million, willie, g...","[willie, weeks, net, worth]",0.704224,33.245087,0.897249
997,1101870,7753562,willie mays worth,The Groundhog Day concept became popular in Ca...,0.0,"[groundhog, day, concept, become, popular, can...","[willie, mays, worth]",0.355603,11.357400,1.193349
249,1101870,7443738,willie mays worth,Ward said Pixar wanted to capture what NASCAR ...,0.0,"[ward, say, pixar, want, capture, nascar, race...","[willie, mays, worth]",0.288284,6.885770,1.266545
...,...,...,...,...,...,...,...,...,...,...
603,4696,4261568,Is the Louisiana sales tax 4.75,"The combined sales tax rate for Ithaca, NY is ...",0.0,"[combine, sales, tax, rate, ithaca, ny, total,...","[louisiana, sales, tax]",0.833281,27.630300,0.790931
589,2962,7222545,Explain what a bone scan is and what it is use...,Re: Bone Cancer Signs and PET. PET scan does n...,0.0,"[bone, cancer, sign, pet, pet, scan, confirm, ...","[explain, bone, scan, use]",0.802917,21.287782,0.950971
1212,2962,5618754,Explain what a bone scan is and what it is use...,Although a bone scan is very sensitive to abno...,0.0,"[although, bone, scan, sensitive, abnormalitie...","[explain, bone, scan, use]",0.846757,23.941265,0.878039
852,2962,658625,Explain what a bone scan is and what it is use...,Why It Is Done. A bone scan is done to: 1 Fin...,1.0,"[do, bone, scan, do, find, bone, cancer, deter...","[explain, bone, scan, use]",0.845000,23.113599,0.940562


In [25]:
lambdaMart_x_val = np.zeros((validation_data.shape[0], num_of_features))

In [26]:
lambdaMart_x_val = validation_data[['co_similarity', 'bm25', 'word_mover_distance']].values
lambdaMart_x_val

array([[ 0.53655854, 11.24072902,  1.15688473],
       [ 0.65912417, 21.28778248,  1.03466535],
       [ 0.70422369, 33.24508737,  0.89724853],
       ...,
       [ 0.84675682, 23.94126531,  0.87803887],
       [ 0.84500048, 23.11359862,  0.94056153],
       [ 0.77077747, 14.89642112,  1.06300985]])

In [27]:
lambdaMart_y_val = validation_data.relevancy.values

In [28]:
query_count_dict_val = validation_data['queries'].value_counts().to_dict()

In [29]:
dgroup_val = []
dgroup_val

[]

In [30]:
checker = []
for idx, row in validation_data.iterrows():
    query = row['queries']
    if query in checker:
        continue
    checker.append(query)
    group_entry_count = query_count_dict_val[query]
    dgroup_val.append(group_entry_count)

In [31]:
# dgroup_val = np.array(dgroup_val)

In [32]:
dgroup_val

[2,
 1,
 2,
 2,
 1,
 1,
 1,
 1,
 2,
 1,
 1,
 3,
 2,
 1,
 1,
 1,
 2,
 3,
 1,
 1,
 1,
 2,
 1,
 2,
 1,
 2,
 1,
 2,
 1,
 2,
 2,
 1,
 1,
 1,
 2,
 1,
 2,
 2,
 1,
 2,
 1,
 1,
 1,
 2,
 2,
 1,
 1,
 1,
 2,
 1,
 1,
 1,
 1,
 3,
 1,
 1,
 2,
 2,
 1,
 1,
 1,
 2,
 1,
 1,
 4,
 1,
 1,
 2,
 2,
 2,
 1,
 2,
 2,
 1,
 1,
 2,
 2,
 1,
 2,
 1,
 1,
 2,
 1,
 1,
 2,
 1,
 1,
 1,
 1,
 2,
 1,
 2,
 2,
 1,
 2,
 1,
 4,
 2,
 2,
 1,
 3,
 1,
 1,
 1,
 1,
 3,
 1,
 1,
 2,
 2,
 1,
 4,
 2,
 1,
 2,
 1,
 2,
 1,
 2,
 1,
 1,
 1,
 1,
 2,
 2,
 2,
 1,
 1,
 2,
 1,
 2,
 2,
 1,
 1,
 3,
 4,
 1,
 1,
 1,
 1,
 1,
 2,
 1,
 1,
 2,
 2,
 1,
 1,
 1,
 1,
 1,
 1,
 3,
 1,
 1,
 1,
 1,
 2,
 1,
 2,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 2,
 1,
 1,
 1,
 2,
 1,
 1,
 2,
 2,
 1,
 2,
 1,
 2,
 2,
 2,
 1,
 1,
 3,
 2,
 1,
 1,
 1,
 3,
 1,
 2,
 1,
 1,
 1,
 5,
 3,
 2,
 1,
 1,
 1,
 1,
 2,
 1,
 1,
 3,
 2,
 1,
 2,
 1,
 2,
 2,
 1,
 1,
 1,
 1,
 1,
 3,
 1,
 2,
 2,
 1,
 1,
 1,
 1,
 1,
 1,
 2,
 1,
 2,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 3,
 1,
 1,
 1,
 2,
 1,
 3,
 2,
 2,


## LambdaMart Implementation

In [33]:
# xgb_rank_params1 ={    
#     'booster' : 'gbtree',
#     'eta': 0.1,
#     'gamma' : 1.0 ,
#     'min_child_weight' : 0.1,
#     'objective' : 'rank:pairwise',
#     'eval_metric' : 'merror',
#     'max_depth' : 6,
#     'num_boost_round':10,
#     'save_period' : 0 
# }

# params = {
#     'bst:max_depth':2, 
#     'bst:eta':1, 'silent':1, 
#     'objective':'rank:pairwise',
#     'nthread':4,
#     'eval_metric':'ndcg'
# }
  



In [34]:
#generate training dataset
# dtrain= lambdaMart_x_train
# dtarget= lambdaMart_y_train
# dgroup_train = dgroup_train
# dtrain= lambdaMart_x_train[:2]
# dtarget= lambdaMart_y_train[:2]
# dgroup_train = dgroup_train[:2]
# dtrain.shape
# dtarget.shape

In [35]:
lambdaMart_x_train

array([[ 0.62506872, 27.39067855,  1.12908076],
       [ 0.85019787, 60.57204123,  0.6107939 ],
       [ 0.61757767, 56.2897212 ,  0.94702547],
       ...,
       [ 0.77105719, 29.39972084,  1.03506017],
       [ 0.68281389, 32.84615267,  1.00176317],
       [ 0.70663054, 31.70384234,  0.911557  ]])

In [36]:
dgroup_train

[1,
 2,
 2,
 1,
 2,
 1,
 2,
 2,
 3,
 1,
 1,
 1,
 2,
 2,
 1,
 2,
 1,
 2,
 1,
 3,
 2,
 2,
 2,
 3,
 2,
 1,
 1,
 1,
 1,
 1,
 2,
 2,
 1,
 2,
 1,
 4,
 1,
 2,
 1,
 1,
 1,
 3,
 1,
 1,
 1,
 1,
 1,
 2,
 3,
 1,
 2,
 3,
 1,
 1,
 2,
 1,
 3,
 2,
 2,
 1,
 1,
 1,
 1,
 2,
 1,
 1,
 2,
 2,
 1,
 1,
 4,
 1,
 1,
 3,
 1,
 1,
 1,
 1,
 1,
 1,
 2,
 1,
 1,
 1,
 1,
 2,
 1,
 1,
 2,
 1,
 1,
 1,
 2,
 1,
 1,
 1,
 1,
 1,
 2,
 2,
 1,
 1,
 2,
 3,
 1,
 1,
 2,
 3,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 2,
 1,
 3,
 1,
 1,
 1,
 2,
 2,
 1,
 3,
 1,
 1,
 2,
 1,
 1,
 1,
 1,
 1,
 1,
 2,
 2,
 1,
 1,
 1,
 1,
 1,
 4,
 3,
 2,
 1,
 3,
 2,
 3,
 2,
 1,
 1,
 1,
 2,
 1,
 1,
 1,
 2,
 1,
 1,
 1,
 1,
 1,
 2,
 2,
 3,
 1,
 3,
 1,
 1,
 4,
 1,
 1,
 1,
 1,
 1,
 3,
 2,
 2,
 1,
 1,
 1,
 1,
 2,
 1,
 1,
 1,
 1,
 1,
 2,
 1,
 2,
 1,
 2,
 1,
 2,
 1,
 1,
 2,
 2,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 4,
 1,
 1,
 1,
 2,
 1,
 1,
 3,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 2,
 1,
 2,
 1,
 3,
 1,
 1,
 1,
 2,
 3,
 1,
 2,
 1,
 1,
 5,
 1,
 1,
 1,
 1,


In [37]:
# concate Train data, very import here !

train_dmatrix = DMatrix(lambdaMart_x_train, label = lambdaMart_y_train)
valid_dmatrix = DMatrix(lambdaMart_x_val, label = lambdaMart_y_val)
train_dmatrix.set_group(dgroup_train)
valid_dmatrix.set_group(dgroup_val)




In [38]:
params = {'objective': 'rank:ndcg', 'eta': 0.1, 'gamma': 1.0,
          'min_child_weight': 0.1, 'max_depth': 6}

In [39]:
lambdaRank_model = train(params, train_dmatrix, num_boost_round=4,
                      evals=[(valid_dmatrix, 'validation')])


[0]	validation-map:0.93546
[1]	validation-map:0.93884
[2]	validation-map:0.93524
[3]	validation-map:0.93718


In [40]:
# pred = lambdaRank_model.predict(test_dmatrix)

In [41]:
# # generate eval data -> validation data
# dtrain_eval= lambdaMart_x_val   
# xgbTrain_eval = DMatrix(lambdaMart_x_val, label = lambdaMart_y_val)

# evallist  = [(xgbTrain,'train'),(xgbTrain_eval, 'eval')]


In [182]:
# # train model
# # xgb_rank_params1加上 evals 这个参数会报错，还没找到原因
# # rankModel = train(xgb_rank_params1,xgbTrain,num_boost_round=10)
# rankModel = train(params,xgbTrain,num_boost_round=20,evals=evallist)



[0]	train-ndcg:1	eval-ndcg:1
[1]	train-ndcg:1	eval-ndcg:1
[2]	train-ndcg:1	eval-ndcg:1
[3]	train-ndcg:1	eval-ndcg:1
[4]	train-ndcg:1	eval-ndcg:1
[5]	train-ndcg:1	eval-ndcg:1
[6]	train-ndcg:1	eval-ndcg:1
[7]	train-ndcg:1	eval-ndcg:1
[8]	train-ndcg:1	eval-ndcg:1
[9]	train-ndcg:1	eval-ndcg:1
[10]	train-ndcg:1	eval-ndcg:1
[11]	train-ndcg:1	eval-ndcg:1
[12]	train-ndcg:1	eval-ndcg:1
[13]	train-ndcg:1	eval-ndcg:1
[14]	train-ndcg:1	eval-ndcg:1
[15]	train-ndcg:1	eval-ndcg:1
[16]	train-ndcg:1	eval-ndcg:1
[17]	train-ndcg:1	eval-ndcg:1
[18]	train-ndcg:1	eval-ndcg:1
[19]	train-ndcg:1	eval-ndcg:1


In [189]:
# #test dataset
# dtest= lambdaMart_x_val  
# dtestgroup=dgroup_val
# xgbTest = DMatrix(dtest)
# xgbTest.set_group(dgroup_train)


In [190]:
# print(rankModel.predict(xgbTest))

[0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5
 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5
 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5
 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5
 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5
 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5
 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5
 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5
 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5
 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5
 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5
 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5
 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5
 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.

## Subtask 4

In [42]:
import tensorflow as tf
from tensorflow.keras import layers, activations, losses, Model, Input
from tensorflow.nn import leaky_relu
import numpy as np
from itertools import combinations
from tensorflow.keras.utils import plot_model, Progbar
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# model architecture
class RankNet(Model):
    def __init__(self):
        super().__init__()
        self.dense = [layers.Dense(16, activation=leaky_relu), layers.Dense(8, activation=leaky_relu)]
        self.o = layers.Dense(1, activation='linear')
        self.oi_minus_oj = layers.Subtract()
    
    def call(self, inputs):
        xi, xj = inputs
        densei = self.dense[0](xi)
        densej = self.dense[0](xj)
        for dense in self.dense[1:]:
            densei = dense(densei)
            densej = dense(densej)
        oi = self.o(densei)
        oj= self.o(densej)
        oij = self.oi_minus_oj([oi, oj])
        output = layers.Activation('sigmoid')(oij)
        return output
    
    def build_graph(self):
        x = [Input(shape=(10)), Input(shape=(10))]
        return Model(inputs=x, outputs=self.call(x))

# visualize model architecture
# plot_model(RankNet().build_graph(), show_shapes=False)

In [43]:
train_data

Unnamed: 0,qid,pid,queries,passage,relevancy,passage_cleaned,query_cleaned,co_similarity,bm25,word_mover_distance
1069,1081569,8801198,youngest actors nominated for oscar,"In 1998, Gabriel appeared on the soundtrack of...",0.0,"[gabriel, appear, soundtrack, babe, pig, city,...","[youngest, actors, nominate, oscar]",0.625069,27.390679,1.129081
4295,1080939,2137948,xfinity business customer service number,xfinity® Customer Service Phone Number: 1-800-...,1.0,"[xfinity, customer, service, phone, number, sp...","[xfinity, business, customer, service, number]",0.850198,60.572041,0.610794
31,1080939,3086536,xfinity business customer service number,Leave your name and number and we'll get back ...,0.0,"[leave, name, number, get, back, within, minut...","[xfinity, business, customer, service, number]",0.617578,56.289721,0.947025
1501,1080537,1737709,wright city mo population,"63361 is the zip code for Montgomery City, MO ...",0.0,"[zip, code, montgomery, city, mo, also, locati...","[wright, city, mo, population]",0.687016,31.745812,0.831660
2225,1080537,7964183,wright city mo population,Most / Least Educated Cities in MO. As of 2010...,1.0,"[least, educate, cities, mo, total, population...","[wright, city, mo, population]",0.633695,57.677868,0.944734
...,...,...,...,...,...,...,...,...,...,...
2244,1093507,7741722,7) how does human rights law differ from the l...,Traditional international law is a law of powe...,1.0,"[traditional, international, law, law, power, ...","[human, right, law, differ, law, war]",0.880920,115.298056,0.832027
2206,1093507,7301660,7) how does human rights law differ from the l...,"Today, freedom of speech, or the freedom of ex...",0.0,"[today, freedom, speech, freedom, expression, ...","[human, right, law, differ, law, war]",0.821726,86.749005,0.840228
1071,1215,2888358,3 levels of government in canada and their res...,1 They ruled that the Congress could not ban s...,0.0,"[rule, congress, could, ban, slavery, territor...","[level, government, canada, responsibilities]",0.771057,29.399721,1.035060
640,1215,5754270,3 levels of government in canada and their res...,Conceives of federalism as a marble cake in wh...,0.0,"[conceive, federalism, marble, cake, level, go...","[level, government, canada, responsibilities]",0.682814,32.846153,1.001763


In [44]:
validation_data

Unnamed: 0,qid,pid,queries,passage,relevancy,passage_cleaned,query_cleaned,co_similarity,bm25,word_mover_distance
898,1078446,7081137,wine cabinets definition,How to Remove an Odor From Wooden Cabinets in ...,0.0,"[remove, odor, wooden, cabinets, kitchen, deod...","[wine, cabinets, definition]",0.536559,11.240729,1.156885
103,1078446,7865137,wine cabinets definition,wine rack noun [C]. › a ​wooden or ​metal ​fra...,1.0,"[wine, rack, noun, c, wooden, metal, frame, us...","[wine, cabinets, definition]",0.659124,21.287782,1.034665
94,1101868,7500425,willie weeks net worth,Willie Geist Net Worth is $2 Million. Willie G...,1.0,"[willie, geist, net, worth, million, willie, g...","[willie, weeks, net, worth]",0.704224,33.245087,0.897249
997,1101870,7753562,willie mays worth,The Groundhog Day concept became popular in Ca...,0.0,"[groundhog, day, concept, become, popular, can...","[willie, mays, worth]",0.355603,11.357400,1.193349
249,1101870,7443738,willie mays worth,Ward said Pixar wanted to capture what NASCAR ...,0.0,"[ward, say, pixar, want, capture, nascar, race...","[willie, mays, worth]",0.288284,6.885770,1.266545
...,...,...,...,...,...,...,...,...,...,...
603,4696,4261568,Is the Louisiana sales tax 4.75,"The combined sales tax rate for Ithaca, NY is ...",0.0,"[combine, sales, tax, rate, ithaca, ny, total,...","[louisiana, sales, tax]",0.833281,27.630300,0.790931
589,2962,7222545,Explain what a bone scan is and what it is use...,Re: Bone Cancer Signs and PET. PET scan does n...,0.0,"[bone, cancer, sign, pet, pet, scan, confirm, ...","[explain, bone, scan, use]",0.802917,21.287782,0.950971
1212,2962,5618754,Explain what a bone scan is and what it is use...,Although a bone scan is very sensitive to abno...,0.0,"[although, bone, scan, sensitive, abnormalitie...","[explain, bone, scan, use]",0.846757,23.941265,0.878039
852,2962,658625,Explain what a bone scan is and what it is use...,Why It Is Done. A bone scan is done to: 1 Fin...,1.0,"[do, bone, scan, do, find, bone, cancer, deter...","[explain, bone, scan, use]",0.845000,23.113599,0.940562


In [75]:
qids = train_data.qid.values
doc_features = train_data[['co_similarity', 'bm25', 'word_mover_distance']].values
doc_scores = train_data.relevancy.values

In [47]:
qids.shape

(4830,)

In [48]:
np.unique(qids).shape

(3154,)

In [58]:
# generate data
nb_query = 20
query = np.array([i+1 for i in range(nb_query) for x in range(int(np.ceil(np.abs(np.random.normal(0,scale=15))+2)))])
# doc_features = np.random.random((len(query), 10))
# doc_scores = np.random.randint(5, size=len(query)).astype(np.float32)
query

array([ 1,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,
        2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
        2,  2,  2,  2,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
        3,  3,  3,  3,  3,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,
        4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,
        4,  4,  4,  4,  4,  4,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
        5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  6,  6,  6,  6,  6,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  7,  7,  7,  8,  8,  8,  8,  8,  8,  8,
        8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  9,  9,
        9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,
       10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
       10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 12, 12, 12, 13,
       13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14,
       14, 14, 15, 15, 15

In [59]:
np.where(query == 1)

(array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int64),)

In [78]:
# put data into pairs
xi = []
xj = []
pij = []
pair_id = []
pair_query_id = []
for qid in np.unique(qids):
    query_idx = np.where(qids == qid)[0]
#     print(query_idx)
    for pair_idx in combinations(query_idx, 2):
        pair_query_id.append(qid)
        
        pair_id.append(pair_idx)
        i = pair_idx[0]
        j = pair_idx[1]
        xi.append(doc_features[i])
        xj.append(doc_features[j])
        
        if doc_scores[i] == doc_scores[j]:
            _pij = 0.5
        elif doc_scores[i] > doc_scores[j]:
            _pij = 1
        else: 
            _pij = 0
        pij.append(_pij)
        
xi = np.array(xi)
xj = np.array(xj)
pij = np.array(pij)
pair_query_id = np.array(pair_query_id)


xi_train, xi_test, xj_train, xj_test, pij_train, pij_test, pair_id_train, pair_id_test = train_test_split(
    xi, xj, pij, pair_id, test_size=0.2)


In [79]:
# train model using compile and fit
ranknet = RankNet()
ranknet.compile(optimizer='adam', loss='binary_crossentropy')
history = ranknet.fit([xi_train, xj_train], pij_train, epochs=50, batch_size=1, validation_data=([xi_test, xj_test], pij_test))


Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Train on 1761 samples, validate on 441 samples
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50

KeyboardInterrupt: 