In [None]:
from os import environ
environ["TOKENIZERS_PARALLELISM"] = "false"

from os import path
import re
from json import load, dump
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from langchain_text_splitters import RecursiveCharacterTextSplitter

from utils.storage import list_processed_mmd_files, list_sparse_vector_files, download_plain_text, upload_sparse_vectors

In [None]:
MODEL_NAME = "naver/splade-cocondenser-ensembledistil"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME, device_map="auto", torch_dtype=torch.float16)
model.eval()

model.device

In [None]:
splitter = RecursiveCharacterTextSplitter(
	chunk_size=700,
	chunk_overlap=100,
	length_function=len,
	is_separator_regex=False,
)

In [None]:
files = list_processed_mmd_files()
len(files)

In [None]:
with ThreadPoolExecutor() as executor:
	plain = list(tqdm(executor.map(download_plain_text, files), total=len(files)))

In [None]:
BATCH_SIZE = 128

In [None]:
vectorFiles = list_sparse_vector_files()
print(f"Found {len(vectorFiles)} vector files")

for filename, text in tqdm(zip(files, plain), total=len(files)):
	if filename in vectorFiles:
		continue

	chunks = splitter.split_text(text)
	index = 0

	output = {}

	for i in range(0, len(chunks), BATCH_SIZE):
		batch = chunks[i:i + BATCH_SIZE]

		tokens = tokenizer(batch, return_tensors='pt', padding=True, truncation=True, max_length=512)
		tokens = {k: v.to(model.device) for k, v in tokens.items()}
		with torch.no_grad():
			outputs = model(**tokens)

		vectors = torch.max(
			torch.log(1 + torch.relu(outputs.logits)) * tokens['attention_mask'].unsqueeze(-1),
			dim=1
		)[0].squeeze()

		for j in range(len(batch)):
			indices = vectors[j].nonzero().squeeze().cpu().tolist()
			if not isinstance(indices, list):
				indices = [indices]

			if len(indices) == 0:
				continue

			weights = vectors[j][indices].cpu().tolist()
			output[index] = dict(zip(indices, weights))

			index += 1

	upload_sparse_vectors(filename, output)