This notebook was automatically generated with our custom script (mir.scripts.ipynb_compiler)


# Multimedia Information Retrieval - Project

This project was developed by "The Karate Kid" team:

- [Ettore Ricci](https://github.com/Etto48)
- [Paolo Palumbo](https://github.com/paolpal)
- [Zahra Omrani](https://github.com/zahra-omrani)
- [Erni Deliallisi](https://github.com/erni-de)

The whole codebase can be found on [GitHub](https://github.com/Etto48/MIRProject).



In [1]:
# install dependencies
%pip install pandas tqdm iprogress ipywidgets unidecode nltk more_itertools python-terrier torch transformers psutil
!git clone https://github.com/facebookresearch/contriever
!echo -e '[project]\nname = "src"\nversion = "0.1.0"\ndescription = "contriever"\ndependencies = ["beir", "torch", "transformers",]\n\n[project.license]\nfile = "LICENSE"\n\n[tool.setuptools.package-dir]\nsrc = "src"' > contriever/pyproject.toml
!pip install -e contriever/

fatal: destination path 'contriever' already exists and is not an empty directory.
Obtaining file:///content/contriever
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: src
  Building editable for src (pyproject.toml) ... [?25l[?25hdone
  Created wheel for src: filename=src-0.1.0-0.editable-py3-none-any.whl size=15133 sha256=189c33f886c2633edbfb172bc0647b0490b279b78f42e0cb141e78fa9a370252
  Stored in directory: /tmp/pip-ephem-wheel-cache-cv06vy6c/wheels/56/bb/cf/c177ac9deab8a3be8a77a89f9eeb27836d32bfe2c2e961775a
Successfully built src
Installing collected packages: src
  Attempting uninstall: src
    Found existing installation: src 0.1.0
    Uninstalling src-0.1.0:
      Successfully uninstalled src-0.1.0
Successfully installed src-0.1.0


In [2]:
# define __file__ and set env variable
import os
__file__ = os.path.abspath('colab.ipynb')
os.environ['MIR_NOTEBOOK'] = __file__


In [3]:
#%% === mir ===

import os
import pandas as pd
pd.options.mode.copy_on_write = True

if os.getenv("COLAB_RELEASE_TAG"):
    COLAB = True
else:
    COLAB = False

if COLAB or os.getenv("MIR_NOTEBOOK") is not None:
    PROJECT_DIR = os.path.abspath("./")
else:
    PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATA_DIR = os.path.join(PROJECT_DIR, 'data')

if not os.path.exists(DATA_DIR):
    os.mkdir(DATA_DIR)


In [4]:
#%% === mir.ir.document_contents ===



class DocumentContents:
    def __init__(self, author: str, title: str, body: str, **kwargs):
        self.author = author
        self.title = title
        self.body = body
        self.__dict__.update(kwargs)

    def add_field(self, field: str, value: str):
        self.__dict__[field] = value

    def set_score(self, score: float):
        self.score = score


In [5]:
#%% === mir.ir.token_ir ===

from dataclasses import dataclass
from enum import Enum


class TokenLocation(Enum) :
    QUERY = 0
    AUTHOR = 1
    TITLE = 2
    BODY = 3

@dataclass
class Token:
    text: str
    location: TokenLocation


In [6]:
#%% === mir.ir.tokenizer ===

from abc import abstractmethod
from typing import Protocol

# from mir.ir.document_contents import DocumentContents
# from mir.ir.token_ir import Token


class Tokenizer(Protocol):
    @abstractmethod
    def tokenize_query(self, query: str) -> list[Token]:
        """
        Tokenize a query.

        # Parameters
        - query (str): The query to tokenize.

        # Returns
        - list[Token]: The tokens of the query.
        """
    @abstractmethod
    def tokenize_document(self, doc: DocumentContents) -> list[Token]:
        """
        Tokenize a document.

        # Parameters
        - doc (DocumentContents): The document to tokenize.

        # Returns
        - list[Token]: The tokens of the document.
        """


In [7]:
#%% === mir.ir.document_info ===

# from mir.ir.document_contents import DocumentContents
# from mir.ir.token_ir import TokenLocation
# from mir.ir.tokenizer import Tokenizer


class DocumentInfo:
    def __init__(self, id: int, lengths: list[int]):
        assert len(lengths) == 3, "Lengths must have 3 elements, [author, title, body]"
        self.id = id
        self.lengths = lengths

    @staticmethod
    def from_document_contents(id: int, doc: DocumentContents, tokenizer: Tokenizer) -> "DocumentInfo":
        tokens = tokenizer.tokenize_document(doc)
        tokens_for_field = [0,0,0]
        for token in tokens:
            match token.location:
                case TokenLocation.AUTHOR:
                    field_offset = 0
                case TokenLocation.TITLE:
                    field_offset = 1
                case TokenLocation.BODY:
                    field_offset = 2
                case _:
                    raise ValueError(f"Invalid token location {token.location}")
            tokens_for_field[field_offset] += 1
        return DocumentInfo(id, tokens_for_field)


In [8]:
#%% === mir.ir.posting ===

from typing import Optional

class Posting:
    def __init__(self, doc_id: int, term_id: int, occurrences: Optional[dict[str, int]] = None):
        self.term_id = term_id
        self.doc_id = doc_id
        self.occurrences = occurrences if occurrences is not None else {"author": 0, "title": 0, "body": 0}

    def __repr__(self) -> str:
        return f"Posting(doc_id={self.doc_id}, term_id={self.term_id}, occurrences={self.occurrences})"


In [9]:
#%% === mir.ir.term ===

class Term:
    def __init__(self, term: str, id: int, **kwargs):
        self.term = term
        self.id = id
        self.info = kwargs


In [10]:
#%% === mir.ir.scoring_function ===

from typing import Any, Callable, Optional, Protocol

# from mir.ir.document_info import DocumentInfo
# from mir.ir.posting import Posting
# from mir.ir.term import Term


class ScoringFunction(Protocol):
    batched_call: Optional[Callable[["ScoringFunction",list[str],str], list[float]]] = None
    def __call__(self, document_info: DocumentInfo, postings: list[Posting], query: list[Term], **kwargs: dict[str, Any]) -> float:
        """
        Score a document based on the postings and the query.

        # Parameters
        - document_info (DocumentInfo): The document info relative to the document to score.
        - postings (list[Posting]): The postings related to the document and the query.
        - query (list[Term]): The query terms.
        - **kwargs (dict[str, Any]): Additional arguments for the scoring function.

        # Returns
        - float: The score of the document.
        """


In [11]:
#%% === mir.ir.impls.bm25f_scoring ===

from typing import List, Dict
# from mir.ir.document_info import DocumentInfo
# from mir.ir.posting import Posting
# from mir.ir.scoring_function import ScoringFunction
# from mir.ir.term import Term
import math


class BM25FScoringFunction(ScoringFunction):
    def __init__(self, k1: float = 1.5, b: float = 0.75, field_weights: Dict[str, float] = None):
        self.k1 = k1
        self.b = b
        self.field_weights = field_weights if field_weights is not None else {'title': 2.0, 'body': 1.0, 'author': 0.5}

    def _build_postings_dict(self, postings: List[Posting]) -> Dict[int, Posting]:
        return {posting.term_id: posting for posting in postings}

    def __call__(self, document: DocumentInfo, postings: List[Posting], query: List[Term], *, num_docs: int, avg_field_lengths: dict[str, int], **_) -> float:
        postings_dict = self._build_postings_dict(postings)
        score = 0.0
        for term in query:
            if term.id in postings_dict:
                score += self._rsv(term, document, num_docs, postings_dict, avg_field_lengths)
        return score

    def _rsv(self, term: Term, document: DocumentInfo, num_docs: int, postings_dict: dict[int, Posting], avg_field_lengths: dict[str, int]) -> float:
        tfd = self._wtf(term, document, postings_dict, avg_field_lengths)
        idf = math.log(num_docs / term.info['document_frequency'])

        if tfd > 0:
            return (tfd / (self.k1 + tfd)) * idf
        return 0.0

    def _wtf(self, term: Term, document: DocumentInfo, postings_dict: dict[int, Posting], avg_field_lengths: dict[str, int]) -> float:
        tfd = 0.0
        field_indices = {"author": 0, "title": 1, "body": 2}

        if term.id not in postings_dict:
            return 0.0

        posting = postings_dict[term.id]
        for field, weight in self.field_weights.items():
            tf = posting.occurrences.get(field, 0)
            if tf == 0:
                continue
            field_index = field_indices[field]
            avg_dlf = avg_field_lengths[field]
            bb = 1 - self.b + self.b * document.lengths[field_index] / avg_dlf
            tfd += weight * tf / bb

        return tfd


In [12]:
#%% === mir.ir.impls.count_scoring_function ===

# from mir.ir.scoring_function import ScoringFunction


class CountScoringFunction(ScoringFunction):
    def __call__(self, document, postings, query, **kwargs):
        return len(postings) / len(query)


In [13]:
#%% === mir.utils.sized_generator ===

from collections.abc import Generator, Sized
from typing import TypeVar, Generic

T = TypeVar('T')
P = TypeVar('P')
Q = TypeVar('Q')


class SizedGenerator(Generic[T, P, Q], Generator[T, P, Q], Sized):
    def __init__(self, generator: Generator[T, P, Q], length: int):
        self.generator = generator
        self.length = length

    def __iter__(self):
        return self.generator

    def __len__(self):
        return self.length

    def send(self, value):
        return self.generator.send(value)

    def throw(self, typ, val=None, tb=None):
        return self.generator.throw(typ, val, tb)


# Index

This interface is the component that actually holds the inverted index and exposes the methods to interact with it.



In [14]:
#%% === mir.ir.index ===

from abc import abstractmethod
from collections.abc import Generator
from typing import Any, Optional, Protocol
from tqdm.auto import tqdm

# from mir.ir.document_info import DocumentInfo
# from mir.ir.document_contents import DocumentContents
# from mir.ir.posting import Posting
# from mir.ir.term import Term
# from mir.ir.tokenizer import Tokenizer
# from mir.utils.sized_generator import SizedGenerator


class Index(Protocol):
    def get_global_info(self) -> dict[str, Any]:
        """
        Get global info from the index.

        # Returns
        - dict[str, int]: A dictionary with global info.
        """
        return {}

    @abstractmethod
    def get_postings(self, term_id: int) -> Generator[Posting, None, None]:
        """
        Get a generator of postings for a term_id.
        MUST be sorted by doc_id.

        # Parameters
        - term_id (int): The term_id.

        # Yields
        - Posting: A posting from the posting list related to the term_id.
        """

    @abstractmethod
    def get_document_info(self, doc_id: int) -> DocumentInfo:
        """
        Get document info from a doc_id.

        # Parameters
        - doc_id (int): The doc_id.

        # Returns
        - DocumentInfo: The document info related to the doc_id.
        """

    def get_document_contents(self, doc_id: int) -> DocumentContents:
        """
        Get document contents from a doc_id.

        # Parameters
        - doc_id (int): The doc_id.

        # Returns
        - DocumentContents: The document contents related to the doc_id.
        """

    @abstractmethod
    def get_term(self, term_id: int) -> Term:
        """
        Get term info from a term_id.

        # Parameters
        - term_id (int): The term_id.

        # Returns
        - Term: The term related to the term_id.
        """

    @abstractmethod
    def get_term_id(self, term: str) -> Optional[int]:
        """
        Get term_id from a term in string format.
        Returns None if the term is not in the index.

        # Parameters
        - term (str): The term in string format.

        # Returns
        - Optional[int]: The term_id related to the term or None if the term is not in the index.
        """

    @abstractmethod
    def __len__(self) -> int:
        """
        Get the number of documents in the index.

        # Returns
        - int: The number of documents in the index.
        """

    @abstractmethod
    def index_document(self, doc: DocumentContents, tokenizer: Tokenizer) -> None:
        """
        Add a document to the index.

        # Parameters
        - doc (DocumentContents): The document to add to the index.
        - tokenizer (Tokenizer): The tokenizer to use to tokenize the document.
        """

    def bulk_index_documents(self, docs: SizedGenerator[DocumentContents, None, None], tokenizer: Tokenizer, verbose: bool = False) -> None:
        """
        Add multiple documents to the index, this calls index_document for each document.

        # Parameters
        - docs (SizedGenerator[DocumentContents, None, None]): A generator of documents to add to the index.
        - tokenizer (Tokenizer): The tokenizer to use to tokenize the documents.
        - verbose (bool): Whether to show a progress bar.
        """
        for doc in tqdm(docs, desc="Indexing documents", disable=not verbose, total=len(docs)):
            self.index_document(doc, tokenizer)


In [15]:
#%% === mir.ir.impls.default_index ===

from collections import OrderedDict
from collections.abc import Generator
# import os
import pickle
from typing import Any, Optional
# from mir.ir.document_info import DocumentInfo
# from mir.ir.document_contents import DocumentContents
# from mir.ir.index import Index
# from mir.ir.posting import Posting
# from mir.ir.term import Term
# from mir.ir.token_ir import TokenLocation
# from mir.ir.tokenizer import Tokenizer
# from mir.utils.sized_generator import SizedGenerator


class DefaultIndex(Index):
    def __init__(self, path: Optional[str] = None):
        super().__init__()
        self.postings: list[OrderedDict[Posting]] = []
        self.document_info: list[DocumentInfo] = []
        self.document_contents: list[DocumentContents] = []
        self.terms: list[Term] = []
        self.term_lookup: dict[str, int] = {}
        self.path = None
        self.total_field_lengths = {
            "author": 0,
            "title": 0,
            "body": 0
        }
        if path is not None:
            self.path = path
            if os.path.exists(path):
                self.load()

    def get_postings(self, term_id: int) -> Generator[Posting, None, None]:
        for doc_id, posting in self.postings[term_id].items():
            yield posting

    def get_document_info(self, doc_id: int) -> DocumentInfo:
        return self.document_info[doc_id]

    def get_document_contents(self, doc_id: int) -> DocumentContents:
        return self.document_contents[doc_id]

    def get_term(self, term_id: int) -> Term:
        return self.terms[term_id]

    def get_term_id(self, term: str) -> Optional[int]:
        return self.term_lookup.get(term)

    def get_global_info(self) -> dict[str, Any]:
        return {
            "avg_field_lengths": {
                "author": self.total_field_lengths["author"] / len(self.document_info),
                "title": self.total_field_lengths["title"] / len(self.document_info),
                "body": self.total_field_lengths["body"] / len(self.document_info)
            },
            "num_docs": len(self.document_info)
        }

    def __len__(self) -> int:
        return len(self.document_info)

    def index_document(self, doc: DocumentContents, tokenizer: Tokenizer) -> None:
        terms = tokenizer.tokenize_document(doc)
        author_length = sum(1 for term in terms if term.location == TokenLocation.AUTHOR)
        title_length = sum(1 for term in terms if term.location == TokenLocation.TITLE)
        body_length = sum(1 for term in terms if term.location == TokenLocation.BODY)
        self.total_field_lengths["author"] += author_length
        self.total_field_lengths["title"] += title_length
        self.total_field_lengths["body"] += body_length
        term_ids = []
        for term in terms:
            if term.text not in self.term_lookup:
                term_id = len(self.terms)
                self.terms.append(Term(term.text, term_id))
                self.term_lookup[term.text] = term_id
            else:
                term_id = self.term_lookup[term.text]
            term_ids.append(term_id)
        doc_id = len(self.document_info)
        self.document_info.append(DocumentInfo.from_document_contents(doc_id, doc, tokenizer))
        self.document_contents.append(doc)
        for term_id in term_ids:
            if term_id >= len(self.postings):
                self.postings.append(OrderedDict())
            self.postings[term_id][doc_id] = Posting(doc_id, term_id)

    def bulk_index_documents(self, docs: SizedGenerator[DocumentContents, None, None], tokenizer: Tokenizer, verbose: bool = False) -> None:
        super().bulk_index_documents(docs, tokenizer, verbose)
        if self.path is not None:
            self.save()

    def load(self):
        if self.path is not None:
            try:
                with open(self.path, "rb") as f:
                    postings, document_info, document_contents, terms, term_lookup = pickle.load(f)
                assert isinstance(postings, list)
                assert isinstance(document_info, list)
                assert isinstance(document_contents, list)
                assert isinstance(terms, list)
                assert isinstance(term_lookup, dict)
            except Exception as e:
                pass
            else:
                self.postings = postings
                self.document_info = document_info
                self.document_contents = document_contents
                self.terms = terms
                self.term_lookup = term_lookup
        else:
            raise ValueError("Path not set for index.")

    def save(self):
        if self.path is not None:
            with open(self.path, "wb") as f:
                pickle.dump((self.postings, self.document_info, self.document_contents, self.terms, self.term_lookup), f)
        else:
            raise ValueError("Path not set for index.")


In [16]:
#%% === mir.ir.impls.default_tokenizers ===

import string
import nltk
import nltk.corpus
import unidecode
# from mir import DATA_DIR
# from mir.ir.document_contents import DocumentContents
# from mir.ir.tokenizer import Tokenizer
# from mir.ir.token_ir import Token, TokenLocation


class DefaultTokenizer(Tokenizer):
    def __init__(self):
        download_dir = f"{DATA_DIR}/nltk_data"
        nltk.download("stopwords", quiet=True, download_dir=download_dir,)
        stopwords_from_path = nltk.data.find("corpora/stopwords/english", [download_dir])
        with open(stopwords_from_path) as f:
            self.stopwords = frozenset(f.read().splitlines())

        self.stemmer = nltk.SnowballStemmer("english")
        self.remove_punctuation = str.maketrans(string.punctuation, " " * len(string.punctuation))
        self.separate_numbers = str.maketrans({key: f" {key} " for key in string.digits})

    def preprocess(self, text: str):
        # normalize unicode
        text = unidecode.unidecode(text, errors="replace", replace_str=" ")
        # replace punctuation with space
        text = text.translate(self.remove_punctuation).lower()
        # separate numbers with a space
        text = text.translate(self.separate_numbers)
        # split text into words
        words = text.split()
        # remove stopwords
        words = [word for word in words if word not in self.stopwords]
        # stem words
        words: list[str] = [self.stemmer.stem(word) for word in words]
        return words


    def tokenize_query(self, query: str) -> list[Token]:
        query_list = self.preprocess(query)
        token_list = [Token(word, TokenLocation.QUERY) for word in query_list]

        return token_list

    def tokenize_document(self, doc: DocumentContents) -> list[Token]:
        author_list = self.preprocess(doc.author)
        title_list = self.preprocess(doc.title)
        body_list = self.preprocess(doc.body)

        token_list = \
            [Token(aword, TokenLocation.AUTHOR) for aword in author_list] + \
            [Token(tword, TokenLocation.TITLE) for tword in title_list] + \
            [Token(bword, TokenLocation.BODY) for bword in body_list]

        return token_list


In [17]:
#%% === mir.neural_relevance.dataset ===

from typing import Literal
import torch
from torch import nn
# import pandas as pd
from tqdm.auto import tqdm

# from mir import DATA_DIR

class MSMarcoDataset(torch.utils.data.Dataset):
    def __init__(self, collection_path: str, queries_path: str, qrels_path: str):
        self.collection = self.load_collection(collection_path)
        self.queries = self.load_queries(queries_path)
        self.qrels = self.load_qrels(qrels_path)

    @staticmethod
    def load(mode: Literal["train", "valid", "test"]):
        collection_path = f"{DATA_DIR}/msmarco/collection.tsv"
        match mode:
            case "train":
                queries_path = f"{DATA_DIR}/msmarco/queries.train.tsv"
                qrels_path = f"{DATA_DIR}/msmarco/qrels.train.tsv"
            case "valid":
                queries_path = f"{DATA_DIR}/msmarco/msmarco-test2019-queries.tsv"
                qrels_path = f"{DATA_DIR}/msmarco/2019qrels-pass.txt"
            case "test":
                raise NotImplementedError(f"Mode {mode} not implemented.")
            case _:
                raise ValueError(f"Invalid mode {mode}.")
        return MSMarcoDataset(collection_path, queries_path, qrels_path)

    def load_collection(self, collection_path: str):
        collection = pd.read_csv(collection_path, sep='\t', header=None, names=['docid', 'text'], index_col='docid')
        return collection

    def load_queries(self, queries_path: str):
        queries = pd.read_csv(queries_path, sep='\t', header=None, names=['qid', 'text'], index_col='qid')
        return queries

    def load_qrels(self, qrels_path: str):
        sep = ' ' if qrels_path.endswith(".txt") else '\t'
        qrels = pd.read_csv(qrels_path, sep=sep, header=None, names=['qid', 'Q0', 'docid', 'relevance'])
        return qrels

    def __len__(self):
        return len(self.qrels)

    def __getitem__(self, idx):
        qid = self.qrels.iloc[idx]['qid']
        docid = self.qrels.iloc[idx]['docid']
        relevance = self.qrels.iloc[idx]['relevance']
        query = self.queries.loc[qid]['text']
        doc = self.collection.loc[docid]['text']
        return query, doc, relevance

    @staticmethod
    def collate_fn(batch):
        queries, docs, relevances = zip(*batch)
        return queries, docs, torch.tensor(relevances, dtype=torch.float32)


In [18]:
# import os
import requests
# import torch
from torch import nn
from tqdm.auto import tqdm
import transformers
from src.contriever import Contriever

contriever_ = Contriever.from_pretrained("facebook/contriever-msmarco")
tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/contriever-msmarco")

sentences = [
    "Where was Marie Curie born?",
    "Maria Sklodowska, later known as Marie Curie, was born on November 7, 1867.",
    "Born in Paris on 15 May 1859, Pierre Curie was the son of Eugène Curie, a doctor of French Catholic origin from Alsace."
]

inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
embeddings = contriever_(**inputs)

x = torch.triu(embeddings @ embeddings.T, diagonal = 1)
x = x[x != 0]
x.squeeze()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of the model checkpoint at facebook/contriever-msmarco were not used when initializing Contriever: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing Contriever from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Contriever from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenc

tensor([1.7735, 1.5486, 1.5952], grad_fn=<SqueezeBackward0>)

In [36]:
#%% === mir.neural_relevance.model ===

# import os
import requests
#import torch
from torch import nn
from tqdm.auto import tqdm
import transformers

# from mir import DATA_DIR
# from mir.neural_relevance.dataset import MSMarcoDataset

class NeuralRelevance(nn.Module):
    def __init__(self):
        super().__init__()
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        model_name = "facebook-contriever"
        self.tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/contriever-msmarco")
        self.model = Contriever.from_pretrained("facebook/contriever-msmarco").to(self.device)
        for param in self.model.parameters():
            param.requires_grad = False

        bert_embedding_size = self.model.config.hidden_size
        self.similairty_head = nn.Sequential(
            nn.Linear(bert_embedding_size, 1, device=self.device),
            nn.Sigmoid()
        ).to(self.device)


    def forward(self, q_tokens: dict, d_tokens: dict) -> torch.Tensor:
        with torch.no_grad():
            query_embeddings = self.model(**q_tokens)
            doc_embeddings = self.model(**d_tokens)
        x = torch.triu(query_embeddings @ doc_embeddings.T, diagonal=1)
        x = x[x != 0]
        print(x.shape)
        return x

    def preprocess(self, queries: list[str], documents: list[str]):
        q_tokens = self.tokenizer(queries, return_tensors="pt", truncation=True, padding=True).to(self.device)
        d_tokens = self.tokenizer(documents, return_tensors="pt", truncation=True, padding=True).to(self.device)
        return q_tokens, d_tokens

    def forward_queries_and_documents(self, queries: list[str], documents: list[str]) -> torch.Tensor:
        q_tokens, d_tokens = self.preprocess(queries, documents)
        return self.forward(q_tokens, d_tokens)

    def save(self, path: str):
        torch.save(self.state_dict(), path)

    def return_model(self):
      return self.model

    @staticmethod
    def from_pretrained():
        if not os.path.exists(f"{DATA_DIR}/neural_relevance.pt"):
            url = "https://huggingface.co/Etto48/MIRProject/resolve/main/neural_relevance.pt"
            weights_request = requests.get(url)
            weights_request.raise_for_status()
            with tqdm(total=int(weights_request.headers["Content-Length"]), unit="B", unit_scale=True, desc="Downloading weights") as pbar:
                with open(f"{DATA_DIR}/neural_relevance.pt", "wb") as f:
                    for chunk in weights_request.iter_content(chunk_size=1024):
                        f.write(chunk)
                        pbar.update(len(chunk))
        model = NeuralRelevance.load(f"{DATA_DIR}/neural_relevance.pt")
        return model


In [37]:
#%% === mir.ir.impls.neural_scoring_function ===

import numpy as np
# import torch
from tqdm.auto import tqdm

# from mir import DATA_DIR
# from mir.neural_relevance.model import NeuralRelevance
# from mir.ir.document_info import DocumentInfo
# from mir.ir.posting import Posting
# from mir.ir.scoring_function import ScoringFunction
# from mir.ir.term import Term
# from mir.neural_relevance.dataset import MSMarcoDataset


class NeuralScoringFunction(ScoringFunction):
    def __init__(self):
        # Load the model
        self.model = NeuralRelevance()
        self.model.eval()

    def __call__(self, document: DocumentInfo, postings: list[Posting], query: list[Term], *, document_content: str, query_content: str, **kwargs) -> float:
        if len(document_content) == 0 or len(query_content) == 0:
            return 0.0
        with torch.no_grad():
            score = self.model.forward_queries_and_documents([query_content], [document_content])
        return score.item()
    """
    def batched_call(self, document_contents: list[str], query_contents: str) -> list[float]:
        scores = []
        with torch.no_grad():
            scores = self.model.forward_queries_and_documents([query_contents]*len(document_contents), document_contents)
        return scores.tolist()"""


In [21]:
#%% === mir.ir.impls.sqlite_index ===

from collections.abc import Generator
# import os
import sqlite3
import sys
from typing import Any, Optional

import psutil
# from mir.ir.document_info import DocumentInfo
# from mir.ir.document_contents import DocumentContents
# from mir.ir.impls.default_tokenizers import DefaultTokenizer
# from mir.ir.index import Index
# from mir.ir.posting import Posting
# from mir.ir.term import Term
# from mir.ir.token_ir import TokenLocation
# from mir.ir.tokenizer import Tokenizer


class SqliteIndex(Index):
    def __init__(self, path: Optional[str] = None):
        super().__init__()

        self.connection = sqlite3.connect(
            path if path is not None else ":memory:",
            check_same_thread=False,
            cached_statements=1024,)

        assert sys.version_info.major == 3, "Python 2 is not supported"
        assert sys.version_info.minor >= 10, "Python <3.10 is not supported"

        legacy_mode = sys.version_info.minor == 10
        if legacy_mode:
            self.isolation_level = None
        else:
            self.connection.autocommit = True

        self.connection.execute("pragma synchronous = off")
        self.connection.execute(f"pragma threads = {os.cpu_count()}")
        self.connection.execute("pragma journal_mode = WAL")
        cache_memory = psutil.virtual_memory().total // 1024 // 2
        self.connection.execute(f"pragma cache_size = {-cache_memory}")
        self.connection.execute(f"pragma mmap_size = {1024*1024*1024 * 16}")
        self.connection.execute("pragma temp_store = memory")

        if legacy_mode:
            self.connection.isolation_level = "DEFERRED"
        else:
            self.connection.autocommit = False


        self.connection.execute(
            "create table if not exists postings "
            "(term_id integer references terms(term_id) not null, "
            "doc_id integer references document_info(doc_id) not null, "
            "occurrences_author integer not null, "
            "occurrences_title integer not null, "
            "occurrences_body integer not null, "
            "primary key (term_id, doc_id))")
        self.connection.execute(
            "create table if not exists document_info "
            "(doc_id integer not null primary key autoincrement, "
            "author_len integer not null, "
            "title_len integer not null, "
            "body_len integer not null)")
        self.connection.execute(
            "create table if not exists document_contents "
            "(doc_id integer not null primary key references document_info(doc_id), "
            "author text, "
            "title text, "
            "body text)")
        self.connection.execute(
            "create table if not exists terms "
            "(term_id integer not null primary key autoincrement, "
            "term text unique not null, "
            "document_frequency integer not null)")

        self.connection.execute("create table if not exists global_info (key text not null primary key, value integer)")
        # add global info default values if not present
        self.connection.execute("insert or ignore into global_info values ('total_author_len', 0)")
        self.connection.execute("insert or ignore into global_info values ('total_title_len', 0)")
        self.connection.execute("insert or ignore into global_info values ('total_body_len', 0)")
        self.connection.execute("insert or ignore into global_info values ('num_docs', 0)")

        self.connection.execute("pragma optimize")

        self.connection.commit()
        self.global_info_dirty = True
        self.cached_global_info = None

    def get_postings(self, term_id: int) -> Generator[Posting, None, None]:
        cursor = self.connection.cursor()
        cursor.execute(
            "select doc_id, occurrences_author, occurrences_title, occurrences_body from postings where term_id = ? "
            "order by doc_id", (term_id,))
        def row_factory(_cursor, row):
            return Posting(row[0], term_id, {"author": row[1], "title": row[2], "body": row[3]})
        cursor.row_factory = row_factory
        yield from cursor

    def get_document_info(self, doc_id: int) -> DocumentInfo:
        cursor = self.connection.cursor()
        cursor.execute("select author_len, title_len, body_len from document_info where doc_id = ?", (doc_id,))
        author_len, title_len, body_len = cursor.fetchone()
        return DocumentInfo(doc_id, [author_len, title_len, body_len])

    def get_document_contents(self, doc_id: int) -> DocumentContents:
        cursor = self.connection.cursor()
        cursor.execute("select author, title, body from document_contents where doc_id = ?", (doc_id,))
        author, title, body = cursor.fetchone()
        return DocumentContents(author, title, body)

    def get_term(self, term_id: int) -> Term:
        cursor = self.connection.cursor()
        cursor.execute("select term, document_frequency from terms where term_id = ?", (term_id,))
        term, document_frequency = cursor.fetchone()
        return Term(term, term_id, document_frequency=document_frequency)

    def get_term_id(self, term: str) -> Optional[int]:
        cursor = self.connection.cursor()
        cursor.execute("select term_id from terms where term = ?", (term,))
        result = cursor.fetchone()
        return result[0] if result is not None else None

    def get_global_info(self) -> dict[str, Any]:
        if self.global_info_dirty:
            cursor = self.connection.cursor()
            cursor.execute("select key, value from global_info")
            global_info = cursor.fetchall()
            global_info = {key: value for key, value in global_info}
            self.cached_global_info = {
                "avg_field_lengths": {
                    "author": global_info["total_author_len"] / global_info["num_docs"],
                    "title": global_info["total_title_len"] / global_info["num_docs"],
                    "body": global_info["total_body_len"] / global_info["num_docs"]
                },
                "num_docs": global_info["num_docs"]
            }
            self.global_info_dirty = False
        return self.cached_global_info

    def __len__(self) -> int:
        cursor = self.connection.cursor()
        cursor.execute("select value from global_info where key = 'num_docs'")
        return cursor.fetchone()[0]

    def _increment_field_lengths(self, author_len: int, title_len: int, body_len: int) -> None:
        cursor = self.connection.cursor()
        cursor.execute("update global_info set value = value + ? where key = 'total_author_len'", (author_len,))
        cursor.execute("update global_info set value = value + ? where key = 'total_title_len'", (title_len,))
        cursor.execute("update global_info set value = value + ? where key = 'total_body_len'", (body_len,))

    def _create_or_get_term_id(self, term: str) -> int:
        cursor = self.connection.cursor()
        cursor.execute("insert or ignore into terms(term, document_frequency) values (?, 0)", (term,))
        cursor.execute("select term_id from terms where term = ?", (term,))
        return cursor.fetchone()[0]

    def _new_document(self, doc: DocumentContents, author_len: int, title_len: int, body_len: int) -> int:
        cursor = self.connection.cursor()
        if doc.__dict__.get("doc_id") is not None:
            cursor.execute("insert into document_info(doc_id, author_len, title_len, body_len) values (?, ?, ?, ?)", (doc.doc_id, author_len, title_len, body_len))
        else:
            cursor.execute("insert into document_info(author_len, title_len, body_len) values (?, ?, ?)", (author_len, title_len, body_len))
        doc_id = cursor.lastrowid
        cursor.execute("insert into document_contents(doc_id, author, title, body) values (?, ?, ?, ?)", (doc_id, doc.author, doc.title, doc.body))
        cursor.execute("update global_info set value = value + 1 where key = 'num_docs'")
        return doc_id

    def _update_postings(self, term_id: int, doc_id: int, location: TokenLocation) -> None:
        increments = {
            "author": 1 if location == TokenLocation.AUTHOR else 0,
            "title": 1 if location == TokenLocation.TITLE else 0,
            "body": 1 if location == TokenLocation.BODY else 0
        }
        cursor = self.connection.cursor()
        cursor.execute("select occurrences_author, occurrences_title, occurrences_body from postings where term_id = ? and doc_id = ?", (term_id, doc_id))
        result = cursor.fetchone()
        if result is None:
            cursor.execute(
                "insert into postings(term_id, doc_id, occurrences_author, occurrences_title, occurrences_body) "
                "values (?, ?, ?, ?, ?)", (term_id, doc_id, increments["author"], increments["title"], increments["body"]))
        else:
            cursor.execute(
                "update postings set occurrences_author = occurrences_author + ?, occurrences_title = occurrences_title + ?, "
                "occurrences_body = occurrences_body + ? where term_id = ? and doc_id = ?",
                (increments["author"], increments["title"], increments["body"], term_id, doc_id))

    def _contains_document(self, doc_id: int) -> bool:
        cursor = self.connection.cursor()
        cursor.execute("select count(*) from document_info where doc_id = ?", (doc_id,))
        ret = cursor.fetchone()[0]
        return ret > 0

    def _increment_document_frequency(self, term_id: int) -> None:
        cursor = self.connection.cursor()
        cursor.execute("update terms set document_frequency = document_frequency + 1 where term_id = ?", (term_id,))

    def index_document(self, doc: DocumentContents, tokenizer: Tokenizer) -> None:

        if doc.__dict__.get("doc_id") is not None:
            if self._contains_document(doc.doc_id):
                return
        self.global_info_dirty = True

        terms = tokenizer.tokenize_document(doc)
        author_length = sum(1 for term in terms if term.location == TokenLocation.AUTHOR)
        title_length = sum(1 for term in terms if term.location == TokenLocation.TITLE)
        body_length = sum(1 for term in terms if term.location == TokenLocation.BODY)

        self._increment_field_lengths(author_length, title_length, body_length)

        encountered_terms = set()
        term_ids_and_locations = []
        for term in terms:
            term_id = self._create_or_get_term_id(term.text)
            if term_id not in encountered_terms:
                self._increment_document_frequency(term_id)
            encountered_terms.add(term_id)
            term_ids_and_locations.append((term_id, term.location))
        doc_id = self._new_document(doc, author_length, title_length, body_length)
        for term_id, location in term_ids_and_locations:
            self._update_postings(term_id, doc_id, location)
        self.connection.commit()

    def bulk_index_documents(self, docs, tokenizer, verbose = False):
        super().bulk_index_documents(docs, tokenizer, verbose)
        self.connection.execute("pragma optimize")
        self.connection.commit()


In [22]:
#%% === mir.ir.priority_queue ===

import heapq
from typing import Iterable, Optional, Sized

class PriorityQueue(Iterable[tuple[float, int]], Sized):
    def __init__(self, max_size: int):
        """
        Create a priority queue with a maximum size.
        """
        self.heap = []
        self.finalised = False
        self.max_size = max_size

    def push(self, doc_id: int, score: float) -> Optional[int]:
        """
        Add an item with a given score to the priority queue

        Returns the doc_id of the item that was popped, if any. If the new item was not added returns its doc_id.
        """
        if len(self) == self.max_size:
            if score > self.heap[0][0]:
                popped = heapq.heappushpop(self.heap, (score, doc_id))
                return popped[1]
            else:
                return doc_id
        else:
            heapq.heappush(self.heap, (score, doc_id))
            return None

    def finalise(self):
        """
        Call this after all items have been pushed to the priority queue.
        """
        self.heap.sort(reverse=True)
        self.finalised = True

    def __iter__(self) -> Iterable[tuple[float, int]]:
        """
        Iterate over the items in the priority queue.
        """
        if not self.finalised:
            raise ValueError("Priority queue must be finalised before iterating")
        return iter(self.heap)

    def __len__(self) -> int:
        """
        Get the number of items in the priority queue.
        """
        return len(self.heap)


# Ir system

This class is the core of the project. All the components of the system are needed in order to construct an instance of this class.
It uses the components to perform the indexing and search operations.



In [23]:
#%% === mir.ir.ir ===

from collections.abc import Generator
# import string
import time
from typing import Optional

# import pandas as pd
from tqdm.auto import tqdm
from more_itertools import peekable

# from mir.ir.document_contents import DocumentContents
# from mir.ir.impls.default_index import DefaultIndex
# from mir.ir.impls.count_scoring_function import CountScoringFunction
# from mir.ir.impls.default_tokenizers import DefaultTokenizer
# from mir.ir.index import Index
# from mir.ir.priority_queue import PriorityQueue
# from mir.ir.scoring_function import ScoringFunction
# from mir.ir.tokenizer import Tokenizer
# from mir.utils.sized_generator import SizedGenerator

class Ir:
    def __init__(self, index: Optional[Index] = None, tokenizer: Optional[Tokenizer] = None, scoring_functions: Optional[list[tuple[int, ScoringFunction]]] = None):
        """
        Create an IR system.

        # Parameters
        - index (Index): The index to use. If None, a DefaultIndex is used.
        - tokenizer (Tokenizer): The tokenizer to use. If None, a DefaultTokenizer is used.
        - scoring_functions (Optional[list[tuple[int, ScoringFunction]]]): A list of scoring functions to use, with their respective top_k results to keep.
        If None CountScoringFunction is used.
        """
        self.index: Index = index if index is not None else DefaultIndex()
        self.tokenizer: Tokenizer = tokenizer if tokenizer is not None else DefaultTokenizer()
        self.scoring_functions: list[tuple[int, ScoringFunction]] = scoring_functions if scoring_functions is not None else [
            (1000, CountScoringFunction())
        ]

    def __len__(self) -> int:
        """
        Get the number of documents in the index.
        """
        return len(self.index)

    def index_document(self, doc: DocumentContents) -> None:
        """
        Index a document.

        # Parameters
        - doc (DocumentContents): The document to index.
        """
        self.index.index_document(doc, self.tokenizer)

    def bulk_index_documents(self, docs: SizedGenerator[DocumentContents, None, None], verbose: bool = False) -> None:
        """
        Bulk index documents.

        # Parameters
        - docs (SizedGenerator[DocumentContents, None, None]): A generator of documents to index.
        - verbose (bool): Whether to show a progress bar.
        """
        self.index.bulk_index_documents(docs, self.tokenizer, verbose)

    def search(self, query: str) -> Generator[DocumentContents, None, None]:
        """
        Search for documents based on a query.
        Uses document-at-a-time scoring.

        # Parameters
        - query (str): The query to search for.
        - scoring_functions (list[ScoringFunction]): A list of scoring functions to use.

        # Yields
        - DocumentContents: A document that matches the query. In decreasing order of score.
        It also has a score attribute with the score of the document.
        """

        assert len(self.scoring_functions) > 0, "At least one scoring function must be provided"

        ks, scoring_functions = zip(*self.scoring_functions)
        scoring_functions: list[ScoringFunction] = list(scoring_functions)
        ks: list[int] = list(ks)[::-1]

        terms = self.tokenizer.tokenize_query(query)
        term_ids = [term_id for term in terms if (
            term_id := self.index.get_term_id(term.text)) is not None]
        terms = [self.index.get_term(term_id) for term_id in term_ids]
        posting_generators = [
            peekable(self.index.get_postings(term_id)) for term_id in term_ids]

        priority_queue = PriorityQueue(ks[-1])
        first_scoring_function = scoring_functions[0]
        postings_cache = {}

        while True:
            # find the lowest doc_id among all the posting lists
            # doing this avoids having to iterate over all the doc_ids
            # we only take into account the doc_ids that are present in the posting lists
            lowest_doc_id = None
            empty_posting_lists = []
            for i, posting in enumerate(posting_generators):
                try:
                    doc_id = posting.peek().doc_id
                    if lowest_doc_id is None or doc_id < lowest_doc_id:
                        lowest_doc_id = doc_id
                except StopIteration:
                    empty_posting_lists.append(i)
            # all the posting lists are empty
            if lowest_doc_id is None:
                break

            # remove the empty posting lists
            for i in reversed(empty_posting_lists):
                posting_generators.pop(i)
                term_ids.pop(i)

            postings = []
            # get all the postings with the current doc_id, and advance their iterators
            for i, posting in enumerate(posting_generators):
                if posting.peek().doc_id == lowest_doc_id:
                    next_posting = next(posting)
                    postings.append(next_posting)
            postings_cache[lowest_doc_id] = postings
            # now that we have all the info about the current document, we can score it
            global_info = self.index.get_global_info()
            document_info = self.index.get_document_info(lowest_doc_id)
            score = first_scoring_function(document_info, postings, terms, **global_info)
            # we add the score and doc_id to the priority queue
            popped_doc_id = priority_queue.push(lowest_doc_id, score)
            # if the priority queue is full, we remove the lowest score
            if popped_doc_id is not None:
                del postings_cache[popped_doc_id]

        priority_queue.finalise()

        for scoring_function in scoring_functions[1:]:
            ks.pop()
            resorted_documents = []
            if scoring_function.batched_call is not None:
                scores: list[float] = scoring_function.batched_call(
                    [self.index.get_document_contents(doc_id).body for _, doc_id in priority_queue.heap[:ks[-1]]],
                    query
                )
                for i, (score, doc_id) in enumerate(priority_queue.heap[:ks[-1]]):
                    new_score = scores[i]
                    resorted_documents.append((new_score + score, doc_id))
            else:
                for score, doc_id in priority_queue.heap[:ks[-1]]:
                    postings = postings_cache[doc_id]
                    global_info = self.index.get_global_info()
                    global_info["document_content"] = self.index.get_document_contents(doc_id).body
                    global_info["query_content"] = query
                    new_score = scoring_function(self.index.get_document_info(doc_id), postings, terms, **global_info)
                    # we add the old score to maintain monotonicity
                    resorted_documents.append((new_score + score, doc_id))

            resorted_documents.sort(key=lambda x: x[0], reverse=True)
            priority_queue.heap = resorted_documents + priority_queue.heap[ks[-1]:]

        for score, doc_id in priority_queue:
            ret = self.index.get_document_contents(doc_id)
            ret.add_field("id", doc_id)
            ret.set_score(score)
            yield ret

    def get_run(self, queries: pd.DataFrame, verbose: bool = False, pyterrier_compatible: bool = False) -> pd.DataFrame:
        """
        Generate a run file for the given queries in the form of a pandas DataFrame.
        You can encode it to a file using a tab separator and the to_csv method.

        # Parameters
        - queries (pd.DataFrame): A DataFrame with the queries to run.
        It must have the columns "query_id" and "text".
        - verbose (bool): Whether to show a progress bar.

        # Returns
        - pd.DataFrame: The run file. It has the columns
        "query_id", "Q0", "document_no", "rank", "score", "run_id".
        If pyterrier_compatible is True, the columns are "qid", "docid", "docno", "rank", "score", "query".
        """

        run = []
        for _, query_row in tqdm(queries.iterrows(), desc="Running queries", disable=not verbose, total=len(queries)):
            query_id = query_row["query_id"]
            query = query_row["text"]
            for rank, doc in enumerate(self.search(query), start=0):
                if pyterrier_compatible:
                    run.append(
                        {"qid": query_id, "docid": doc.id, "docno": doc.id, "rank": rank, "score": doc.score, "query": query})
                else:
                    run.append(
                        {"query_id": query_id, "Q0": "Q0", "doc_id": doc.id, "rank": rank, "score": doc.score, "run_id": self.__class__.__name__})

        run = pd.DataFrame(run)
        return run


In [24]:
#%% === mir.utils.dataset ===

from collections.abc import Generator
import gzip
import tarfile
# import numpy as np
# import requests
# from mir import DATA_DIR, COLAB
# import pandas as pd
# import os
import json
# import unidecode
from tqdm.auto import tqdm

# from mir.ir.document_contents import DocumentContents
# from mir.utils.sized_generator import SizedGenerator

def get_msmarco_dataset(verbose: bool = False):
    """
    Downloads the MS MARCO dataset to the data directory.
    """
    corpus = "https://msmarco.z22.web.core.windows.net/msmarcoranking/collection.tar.gz"
    queries = "https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz"
    queries_valid = "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz"
    queries_test = "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test2020-queries.tsv.gz"
    qrels_train = "https://msmarco.z22.web.core.windows.net/msmarcoranking/qrels.train.tsv"
    # trec link is usually down so I'm using my own link to the same files
    # qrels_valid = "https://trec.nist.gov/data/deep/2019qrels-pass.txt"
    # qrels_test = "https://trec.nist.gov/data/deep/2020qrels-pass.txt"
    qrels_valid = "https://huggingface.co/Etto48/MIRProject/resolve/main/2019qrels-pass.txt"
    qrels_test = "https://huggingface.co/Etto48/MIRProject/resolve/main/2020qrels-pass.txt"

    urls = [corpus, queries, queries_valid, queries_test, qrels_train, qrels_valid, qrels_test]
    dataset_dir = f"{DATA_DIR}/msmarco"
    os.makedirs(dataset_dir, exist_ok=True)
    for url in urls:
        file_name = url.split("/")[-1]
        path = f"{dataset_dir}/{file_name}"
        if not os.path.exists(path):
            response = requests.get(url, stream=True)
            response.raise_for_status()
            file_size = int(response.headers.get("content-length", 0))
            block_size = 1024
            try:
                with tqdm(total=file_size, unit="B", unit_scale=True, desc=f"Downloading {file_name}", disable=not verbose) as pbar:
                    with open(path, "wb") as f:
                        for data in response.iter_content(block_size):
                            f.write(data)
                            pbar.update(len(data))
            except (KeyboardInterrupt, Exception) as e:
                os.remove(path)
                raise e
        decompressed_path = path.replace(".tar.gz", "")
        decompressed_path = decompressed_path.replace(".gz", "")
        if file_name.endswith(".tar.gz") and not os.path.exists(f"{decompressed_path}.tsv"):
            if verbose:
                print(f"Decompressing {file_name}...")
            with tarfile.open(path, "r:gz") as tar:
                tar.extractall(dataset_dir, filter="fully_trusted")
        elif not file_name.endswith(".tar.gz") and \
                file_name.endswith(".gz") and \
                not os.path.exists(decompressed_path):
            if verbose:
                print(f"Decompressing {file_name}...")
            with gzip.open(path, "rb") as f_in:
                with open(decompressed_path, "wb") as f_out:
                    f_out.write(f_in.read())

def msmarco_dataset_to_contents(corpus: pd.DataFrame, verbose: bool = False) -> SizedGenerator[DocumentContents, None, None]:
    """
    Returns the number of documents and a generator of DocumentContents from the test corpus.
    """
    def inner() -> Generator[DocumentContents, None, None]:
        for _, row in corpus.iterrows():
            yield DocumentContents(author="", title="", body=row['text'], doc_id=int(row['docno']))
    return SizedGenerator(inner(), len(corpus))


In [25]:
#%% === mir.utils.download_and_extract ===

# import os
# import tarfile
# import requests
from tqdm.auto import tqdm

def download_and_extract(url: str, path: str, desc: str = ""):
    stream = requests.get(url, stream=True)
    total_size = int(stream.headers.get('content-length', 0))
    tgz_path = f"{path}.tar.gz"
    output_dir = f"{path}"
    if not os.path.exists(tgz_path):
        with tqdm(total=total_size, unit='B', unit_scale=True, unit_divisor=1024, desc=f"Downloading {desc}") as pbar:
            with open(tgz_path, 'wb') as f:
                for chunk in stream.iter_content(chunk_size=1024):
                    f.write(chunk)
                    pbar.update(len(chunk))
    if not os.path.exists(output_dir):
        with tarfile.open(tgz_path, 'r:gz') as tar:
            members = tqdm(tar.getmembers(), desc=f"Extracting {desc}")
            tar.extractall(output_dir, members)


# Demo

Now we will use the system to index ms-marco and run the test queries.
Then we will compare the results with the ones of PyTerrier.



In [38]:
#%% === mir.scripts.demo ===

# import os
import re
import pyterrier as pt
from pyterrier import IndexFactory
# import pandas as pd
from tqdm.auto import tqdm

# from mir import DATA_DIR
# from mir.ir.impls.bm25f_scoring import BM25FScoringFunction
# from mir.ir.impls.neural_scoring_function import NeuralScoringFunction
# from mir.ir.impls.sqlite_index import SqliteIndex
# from mir.ir.ir import Ir
# from mir.utils.dataset import get_msmarco_dataset, msmarco_dataset_to_contents
# from mir.utils.download_and_extract import download_and_extract


get_msmarco_dataset(verbose=True)
dataset_csv = f"{DATA_DIR}/msmarco/collection.tsv"
index_path = f"{DATA_DIR}/msmarco-pyterrier-index/data.properties"
msmarco_pyterrier_index_url = "https://huggingface.co/Etto48/MIRProject/resolve/main/msmarco-pyterrier-index.tar.gz"
msmarco_sqlite_index_url = "https://huggingface.co/Etto48/MIRProject/resolve/main/msmarco-sqlite-index.db.tar.gz"
# download pyterrier index
download_and_extract(msmarco_pyterrier_index_url, DATA_DIR, desc="PyTerrier Index")
# download sqlite index
download_and_extract(msmarco_sqlite_index_url, DATA_DIR, desc="SQLite Index")

indexer = pt.terrier.IterDictIndexer(f"{DATA_DIR}/msmarco-pyterrier-index")
if not os.path.exists(index_path):
    dataset = pd.read_csv(dataset_csv, sep='\t', header=None, names=['docno', 'text'], dtype={'docno': str, 'text': str})
    indexref = indexer.index(tqdm(dataset.to_dict(orient='records'), desc="Indexing"))
    del dataset
else:
    indexref = pt.IndexRef.of(index_path)
index = IndexFactory.of(indexref)


topics_path = f"{DATA_DIR}/msmarco/msmarco-test2020-queries.tsv"
qrels_path = f"{DATA_DIR}/msmarco/2020qrels-pass.txt"

topics = pd.read_csv(topics_path, sep='\t', header=None, names=['qid', 'query'], dtype={'qid': str, 'query': str})
qrels = pd.read_csv(qrels_path, sep=' ', header=None, names=['qid', 'Q0', 'docno', 'relevance'], dtype={'qid': str, 'Q0': str, 'docno': str, 'relevance': int})

def preprocess_query(query):
    query = re.sub(r'[^\w\s]', '', query)
    query = query.lower()
    return query

topics['query'] = topics['query'].apply(preprocess_query)

my_ir = Ir(SqliteIndex(f"{DATA_DIR}/msmarco-sqlite-index.db"), scoring_functions=[
    (100, BM25FScoringFunction(1.2, 0.8)),
    (10, NeuralScoringFunction())
])
if len(my_ir.index) == 0:
    dataset = pd.read_csv(dataset_csv, sep='\t', header=None, names=['docno', 'text'], dtype={'docno': str, 'text': str})
    sized_generator = msmarco_dataset_to_contents(dataset)
    my_ir.bulk_index_documents(sized_generator, verbose=True)

my_topics = pd.read_csv(topics_path, sep='\t', header=None, names=['query_id', 'text'], dtype={'query_id': int, 'text': str})
my_run = my_ir.get_run(my_topics, verbose=True, pyterrier_compatible=True)

bm25 = pt.terrier.Retriever(index, wmodel="BM25")
dfree = pt.terrier.Retriever(index, wmodel="DFRee")
pyterrier_models = {
    "BM25": bm25 % 100,
    "BM25+DFRee": (bm25 % 100) >> dfree
}
pyterrier_runs = {}
for model_name, model in pyterrier_models.items():
    print(f"Running PyTerrier {model_name}")
    pyterrier_runs[model_name] = model.transform(topics)

test_runs = [my_run, *pyterrier_runs.values()]
names = ["MyIR", *pyterrier_models.keys()]

metrics = ["map", "ndcg", "recip_rank", "P.10", "recall.10", ]
res = pt.Experiment(test_runs, topics, qrels, metrics, names=names)
print(res)


Decompressing queries.tar.gz...


Some weights of the model checkpoint at facebook/contriever-msmarco were not used when initializing Contriever: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing Contriever from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Contriever from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Running queries:   0%|          | 0/200 [00:00<?, ?it/s]

torch.Size([0])


RuntimeError: a Tensor with 0 elements cannot be converted to Scalar