## TODO
#### 0. Create the actual search engine
#### 1. Process document titles and query descriptions?
#### 2. Make solution case sensitive?
#### 3. Handle named entities?
#### 4. Computational improvements? (streaming parser, opt. cosine similarity, sparse matrices)
#### 5. Other text pre-processing
#### 6. Use OOP (particularly for inverted index, but other entities, too)
#### 7. Different b, k1 for BM25

In [11]:
from os import walk, path
from lxml import etree
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import re
from tqdm import tqdm
import pickle
from collections import Counter
import numpy as np
from datetime import datetime
from abc import ABC, abstractmethod

[nltk_data] Downloading package stopwords to /home/atotev/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /home/atotev/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/atotev/nltk_data...
[nltk_data]   Unzipping corpora/omw-1.4.zip.


### Constants

In [12]:
collection_dir = '../COLLECTION'
topics_dir = '../topics'

### Create the inverted index

In [14]:
class TextProcessor:
    def __init__(self):
        self.stop_words = set(stopwords.words(InvertedIndex.LANGUAGE))
        self.lemmatizer = WordNetLemmatizer()
    
    def count_terms(self, text):
        if not text:
            return Counter()
        f = lambda word: InvertedIndex.TERM_FILTER_REGEX.match(word) and word not in self.stop_words
        lemmatized = map(self.lemmatizer.lemmatize, text.lower().split())
        return Counter(filter(f, lemmatized))

class FileLoader:
    ELMNT_DOCID = 'DOCID'
    ELMNT_TEXT = 'TEXT'
    ELMNT_QUERYID = 'QUERYID'
    ELMNT_TITLE = 'TITLE'
    
    def load_document(self, filepath):
        tree = etree.parse(filepath)
        return (tree.find(FileLoader.ELMNT_DOCID).text, tree.find(FileLoader.ELMNT_TEXT).text)

    def load_query_file(self, filepath):
        tree = etree.parse(filepath)
        return (tree.find(FileLoader.ELMNT_QUERYID).text, tree.find(FileLoader.ELMNT_TITLE).text)
                           
class Posting:
    def __init__(self, docid, count):
        self.docid = docid
        self.count = count

class InvertedIndex:
    LANGUAGE = 'english'
    TERM_FILTER_REGEX = re.compile('^[a-z][a-z_\\-]*[a-z]$') # length>1, no punctuation-only
    
    def __init__(self):
        self._document_count = 0
        self._documents_total_length = 0
        self._index = {}
        self._document_files = {}
        self._file_loader = FileLoader()
        self._test_proc = TextProcessor()
        
    def _add_postings(self, docid, text):
        tcounts = self._test_proc.count_terms(text)
        for t in tcounts:
            if t not in self._index:
                self._index[t] = []
            self._index[t].append(Posting(docid, tcounts[t]))
        return tcounts

    def add_files(self, document_dir):
        all_documents_length = 0
        basepath, _, files = next(walk(document_dir))
        for each in tqdm(files):
            filepath = path.join(document_dir, each)
            docid, text = self._file_loader.load_document(filepath)
            self._document_files[docid] = filepath
            term_counts = self._add_postings(docid, text)
            self._documents_total_length += sum(term_counts.values())
            self._document_count += 1
    
    def get_avg_document_length(self):
        return self._documents_total_length / self._document_count
    
    def get_documents_total_length(self):
        return self._documents_total_length
            
    def keys(self):
        return self._index.keys()
    
    def get_posting_list(self, term):
        return self._index[term]
    
    def get_document_file(self, docid):
        return self._document_files[docid]
    
    def get_document_count(self):
        return self._document_count
                
index = InvertedIndex()
index.add_files(collection_dir)

100%|██████████| 12208/12208 [06:31<00:00, 31.16it/s] 


In [19]:
class SearchResult:
    def __init__(self, docid):
        self.docid = docid
        self.terms = Counter()
        self.relevance = 0.
        self.custom_data = {}

class RankingStrategy(ABC):
    @abstractmethod
    def set_param(self, name, value):
        raise Exception('Abstract method call attempted')

    @abstractmethod
    def update_rank(self, search_result, query_terms, qt):
        raise Exception('Abstract method call attempted')

class Bm25Ranking(RankingStrategy):
    PARAM_B = 'bm25.b'
    PARAM_K1 = 'bm25.k1'
    
    def __init__(self, index):
        self._index = index
        self._config = { Bm25Ranking.PARAM_B: 0.75, Bm25Ranking.PARAM_K1: 1.25 }
    
    def set_param(self, name, value):
        self._config[name] = value
        
    def update_rank(self, search_result, query_terms, qt):
        if qt not in search_result.terms:
            return
        
        tf = search_result.terms[qt]
        N = self._index.get_document_count()
        n = len(self._index.get_posting_list(qt))
        dl = sum(search_result.terms.values())
        avdl = self._index.get_avg_document_length()
        k1 = self._config[Bm25Ranking.PARAM_K1]
        b = self._config[Bm25Ranking.PARAM_B]
        search_result.relevance += tf / ((k1 * (1 - b + (b * dl / avdl))) + tf) * np.log((N - n + 0.5) / (n + 0.5))
        
