# SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking

This notebook gives a minimal example usage of SPLADE.

* In this repo, we provide weights for 2 models (in the `weights` folder)
* See [Naver Labs Europe website](https://europe.naverlabs.com/research/machine-learning-and-optimization/splade-models/) for more up-to-date models under various settings
* We also provide two new models via Hugging Face (https://huggingface.co/naver)

| model | MRR@10 (MS MARCO dev) | recall@1000 (MS MARCO dev) | expected FLOPS | ~ avg q length | ~ avg d length | 
| --- | --- | --- | --- | --- | --- |
| `splade_max` (**v2**) | 34.0 | 96.5 | 1.32 | 18 | 92 |
| `distilsplade_max` (**v2**) | 36.8 | 97.9 | 3.82 | 25 | 232 |
| `naver/splade-cocondenser-selfdistil` (**v2bis**, [HF](https://huggingface.co/naver/splade-cocondenser-selfdistil))| 37.6 | 98.4 | 2.32 | 56 | 134 |
| `naver/splade-cocondenser-ensembledistil` (**v2bis**, [HF](https://huggingface.co/naver/splade-cocondenser-ensembledistil)) | 38.3 | 98.3  | 1.85 | 44 | 120 |

In [2]:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
from splade.models.transformer_rep import Splade

In [6]:
from splade.models.transformer_rep import *
import pickle

In [4]:
# set the dir for trained weights

##### v2
# model_type_or_dir = "weights/splade_max"
# model_type_or_dir = "weights/distilsplade_max"

### v2bis, directly download from Hugging Face
# model_type_or_dir = "naver/splade-cocondenser-selfdistil"
model_type_or_dir = "Luyu/co-condenser-marco"
model_type_or_dir = "naver/splade-cocondenser-ensembledistil"
# model_type_or_dir = "distilbert-base-uncased"

In [8]:
with open('output/log-doc-tf-topk-10.pkl', 'rb') as f:
    doc_tf = pickle.load(f)

with open('../output/corpus-tf-log-tensor-dict.pkl', 'rb') as f:
    cortf=pickle.load( f)


In [80]:
# loading model and tokenizer

model = DiversitySplade(model_type_or_dir, agg="max")
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_type_or_dir)
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}

In [20]:
# example document from MS MARCO passage collection (doc_id = 8003157)

doc = "Glass and Thermal Stress. Thermal Stress is created when one area of a glass pane gets hotter than an adjacent area. If the stress is too great then the glass will crack. The stress level at which the glass will break is governed by several factors."

In [21]:
tokenizer

PreTrainedTokenizerFast(name_or_path='naver/splade-cocondenser-ensembledistil', vocab_size=30522, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [22]:
# now compute the document representation
with torch.no_grad():
    doc_rep = model(d_kwargs=tokenizer(doc, return_tensors="pt"))["d_rep"].squeeze()  # (sparse) doc rep in voc space, shape (30522,)

# get the number of non-zero dimensions in the rep:
col = torch.nonzero(doc_rep).squeeze().cpu().tolist()
print("number of actual dimensions: ", len(col))

# now let's inspect the bow representation:
weights = doc_rep[col].cpu().tolist()
d = {k: v for k, v in zip(col, weights)}
sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
bow_rep = []
for k, v in sorted_d.items():
    bow_rep.append((reverse_voc[k], round(v, 2)))
print("SPLADE BOW rep:\n", bow_rep)

number of actual dimensions:  126
SPLADE BOW rep:
 [('stress', 2.25), ('glass', 2.23), ('thermal', 2.18), ('glasses', 1.65), ('pan', 1.62), ('heat', 1.56), ('stressed', 1.42), ('crack', 1.31), ('break', 1.12), ('cracked', 1.1), ('hot', 0.93), ('created', 0.9), ('factors', 0.81), ('broken', 0.73), ('caused', 0.71), ('too', 0.71), ('damage', 0.69), ('if', 0.68), ('hotter', 0.65), ('governed', 0.61), ('heating', 0.59), ('temperature', 0.59), ('adjacent', 0.59), ('cause', 0.58), ('effect', 0.57), ('fracture', 0.56), ('bradford', 0.55), ('strain', 0.53), ('hammer', 0.51), ('brian', 0.48), ('error', 0.47), ('windows', 0.45), ('will', 0.45), ('reaction', 0.42), ('create', 0.42), ('windshield', 0.41), ('heated', 0.41), ('factor', 0.4), ('cracking', 0.39), ('failure', 0.38), ('mechanical', 0.38), ('when', 0.38), ('formed', 0.38), ('bolt', 0.38), ('mechanism', 0.37), ('warm', 0.37), ('areas', 0.36), ('area', 0.36), ('energy', 0.34), ('disorder', 0.33), ('barry', 0.33), ('shock', 0.32), ('determi

In [25]:
# example document from MS MARCO passage collection (doc_id = 8003157)

doc = "how much caffeine is in peach green tea"

In [53]:
tokken[0][argmax]

tensor([ 2172,  2172,  2129,  2665,  3170,  2172,  2172,  2172, 18237, 24689,
         1999, 24689,  2003,  5572, 18237,  2129,   101,  7959,   101, 24689,
          101,  7959,  2665,  2172,  1999,  1999,  2172,   101,   101,   101,
        18237,  2003,  1999, 24689])

In [83]:
# now compute the document representation
tokken=tokenizer(doc, return_tensors="pt")['input_ids'][:,1:-1]
with torch.no_grad():
    doc_rep = model(q_kwargs=tokenizer(doc, return_tensors="pt"))["q_rep"].squeeze()  # (sparse) doc rep in voc space, shape (30522,)
    argmax = model(q_kwargs=tokenizer(doc, return_tensors="pt"))["argmax"].squeeze() 
# get the number of non-zero dimensions in the rep:
col = torch.nonzero(doc_rep).squeeze().cpu().tolist()
print("number of actual dimensions: ", len(col))

# now let's inspect the bow representation:
weights = doc_rep[col].cpu().tolist()
argmax = argmax[col].cpu().tolist()
source=tokken[0][argmax].cpu().tolist()
d = {k: (v,j) for k, v,j in zip(col, weights,source)}
sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
bow_rep = []
for k, v in sorted_d.items():
    bow_rep.append((reverse_voc[k],reverse_voc[v[1]], round(v[0], 2)))
print("SPLADE BOW rep:\n", bow_rep)

number of actual dimensions:  29
SPLADE BOW rep:
 [('peach', 'peach', 2.42), ('tea', 'tea', 2.15), ('caf', 'caf', 1.7), ('green', 'green', 1.63), ('##fe', '##fe', 1.49), ('georgia', 'peach', 1.48), ('##ine', '##ine', 1.24), ('##ories', 'is', 1.21), ('weight', 'much', 1.1), ('content', 'in', 1.07), ('mg', 'much', 0.88), ('total', 'how', 0.87), ('much', 'much', 0.86), ('amount', 'much', 0.83), ('coffee', 'caf', 0.82), ('drink', 'caf', 0.54), ('cal', 'caf', 0.43), ('smooth', 'peach', 0.42), ('dose', 'much', 0.39), ('fruit', 'in', 0.35), ('%', 'much', 0.29), ('max', 'much', 0.29), ('herb', 'in', 0.27), ('##lor', '##fe', 0.2), ('contain', 'is', 0.19), ('greene', 'green', 0.16), ('rating', 'how', 0.15), ('ingredient', 'in', 0.13), ('pepper', 'in', 0.01)]


In [64]:
tokken=tokken[0]
for t in tokken:
    print(reverse_voc[t])

[CLS]
how
much
caf
##fe
##ine
is
in
peach
green
tea
[SEP]


In [59]:
reverse_voc

{20012: 'glover',
 21369: '##asi',
 19714: '640',
 10706: 'wrestler',
 14305: 'extensions',
 13173: 'kara',
 441: '[unused436]',
 7496: 'ease',
 25746: '##52',
 12956: '##ante',
 4541: 'explained',
 14791: 'jasper',
 20134: 'longitudinal',
 22066: 'linden',
 26225: 'arden',
 5055: 'birds',
 4591: 'qualified',
 13526: 'skeleton',
 18193: 'mandy',
 20230: 'putnam',
 24594: '##29',
 25349: 'draught',
 6193: '1890',
 6820: '##ru',
 19577: 'oxidation',
 1552: '₅',
 7682: 'bearing',
 15043: 'ached',
 27628: 'carousel',
 16177: '##ν',
 24225: '##edge',
 14474: 'gaza',
 328: '[unused323]',
 17845: 'byrd',
 17876: 'elongated',
 291: '[unused286]',
 862: '[unused857]',
 182: '[unused177]',
 14009: 'observers',
 23753: 'greenish',
 25469: '##dity',
 26004: 'valuation',
 28625: 'astro',
 9773: 'nominee',
 11524: 'nsw',
 9067: '##mal',
 27798: '##hiti',
 16326: 'absorption',
 19530: 'acquisitions',
 13496: 'teens',
 16458: '##cross',
 25756: 'hammered',
 27953: '##anal',
 4654: 'ex',
 23614: 'pedia

In [58]:
v[1]

tensor(18237)

In [48]:
tokken=tokenizer(doc, return_tensors="pt")['input_ids']

In [49]:
argmax = argmax[col].cpu().tolist()

In [50]:
argmax

[2,
 2,
 1,
 9,
 5,
 2,
 2,
 2,
 8,
 3,
 7,
 3,
 6,
 10,
 8,
 1,
 0,
 4,
 0,
 3,
 0,
 4,
 9,
 2,
 7,
 7,
 2,
 0,
 0,
 0,
 8,
 6,
 7,
 3]

In [None]:
# largefiles/TaoFiles/splade/splade/evaluation/eval.py

In [12]:
import argparse
import json
from splade.evaluation.eval import mrr_k, evaluate


def load_and_evaluate(qrel_file_path, run_file_path, metric,*args, **kwargs):
    with open(qrel_file_path) as reader:
        qrel = json.load(reader)
    with open(run_file_path) as reader:
        run = json.load(reader)
    # for trec, qrel_binary.json should be used for recall etc., qrel.json for NDCG.
    # if qrel.json is used for binary metrics, the binary 'splits' are not correct
    if "TREC" in qrel_file_path:
        assert ("binary" not in qrel_file_path) == (metric == "ndcg" or metric == "ndcg_cut")
    if metric == "mrr_10":
        res = mrr_k(run, qrel, k=10)
        print("MRR@10:", res)
        return {"mrr_10": res}
    else:
        res = evaluate(run, qrel, metric=metric,*args, **kwargs)
        print(metric, "==>", res)
        return res
run="experiments/cocondenser_ensemble_distil/out2/TREC_DL_2019/run.json"
qrels="data/msmarco/TREC_DL_2019/qrel.json"
metric="ndcg_cut"
res=load_and_evaluate(qrels,run,metric,agg=False)


ndcg_cut ==> {'156493': {'ndcg_cut_5': 0.8550077055393716, 'ndcg_cut_10': 0.8347727710785315, 'ndcg_cut_15': 0.8690005115907, 'ndcg_cut_20': 0.8892363940828747, 'ndcg_cut_30': 0.8903459863148391, 'ndcg_cut_100': 0.8456686930864262, 'ndcg_cut_200': 0.8616580574038185, 'ndcg_cut_500': 0.9080776185302061, 'ndcg_cut_1000': 0.9353519759599007}, '1110199': {'ndcg_cut_5': 0.8008157892980455, 'ndcg_cut_10': 0.8559744731298657, 'ndcg_cut_15': 0.8633110478805224, 'ndcg_cut_20': 0.7593214811183402, 'ndcg_cut_30': 0.6812161237801484, 'ndcg_cut_100': 0.7513523120199341, 'ndcg_cut_200': 0.8071607930105618, 'ndcg_cut_500': 0.8173835461142229, 'ndcg_cut_1000': 0.8265029855898373}, '1063750': {'ndcg_cut_5': 0.7345577848667817, 'ndcg_cut_10': 0.839758543431562, 'ndcg_cut_15': 0.8702436132800725, 'ndcg_cut_20': 0.8750719642322907, 'ndcg_cut_30': 0.8590294134716118, 'ndcg_cut_100': 0.6047392488251477, 'ndcg_cut_200': 0.46676931377556363, 'ndcg_cut_500': 0.4990821187824705, 'ndcg_cut_1000': 0.5552268914382

{'156493': {'ndcg_cut_5': 0.8550077055393716,
  'ndcg_cut_10': 0.8347727710785315,
  'ndcg_cut_15': 0.8690005115907,
  'ndcg_cut_20': 0.8892363940828747,
  'ndcg_cut_30': 0.8903459863148391,
  'ndcg_cut_100': 0.8456686930864262,
  'ndcg_cut_200': 0.8616580574038185,
  'ndcg_cut_500': 0.9080776185302061,
  'ndcg_cut_1000': 0.9353519759599007},
 '1110199': {'ndcg_cut_5': 0.8008157892980455,
  'ndcg_cut_10': 0.8559744731298657,
  'ndcg_cut_15': 0.8633110478805224,
  'ndcg_cut_20': 0.7593214811183402,
  'ndcg_cut_30': 0.6812161237801484,
  'ndcg_cut_100': 0.7513523120199341,
  'ndcg_cut_200': 0.8071607930105618,
  'ndcg_cut_500': 0.8173835461142229,
  'ndcg_cut_1000': 0.8265029855898373},
 '1063750': {'ndcg_cut_5': 0.7345577848667817,
  'ndcg_cut_10': 0.839758543431562,
  'ndcg_cut_15': 0.8702436132800725,
  'ndcg_cut_20': 0.8750719642322907,
  'ndcg_cut_30': 0.8590294134716118,
  'ndcg_cut_100': 0.6047392488251477,
  'ndcg_cut_200': 0.46676931377556363,
  'ndcg_cut_500': 0.499082118782470