In [1]:
import os
import pandas as pd
import numpy as np
from whoosh.fields import Schema, TEXT, ID
from whoosh.index import create_in, open_dir, exists_in
from whoosh.qparser import QueryParser, OrGroup
from whoosh.query import Variations
from whoosh.scoring import BM25F, TF_IDF
from whoosh.analysis import RegexTokenizer, LowercaseFilter, StopFilter, Filter
from whoosh.analysis import Filter

import nltk
nltk.download('wordnet')

from nltk.stem import WordNetLemmatizer

import editdistance

[nltk_data] Downloading package wordnet to /home/oberon/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [2]:
lemmatizer = WordNetLemmatizer()

class LemmatizeFilter(Filter):
    def __call__(self, tokens):
        for token in tokens:
            token.text = lemmatizer.lemmatize(token.text)
            yield token

In [3]:
wiki_df = pd.read_pickle("./data/wiki.pkl")
wiki_redirects_df = pd.read_pickle("./data/wiki_redirects.pkl")
questions_df = pd.read_pickle("./data/questions.pkl")

In [4]:
redirect_lookups = {}
for _, row in wiki_redirects_df.iterrows():
    if row.redirect_index in redirect_lookups:
        redirect_lookups[row.redirect_index].append(row.title)
    else:
        redirect_lookups[row.redirect_index] = [row.title]

In [11]:
class Watson:
    def __init__(self):
        self.Q = len(questions_df.index)
        self._analyzer = self._build_analyzer()
        self._index = self._build_index()
        self._parser = self._build_parser()

    def _build_analyzer(self):
        return RegexTokenizer() | LowercaseFilter() | StopFilter() | LemmatizeFilter()
    
    def _build_index(self):
        schema = Schema(titles=ID(stored=True), content=TEXT(analyzer=self._analyzer))
        if exists_in(".index"):
            ix = open_dir(".index")
        else:
            os.mkdir(".index")
            ix = create_in(".index", schema)
            writer = ix.writer()
            for i, row in wiki_df.iterrows():
                titles = [row.title]
                if i in redirect_lookups:
                    titles += redirect_lookups[i]
                writer.add_document(titles=titles, content=row.text)
            writer.commit()
        return ix
    
    def _build_parser(self):
        og = OrGroup.factory(0.9)
        return QueryParser("content", schema=self._index.schema, group=og)
    
    def _get_category_query(self, category):
        qp = QueryParser("content", schema=self._index.schema, termclass=Variations)
        return qp.parse(category)
    
    def search(self, category, question, scorer=BM25F):
        query = self._parser.parse(question)
        with self._index.searcher(weighting=scorer()) as searcher:
            # results = searcher.search(query, filter=self._get_category_query(category), limit=None)
            # if results.scored_length() == 0:
            results = searcher.search(query, limit=None)
            if results.scored_length() == 0:
                return None
            return [(r["titles"], r.rank+1) for r in results]

    @staticmethod
    def _get_rank(results, answer):
        for answer_variant in answer.split("|"):
            for (doc_titles, rank) in results:
                if "Arlington Cemetery" in answer:
                    print(doc_titles)
                for doc_title in doc_titles:
                    if doc_title.lower() == answer_variant.lower():
                        return rank
        # if "Arlington Cemetery" in answer:
        #     return 0 # in redirects
        # if answer == "Panda|Giant panda":
        #     return 0 # in redirects
        # if "World Wide Fund" in answer:
        #     return 0 # in redirects
        return 0
        raise Exception(f"Document not found for: {answer}")
    
    # TODO use MRR for development, and precision at one for the final report
    def test(self, scorer=BM25F):
        mrr = 0.0
        for _, row in questions_df.iterrows():
            results = self.search(row.category, row.question, scorer)
            rank = Watson._get_rank(results, row.answer)
            if rank > 0:
                mrr += 1 / rank
        mrr /= self.Q
        return mrr

In [12]:
watson = Watson()
watson.test()

['Uss Arizona Memorial']
['Thomas W. Lawson (Ship)']
['Hindenburg Disaster Newsreel Footage']
['Margate Surfboat']
['Brig']
['Uss Grenadier (Ss-210)']
['American Submarine Nr-1']
['Uss Bonhomme Richard (Lhd-6)']
['George Henry Preble']
['Uss Cod (Ss-224)']
['Uscgc Mohawk (Wpg-78)']
['Barque']
['Uss Minneapolis–Saint Paul (Ssn-708)']
['Robert Ballard']
['Uss Sabine (1855)']
['Identity And Change']
['German Destroyer Lütjens (D185)']
['Jamestown (Ship)']
['Jibe']
['Frank Newcomb']
['Uss John F. Kennedy (Cv-67)']
['Italian Training Ship Amerigo Vespucci']
['Loch Ard Gorge']
['Uss Congress (1776)']
['Schooner']
['Full-Rigged Ship']
['Uss Memphis (Ssn-691)']
['Soviet Submarine K-3 Leninsky Komsomol']
['Honda Point Disaster']
['Santa María (Ship)']
['Richard Henry Dana, Jr.']
['Uss Annapolis (Ssn-760)']
['Edwin M. Shepard']
['Uss Cumberland (1842)']
['John Rodgers (Naval Officer, World War I)']
['French Frigate Surcouf (F711)']
['Tora! Tora! Tora!']
['Sailing Ship']
['Hms Endeavour']
['Edwin

0.3150642535704151

## Ideas for increasing score
- Add redirects to index
- Use a more specific analyzer
- Boost important terms in the query
- Filter documents with the category


In [35]:
wiki_df = pd.read_pickle("./data/wiki.pkl")

def build_index():
    schema = Schema(title=ID(stored=True), content=TEXT)
    if exists_in(".index"):
        ix = open_dir(".index")
    else:
        os.mkdir(".index")
        ix = create_in(".index", schema)
        writer = ix.writer()
        for _, row in wiki_df.iterrows():
            writer.add_document(title=row.title, content=row.text)
        # TODO add redirects?
        writer.commit()
    return ix


train_df = pd.read_pickle("./data/questions_train.pkl")
test_df = pd.read_pickle("./data/questions_test.pkl")
ix = build_index()

In [40]:
question = "The dominant paper in our nation's capital, it's among the top 10 U.S. papers in circulation"

og = OrGroup.factory(0.9)
parser = QueryParser("content", schema=ix.schema, group=og)
query = parser.parse(question)

query

with ix.searcher() as searcher:
    results = searcher.search(query, limit=20, terms=True)
    for hit in results:
        print(hit["title"])
        print(hit.matched_terms())
        print()

Virginia
[('content', b'10'), ('content', b'among'), ('content', b'top'), ('content', b'u.s'), ('content', b'capital'), ('content', b'our'), ('content', b'nation'), ('content', b'paper'), ('content', b'papers'), ('content', b'dominant'), ('content', b'circulation')]

Confederate States of America
[('content', b'10'), ('content', b'among'), ('content', b'top'), ('content', b'u.s'), ('content', b'capital'), ('content', b'our'), ('content', b'nation'), ('content', b'paper'), ('content', b'papers'), ('content', b'dominant'), ('content', b'circulation')]

Globalization
[('content', b'10'), ('content', b'among'), ('content', b'top'), ('content', b'u.s'), ('content', b'capital'), ('content', b'our'), ('content', b'nation'), ('content', b'paper'), ('content', b'papers'), ('content', b'dominant'), ('content', b'circulation')]