class LmJmsRanking(RankingStrategy):
    PARAM_LAMBDA = 'lmjms.lambda'
    
    def __init__(self, index):
        self._index = index
        self._config = { LmJmsRanking.PARAM_LAMBDA: 0.25 }
    
    def set_param(self, name, value):
        self._config[name] = value
        
    def update_rank(self, search_result, query_terms, qt):
        Cd = search_result.terms[qt]
        sumCd = sum(search_result.terms.values())
        CD = sum(p.count for p in self._index.get_posting_list(qt))
        sumCD = self._index.get_documents_total_length()
        lmbda = self._config[LmJmsRanking.PARAM_LAMBDA]
        search_result.relevance += np.log(((1 - lmbda) * Cd / sumCd) + (lmbda * CD / sumCD))

class Search:
    PARAM_ACTIVE = 'active'

    IR_VSM = 'vsm'
    IR_BM25 = 'bm25'
    IR_LM = 'lm'
    
    RESULT_LIST_SIZE = 1000
    
    def __init__(self, index):
        self._index = index
        self._irModels = {
            Search.IR_VSM: VsmRanking(self._index),
            Search.IR_BM25: Bm25Ranking(self._index),
            Search.IR_LM: LmJmsRanking(self._index)
        }
        self._ACTIVE_RANKING = self._irModels[Search.IR_VSM]
        self._file_loader = FileLoader()
        self._text_proc = TextProcessor()
        
    def _count_terms(self, docid):
        filepath = self._index.get_document_file(docid)
        _, text = self._file_loader.load_document(filepath)
        return self._text_proc.count_terms(text)
    
    def _remove_unkown(self, term_counts):
        return Counter({x: count for x, count in term_counts.items() if x in self._index.keys()})
    
    def execute(self, query_text):
        search_results = {}
        query_terms = self._text_proc.count_terms(query_text)
        query_terms = self._remove_unkown(query_terms)
        for qt in query_terms:
            posting_list = self._index.get_posting_list(qt)
            for p in posting_list:
                if p.docid not in search_results:
                    sr = SearchResult(p.docid)
                    sr.terms = self._count_terms(p.docid)
                    search_results[p.docid] = sr
        for qt in query_terms:
            for docid in search_results:
                self._ACTIVE_RANKING.update_rank(search_results[docid], query_terms, qt)
        result = list(search_results.values())
        result.sort(reverse=True, key=lambda sr: sr.relevance)
        return result[:Search.RESULT_LIST_SIZE]

    def configure(self, param_name, param_value):
        if Search.PARAM_ACTIVE!=param_name:
            self._ACTIVE_RANKING.set_param(param_name, param_value)
        elif param_value in self._irModels:
            self._ACTIVE_RANKING = self._irModels[param_value]
        else:
            raise Exception('Unrecognized ranking: %s' % (param_value))
        return self

search = Search(index).configure(Search.PARAM_ACTIVE, Search.IR_LM)


class DocRank:
    def __init__(self, qieryid, search_result):
        self.queryid = queryid
        self.sr = search_result
        
    def to_qrel(self):
        return '%s Q0 %s rank %.6f STANDARD\n' % (self.queryid, self.sr.docid, self.sr.relevance)

queryids = set()
ranks = []
_, _, topic_files = next(walk(topics_dir))
test_topic_files = np.array(topic_files)[:5]
file_loader = FileLoader()
for each in tqdm(test_topic_files):
    queryid, query_text = file_loader.load_query_file(path.join(topics_dir, each))
    queryids.add(queryid)
    print('Processing topic %s: %s' % (queryid, query_text))
    ranks.extend(map(lambda sr: DocRank(queryid, sr), search.execute(query_text)))
        
    
def results_filepath():
    return './results_%s.txt' % (datetime.now().strftime('%d-%m-%Y_%H-%M-%S'))
results_file = results_filepath()
with open(results_file, 'w') as fp:
    fp.writelines(r.to_qrel() for r in ranks)

qrels_file = '%s.qrels' % (results_file)
with open(qrels_file, 'w') as fp:
    with open('../test_qrels.txt') as qrelf:
        fp.writelines(line for line in qrelf if any(qid in line for qid in queryids))

!{'../trec_eval-9.0.7/trec_eval -m ndcg -m map -m P %s %s' % (qrels_file, results_file)}

  0%|          | 0/5 [00:00<?, ?it/s]

Processing topic 10.2452/141-AH:  Letter Bomb for Kiesbauer


 20%|██        | 1/5 [00:13<00:55, 13.82s/it]

Processing topic 10.2452/142-AH:  Christo wraps German Reichstag


 40%|████      | 2/5 [00:21<00:30, 10.09s/it]

Processing topic 10.2452/143-AH:  Women ' s Conference Beijing


 60%|██████    | 3/5 [00:50<00:37, 18.62s/it]

Processing topic 10.2452/144-AH:  Sierra_Leone Rebellion and Diamonds


 80%|████████  | 4/5 [00:53<00:12, 12.65s/it]

Processing topic 10.2452/145-AH:  Japanese Rice Imports


100%|██████████| 5/5 [01:06<00:00, 13.37s/it]

map                   	all	0.4982
P_5                   	all	0.3200
P_10                  	all	0.2200
P_15                  	all	0.2000
P_20                  	all	0.1700
P_30                  	all	0.1133
P_100                 	all	0.0480
P_200                 	all	0.0290
P_500                 	all	0.0120
P_1000                	all	0.0062
ndcg                  	all	0.5653



