In [15]:
import nltk
import ir_datasets
import numpy as np
from nltk.tokenize import word_tokenize
from nltk.stem import SnowballStemmer
from nltk.corpus import stopwords
from tqdm import tqdm
from itertools import islice

nltk.download('punkt')
nltk.download('stopwords')

ds = ir_datasets.load("msmarco-passage-v2/trec-dl-2022")

[nltk_data] Downloading package punkt to /home/don/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /home/don/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [16]:
print("Docs Count : ", ds.docs_count())
print("Queries Count : ", ds.queries_count())
print("Qrels Count : ", ds.qrels_count())

Docs Count :  138364198
Queries Count :  500
Qrels Count :  386416


In [17]:
stop_words = set(stopwords.words('english'))
stemmer = SnowballStemmer("english")

def preprocess(text):
    text = text.lower()
    tokens = word_tokenize(text)
    tokens = [word for word in tokens if word.isalpha()] 
    tokens = [stemmer.stem(word) for word in tokens if word not in stop_words]
    return ' '.join(tokens)

In [18]:
processed_queries = {}
for query in ds.queries_iter():
    processed_queries[query.query_id] = preprocess(query.text)

with open('queries.npy', 'wb+') as f:
    np.save(f, processed_queries, allow_pickle=True)

loaded = np.load('queries.npy', allow_pickle=True).item()
assert loaded == processed_queries

In [19]:
MAX_DOCS = 250_000
BATCH_SIZE = 50_000
it = islice(ds.docs_iter(), MAX_DOCS)
batch_idx = 0
while True:
    batch = list(islice(it, BATCH_SIZE))
    if not batch:
        break

    processed = {doc.doc_id: preprocess(doc.text) for doc in batch}
    np.save(f"batches/docs_batch_{batch_idx:03d}.npy", processed, allow_pickle=True)
    batch_idx += 1

In [20]:
from gensim.models.doc2vec import TaggedDocument

def load_tagged_documents(batch_files):
    for batch_file in batch_files:
        data = np.load(batch_file, allow_pickle=True).item()
        for doc_id, text in data.items():
            tokens = text.split()
            yield TaggedDocument(words=tokens, tags=[doc_id])

In [None]:
from gensim.models import Doc2Vec
import glob

vector_size = 512
window = 5
min_count = 2
epochs = 20

batch_files = sorted(glob.glob("batches/docs_batch_*.npy"))
tagged_docs = list(load_tagged_documents(batch_files))

model = Doc2Vec(vector_size=vector_size, window=window, min_count=min_count, workers=4, epochs=epochs)
model.build_vocab(tagged_docs)

model.train(tagged_docs, total_examples=model.corpus_count, epochs=model.epochs)
model.save("doc2vec_model.model")

In [None]:
model = Doc2Vec.load("doc2vec_model.model")

new_doc = "Sample text for embedding."
tokens = preprocess(new_doc).split()

vector = model.infer_vector(tokens)
print(vector)