In [1]:
import collections
import io
import time
import math
import pickle
import os
import pyndri
import pyndri.compat
import logging
import sys
import numpy as np
import gensim
import pandas as pd

from copy import deepcopy
from sklearn.linear_model import LogisticRegression

In [2]:
def load_pickle(fpath):
    with open(fpath, 'rb') as file:
        return pickle.load(file)
    
def save_pickle(obj, fpath):
    with open(fpath, 'wb') as file:
        pickle.dump(obj, file)

In [3]:
tfidf_data = dict(load_pickle('../pickles/prepro_doc_col_q10_top1000_tfidf.pkl'))
index = pyndri.Index('../index/')

In [4]:
class DataLoader(object):
    def __init__(self, tfidf_data:dict, index:pyndri.Index, models:list, rel_file:str):
        self.tfidf_data = tfidf_data
        self.index = index
        self.df = None
        
        self.index_list = []
        self.full_index_list = []
        
        self.create_df()
        self.load_data(models_list=models, relevance_file=rel_file)
                
    def get_indices_lists(self):
        """Create the index list based on query ID and external document ID."""
        self.query_ids = list(self.tfidf_data.keys())
        for query_id, int_doc_ids in self.tfidf_data.items():
            for int_doc_id in int_doc_ids:
                ext_doc_id, _ = index.document(int_doc_id)
                self.index_list.append('~'.join((str(query_id), str(ext_doc_id))))
                self.full_index_list.append('~'.join((str(query_id), str(int_doc_id), str(ext_doc_id))))
                
    def create_df(self):
        """Create initial DataFrame, populating it with useful data."""
        self.get_indices_lists()
        self.df = pd.DataFrame(index=self.index_list)
        self.df['idx'] = self.full_index_list
        self.df['query_id'] = self.df.idx.apply(lambda x: x.split('~')[0])
        self.df['int_doc_id'] = self.df.idx.apply(lambda x: x.split('~')[1])
        self.df['ext_doc_id'] = self.df.idx.apply(lambda x: x.split('~')[2])
        self.df.drop(['idx'], axis=1, inplace=True)
        print("DataFrame created.")
        
    def load_data_from_file(self, model_name):
        """Load model scores from file.
        
        Args:
            model_name: name of the model.
        """
        print("Loading data for model {}".format(model_name))
        retrieval_start_time = time.time()

        with open('../retrievals/{}.run'.format(model_name)) as file:
            for line in file.readlines():
                if line[:2] not in self.query_ids:
                    continue

                query_id, _, ext_doc_id, __, score, model = line.split()
                idx = '~'.join((query_id, ext_doc_id))
        
                if idx in self.df.index:
                    self.df.loc['~'.join((query_id, ext_doc_id)), model] = float(score)

        print("Data loaded in {} seconds.".format(time.time() - retrieval_start_time))
        
    def load_relevance_label(self, file_path):
        """Load relevance labels from file.
        
        Args:
            file_path: path to the qrel_test file.
        """
        print("Loading relevance labels.")
        retrieval_start_time = time.time()

        with open(file_path) as file:
            for line in file.readlines():
                if line[:2] not in self.query_ids:
                    continue

                query_id, _, ext_doc_id, relevance = line.split()
                idx = '~'.join((query_id, ext_doc_id))
        
                if idx in self.df.index:
                    self.df.loc['~'.join((query_id, ext_doc_id)), 'relevance_label'] = int(relevance)

        self.df['relevance_label'].fillna(value=0, inplace=True)
        print("Labels loaded in {} seconds.".format(time.time() - retrieval_start_time))
        
    def drop_rows_with_null(self):
        """Drop the rows containing null values."""
        i = 0
        for idx, row in self.df.iterrows():
            if row.isnull().any():
                i += 1
                self.df.drop(idx, inplace=True, axis=0)
            
        print("{} rows dropped. DataFrame length:".format(i), end="")
        print(self.data_length)
        
    def load_data(self, models_list, relevance_file):
        """Wrapper method to load all models scores and relevance labels.
        
        Args:
            models_list: list of model names.
            relevance_file: path to the file with relevance labels
        """
        for model in models_list:
            self.load_data_from_file(model)
        self.drop_rows_with_null()
        self.load_relevance_label(relevance_file)

    def data_has_nulls(self):
        """Check whether df has any null values"""
        return self.df.isnull().any()
    
    def column_has_nulls(self, col_name):
        """Check whether a column has any null values.
        
        Args:
            col_name: name of the column.
        """
        return self.df[col_name].isnull().any()
    
    def count_null_values(self, col_name):
        """Retrieve the count of null values on a column.
        
        Args:
            col_name: name of the column
        """
        return np.sum(self.df[col_name].isnull())
    
    def save_dataframe(self, fpath):
        """Save DataFrame object to file.
        
        Args:
            fpath: file path to save.
        """
        with open(fpath, 'wb') as file:
            pickle.dump(self.df, file)
    
    @property
    def data(self):
        """Retrieve DataFrame object."""
        if self.df is None:
            self.create_df()
        
        return self.df
    
    @property
    def data_length(self):
        """Retrieve DataFrame object length."""
        if self.df is None:
            self.create_df()
        
        return len(self.df)

