<a href="https://githubtocolab.com/NirantK/lightsplade/blob/master/nbs/01-create-spladev3lexical.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking

In [1]:
!pip install git+https://github.com/naver/splade.git datasets --quiet

  Preparing metadata (setup.py) ... [?25l[?25hdone


In [2]:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
from splade.models.transformer_rep import Splade
from datasets import load_dataset
from tqdm.auto import tqdm

In [3]:
# set the dir for trained weights

##### v2
# model_type_or_dir = "naver/splade_v2_max"
# model_type_or_dir = "naver/splade_v2_distil"

### v2bis, directly download from Hugging Face
# model_type_or_dir = "naver/splade-cocondenser-selfdistil"
model_type_or_dir = "nirantk/splade-v3-lexical"

In [4]:
corpus = load_dataset("BeIR/scifact", "corpus")["corpus"]

In [5]:
# loading model and tokenizer
model = Splade(model_type_or_dir, agg="max")
model.eval()
model.to("cuda")

Splade(
  (transformer_rep): TransformerRep(
    (transformer): BertForMaskedLM(
      (bert): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0-11): 12 x BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (output): BertSelfOutput(
     

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_type_or_dir)
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}

In [7]:
# example document from MS MARCO passage collection (doc_id = 8003157)
docs = ["Glass and Thermal Stress. Thermal Stress is created when one area of a glass pane gets hotter than an adjacent area. If the stress is too great then the glass will crack. The stress level at which the glass will break is governed by several factors.", "Hello World from Qdrant!"]
tokenizer(docs, return_tensors="pt", padding=True).to("cuda")
doc = docs[0]

In [8]:
# now compute the document representation
with torch.no_grad():
    doc_rep = model(d_kwargs=tokenizer(doc, return_tensors="pt").to("cuda"))

doc_rep = doc_rep["d_rep"].cpu().squeeze()  # (sparse) doc rep in voc space, shape (30522,)

# get the number of non-zero dimensions in the rep:
col = torch.nonzero(doc_rep).squeeze().cpu().tolist()
print("number of actual dimensions: ", len(col))

# now let's inspect the bow representation:
weights = doc_rep[col].cpu().tolist()
d = {k: v for k, v in zip(col, weights)}
sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
bow_rep = []
print("SPLADE BOW rep:\n")
for k, v in sorted_d.items():
    bow_rep.append((reverse_voc[k], round(v, 2)))
    print(f"{bow_rep[-1]}")

number of actual dimensions:  270
SPLADE BOW rep:

('glass', 1.6)
('thermal', 1.5)
('stress', 1.49)
('glasses', 1.3)
('heat', 1.08)
('crack', 1.08)
('pan', 1.07)
('cracking', 0.99)
('cracked', 0.98)
('stresses', 0.98)
('stressed', 0.92)
('cracks', 0.92)
('strain', 0.91)
('##glass', 0.88)
('windshield', 0.83)
('break', 0.83)
('windows', 0.8)
('hot', 0.78)
('window', 0.78)
('breaking', 0.73)
('hotter', 0.72)
('causes', 0.69)
('heated', 0.69)
('breaks', 0.68)
('heats', 0.68)
('temperature', 0.67)
('tension', 0.67)
('brittle', 0.63)
('creates', 0.63)
('shatter', 0.6)
('happens', 0.57)
('te', 0.57)
('heating', 0.57)
('##dia', 0.57)
('broken', 0.56)
('fracture', 0.54)
('pressure', 0.53)
('##thermal', 0.51)
('factors', 0.49)
('why', 0.49)
('adjacent', 0.48)
('cause', 0.47)
('create', 0.47)
('lens', 0.45)
('melt', 0.44)
('fatigue', 0.44)
('tensions', 0.44)
('##e', 0.43)
('created', 0.43)
('broke', 0.42)
('how', 0.42)
('##rm', 0.41)
('ceramic', 0.41)
('warm', 0.4)
('caused', 0.4)
('shattering',

In [9]:
corpus

Dataset({
    features: ['_id', 'title', 'text'],
    num_rows: 5183
})

In [10]:
def get_docs_rep(docs, tokenizer, model):
    with torch.no_grad():
        tokens = tokenizer(docs, return_tensors="pt", padding=True, truncation=True).to("cuda")
        doc_rep = model(d_kwargs=tokens)
    return doc_rep

# def get_sparse_vectors(doc_rep):
#     doc_rep ["d_rep"].squeeze()  # (sparse) doc rep in voc space, shape (30522,)
#     with torch.no_grad():
#         doc_rep = model(d_kwargs=tokenizer(doc, return_tensors="pt", padding=True))["d_rep"].squeeze()  # (sparse) doc rep in voc space, shape (30522,)
#     col = torch.nonzero(doc_rep).squeeze().cpu().tolist()
#     weights = doc_rep[col].cpu().tolist()
#     return {k: v for k, v in zip(col, weights)}

# assert d == get_sparse_vectors(doc, tokenizer, model)

In [11]:
def batch_iterator(iterable, batch_size=128):
    """
    Iterates over an iterable in batches of a given size.

    Args:
        iterable: An iterable object.
        batch_size: The size of each batch.

    Yields:
        A batch of items from the iterable.
    """

    l = len(iterable)
    for ndx in range(0, l, batch_size):
        yield iterable[ndx:min(ndx + batch_size, l)]

# Example usage:

for batch in batch_iterator(range(10), 12):
    print(batch)

range(0, 10)


In [12]:
doc_representations = []
for docs in tqdm(batch_iterator(corpus["text"], batch_size=16)):
    doc_rep = get_docs_rep(docs, tokenizer, model)
    doc_representations.append(doc_rep)

0it [00:00, ?it/s]

In [13]:
sparse_vectors = []
for batch_rep in doc_representations:
    for doc_rep in batch_rep["d_rep"]:
        col = torch.nonzero(doc_rep.squeeze()).squeeze().cpu().tolist()
        # cols = torch.nonzero(doc_rep["d_rep"].squeeze().cpu()).squeeze()  # (sparse) doc rep in voc space, shape (30522,)
        weights = doc_rep[col].cpu().tolist()
        d = {k: v for k, v in zip(col, weights)}
        sparse_vectors.append(d)

In [14]:
assert len(sparse_vectors) == len(corpus["text"]), "number of sparse vectors should be equal to number of documents"

In [16]:
import json
import pandas as pd
from datasets import Dataset

In [27]:
str_sparse_vectors = [json.dumps(s) for s in sparse_vectors]

In [40]:
df = pd.DataFrame(str_sparse_vectors, columns=["spalde-v3-lexical"])

In [41]:
df["_id"] = corpus["_id"]

In [42]:
df["text"] = corpus["text"]

In [43]:
df["title"] = corpus["title"]

In [44]:
df.head()

Unnamed: 0,spalde-v3-lexical,_id,text,title
0,"{""1009"": 0.20660123229026794, ""1011"": 0.036437...",4983,Alterations of the architecture of cerebral wh...,Microstructural development of human newborn c...
1,"{""1011"": 0.12642613053321838, ""1055"": 0.427995...",5836,Myelodysplastic syndromes (MDS) are age-depend...,Induction of myelodysplasia by myeloid-derived...
2,"{""1011"": 0.20421554148197174, ""1015"": 0.602221...",7912,ID elements are short interspersed elements (S...,"BC1 RNA, the transcript from a master gene for..."
3,"{""1014"": 0.004870930220931768, ""1052"": 0.83494...",18670,DNA methylation plays an important role in bio...,The DNA Methylome of Human Peripheral Blood Mo...
4,"{""1011"": 0.35403475165367126, ""1017"": 0.100269...",19238,Two human Golli (for gene expressed in the oli...,The human myelin basic protein gene is include...


In [45]:
sparse_vectors_dataset = Dataset.from_pandas(df)

In [47]:
import getpass # for password input
sparse_vectors_dataset.push_to_hub("nirantk/scifact-sparse-vectors", token=getpass.getpass("Enter your Hugging Face API token: ")

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/406 [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/datasets/nirantk/scifact-sparse-vectors/commit/9ef5ac0679a05775006e9c5cb5b5c2a18858c2c3', commit_message='Upload dataset', commit_description='', oid='9ef5ac0679a05775006e9c5cb5b5c2a18858c2c3', pr_url=None, pr_revision=None, pr_num=None)