In [5]:
models = ['tfidf', 'LDA', 'LSI', 'jm_lambda_0.1', 'dp_mu_500', 'ad_delta_0.9', 'glm']
rel_file = '../ap_88_89/qrel_test'
data_loader = DataLoader(tfidf_data, index, models, rel_file)
data_loader.save_dataframe('../pickles/LTR_DataFrame.pkl')

DataFrame created.
Loading data for model tfidf
Data loaded in 5.246838092803955 seconds.
Loading data for model LDA
Data loaded in 5.35348916053772 seconds.
Loading data for model LSI
Data loaded in 6.03591513633728 seconds.
Loading data for model jm_lambda_0.1
Data loaded in 2.3770501613616943 seconds.
Loading data for model dp_mu_500
Data loaded in 3.2428719997406006 seconds.
Loading data for model ad_delta_0.9
Data loaded in 2.764694929122925 seconds.
Loading data for model glm
Data loaded in 5.0609190464019775 seconds.
5031 rows dropped. DataFrame length:4126
Loading relevance labels.
Labels loaded in 0.2678070068359375 seconds.


In [6]:
df_data = load_pickle('../pickles/LTR_DataFrame.pkl')

### Logistic Regression

In [8]:
X_train = df_data[['tfidf', 'LDA', 'LSI', 'jm_lambda_0.1', 'dp_mu_500', 'ad_delta_0.9', 'glm']]
y_train = df_data[['relevance_label']]

In [9]:
normalized_X = (X_train - X_train.mean()) / X_train.std()

In [10]:
log_reg = LogisticRegression()
log_reg.fit(normalized_X.values, y_train.values.ravel())

LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
          penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
          verbose=0, warm_start=False)

In [None]:
# query_doc_pairs = collections.defaultdict(list)
# doc_query_pairs = collections.defaultdict(list)

# print("Loading relevance indicators.")
# retrieval_start_time = time.time()

# with open('../ap_88_89/qrel_test') as file:
#     for line in file.readlines():
#         query_id, _, ext_doc_id, relevance = line.split()
        
#         query_doc_pairs[query_id].append(ext_doc_id)
#         doc_query_pairs[ext_doc_id].append(query_id)
    
# print("Data loaded in {} seconds.".format(time.time() - retrieval_start_time))

In [None]:
# int_to_ext_id = collections.defaultdict(int)
# ext_to_int_id = collections.defaultdict(int)

# for int_doc_id in range(index.document_base(), index.maximum_document()):
#     ext_doc_id, _ = index.document(int_doc_id)
#     int_to_ext_id[int_doc_id] = ext_doc_id
#     ext_to_int_id[ext_doc_id] = int_doc_id