# SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking

This notebook gives a minimal example usage of SPLADE.

* In this repo, we provide weights for 2 models (in the `weights` folder)
* See [Naver Labs Europe website](https://europe.naverlabs.com/research/machine-learning-and-optimization/splade-models/) for more up-to-date models under various settings
* We also provide two new models via Hugging Face (https://huggingface.co/naver)

| model | MRR@10 (MS MARCO dev) | recall@1000 (MS MARCO dev) | expected FLOPS | ~ avg q length | ~ avg d length | 
| --- | --- | --- | --- | --- | --- |
| `splade_max` (**v2**) | 34.0 | 96.5 | 1.32 | 18 | 92 |
| `distilsplade_max` (**v2**) | 36.8 | 97.9 | 3.82 | 25 | 232 |
| `naver/splade-cocondenser-selfdistil` (**v2bis**, [HF](https://huggingface.co/naver/splade-cocondenser-selfdistil))| 37.6 | 98.4 | 2.32 | 56 | 134 |
| `naver/splade-cocondenser-ensembledistil` (**v2bis**, [HF](https://huggingface.co/naver/splade-cocondenser-ensembledistil)) | 38.3 | 98.3  | 1.85 | 44 | 120 |

In [5]:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
from splade.models.transformer_rep import Splade, PriorSpladeV2

In [1]:
import multiprocessing

multiprocessing.cpu_count()

40

In [4]:
# from transformers import AutoTokenizer, AutoModelForMaskedLM

# tokenizer = AutoTokenizer.from_pretrained("naver/splade-cocondenser-ensembledistil")

# model = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil")
# model.eval()
# reverse_voc = {v: k for k, v in tokenizer.vocab.items()}


In [6]:
# set the dir for trained weights

##### v2
# model_type_or_dir = "weights/splade_max"
# model_type_or_dir = "weights/distilsplade_max"

### v2bis, directly download from Hugging Face
# model_type_or_dir = "naver/splade-cocondenser-selfdistil"
model_type_or_dir = "naver/splade-cocondenser-ensembledistil"
# model_type_or_dir = "naver/efficient-splade-V-large-doc"
model_type_or_dir ="/home/taoyang/research/research_everyday/projects/DR/splade/splade/experiments/idf/checkpoint/model"

In [217]:
# model_type_or_dir="distilbert-base-uncased"

In [4]:
import pickle
idfpkl="output/idf-token.pkl"
with open(idfpkl, 'rb') as f:
    idftoken = pickle.load(f)

In [4]:
idftoken["too"]

134309

In [5]:
idftoken["stress"]

38178

In [6]:
idfpkl="output/idf-tokenid.pkl"
with open(idfpkl, 'rb') as f:
    idf = pickle.load(f)

In [7]:
idfTensor=torch.tensor(list(idf.values()),dtype=torch.float32)
idfTensor=torch.clip(idfTensor,10,10**6)

In [6]:
class PriorSplade(Splade):
    """SPLADE model
    """

    def __init__(self,idfTensor,*param,**kwparam):
        super().__init__(*param,**kwparam)
        self.idf=torch.nn.parameter.Parameter(idfTensor,requires_grad=False)
        # self.output_dim = self.transformer_rep.transformer.config.vocab_size  # output dim = vocab size = 30522 for BERT
        # assert agg in ("sum", "max")
        # self.agg = agg

    def encode(self, tokens, is_q):
        out = self.encode_(tokens, is_q)["logits"]  # shape (bs, pad_len, voc_size)
        
        if self.agg == "sum":
            return torch.sum(torch.log(1 + torch.relu(out)) * tokens["attention_mask"].unsqueeze(-1), dim=1)
        else:
            values, _ = torch.max(torch.log(1 + torch.relu(out)) * tokens["attention_mask"].unsqueeze(-1), dim=1)
            values=values/torch.log(self.idf)
            return values

In [35]:
class PriorSpladeV2(Splade):
    """SPLADE model
    """

    def __init__(self,idfpkl="output/idf-tokenid.pkl",*param,**kwparam):
        super().__init__(*param,**kwparam)
        idfpkl="output/idf-tokenid.pkl"
        with open(idfpkl, 'rb') as f:
            idf = pickle.load(f)
        self.idf=torch.nn.parameter.Parameter(idfTensor,requires_grad=False)
        # self.output_dim = self.transformer_rep.transformer.config.vocab_size  # output dim = vocab size = 30522 for BERT
        # assert agg in ("sum", "max")
        # self.agg = agg

    def encode(self, tokens, is_q):
        out = self.encode_(tokens, is_q)["logits"]  # shape (bs, pad_len, voc_size)
        selectedIdf=self.idf[tokens["input_ids"]]
        out=out/torch.log(selectedIdf[:,:,None])
        if self.agg == "sum":
            return torch.sum(torch.log(1 + torch.relu(out)) * tokens["attention_mask"].unsqueeze(-1), dim=1)
        else:
            values, _ = torch.max(torch.log(1 + torch.relu(out)) * tokens["attention_mask"].unsqueeze(-1), dim=1)
            values=values/torch.log(self.idf)*torch.log(self.idf.min())
            return values

In [7]:
# loading model and tokenizer
model = PriorSpladeV2(model_type_or_dir=model_type_or_dir, agg="max")
# model = PriorSpladeV2(idfTensor=idfTensor,model_type_or_dir=model_type_or_dir, agg="max")
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_type_or_dir)
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}

In [8]:
# example document from MS MARCO passage collection (doc_id = 8003157)

doc = "Glass and Thermal Stress. Thermal Stress is created when one area of a glass pane gets hotter than an adjacent area. If the stress is too great then the glass will crack. The stress level at which the glass will break is governed by several factors."

In [9]:
# now compute the document representation
with torch.no_grad():
    doc_rep = model(d_kwargs=tokenizer(doc, return_tensors="pt"))["d_rep"].squeeze()  # (sparse) doc rep in voc space, shape (30522,)

# get the number of non-zero dimensions in the rep:
col = torch.nonzero(doc_rep).squeeze().cpu().tolist()
print("number of actual dimensions: ", len(col))

# now let's inspect the bow representation:
weights = doc_rep[col].cpu().tolist()
d = {k: v for k, v in zip(col, weights)}
sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
bow_rep = []
for k, v in sorted_d.items():
    bow_rep.append((reverse_voc[k], round(v, 2)))
print("SPLADE BOW rep:\n", bow_rep)

number of actual dimensions:  3815
SPLADE BOW rep:


In [272]:
for i in bow_rep:
    if i[0]=="if":
        print(i)
        break

('if', 0.05)


In [10]:
# example document from MS MARCO passage collection (doc_id = 8003157)

query = "Is caffeine dangerous during pregnancy?"
# query = "are you old?"
# query = "mold"

In [11]:
# now compute the document representation
with torch.no_grad():
    doc_rep = model(q_kwargs=tokenizer(query, return_tensors="pt"))["q_rep"].squeeze()  # (sparse) doc rep in voc space, shape (30522,)

# get the number of non-zero dimensions in the rep:
col = torch.nonzero(doc_rep).squeeze().cpu().tolist()
print("number of actual dimensions: ", len(col))

# now let's inspect the bow representation:
weights = doc_rep[col].cpu().tolist()
d = {k: v for k, v in zip(col, weights)}
sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
bow_rep = []
for k, v in sorted_d.items():
    bow_rep.append((reverse_voc[k], round(v, 2)))
print("SPLADE BOW rep:\n", bow_rep)

number of actual dimensions:  2659
SPLADE BOW rep:


In [142]:
for key in tokenizer.vocab:
    if "mold" in key:
        print(key)

mold
molded
moldova


In [12]:
model

Splade(
  (transformer_rep): TransformerRep(
    (transformer): BertForMaskedLM(
      (bert): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (output): BertSelfOutput(
             

In [13]:
import os,pickle
collection_picked_file="/raid/datasets/shared/MSMARCO/collection.pickle"
if os.path.exists(collection_picked_file):
    with open(collection_picked_file, 'rb') as f:
        corpus = pickle.load(f)

In [191]:
corpus[1]

KeyError: 1

In [192]:
len(corpus)

8841823

In [88]:
listCorp=[]
for i in range(10**6):
    listCorp.append(corpus[str(i)])

In [90]:
tokenCop=tokenizer(listCorp, return_tensors="np")['input_ids']

In [91]:
len(tokenizer.vocab.keys())

30522

In [92]:
tokenSet=tokenizer.vocab.values()

In [160]:
# tokenSet=tokenizer.vocab.values()
InverseVocab={key[1]:key[0] for ind,key in enumerate(tokenizer.vocab.items())}

In [136]:
numToken=len(tokenizer.vocab.keys())
idf={ind:0 for ind in range(numToken)}
idfToken={key:0 for key in tokenizer.vocab}

In [152]:
tokenizer.vocab["mold"]

18282

In [151]:
token=18282
InverseVocab[token]

'wakefield'

In [213]:
numToken=len(tokenizer.vocab.keys())
idf={ind:0 for ind in range(numToken)}
idfToken={key:0 for key in tokenizer.vocab}
moldPass=[]
moldPassU=0
for passageId,tokenSent in enumerate(tokenCop):
    SetTokenSent=set(tokenSent)
    for token in SetTokenSent:
        idf[token]+=1
        idfToken[InverseVocab[token]]+=1
        # if "mold" == InverseVocab[token]:
        #     moldPassU+=1
    # if "mold" in listCorp[passageId]:
    #     moldPass.append(listCorp[passageId])
        # print(listCorp[passageId])

In [214]:
corpusLen=len(corpus)

In [205]:
ee=list(tokenizer(listCorp[idx*batch:(idx+1)*batch], return_tensors="np")['input_ids'])

In [206]:
len(ee)

10000

In [209]:
from tqdm.auto import tqdm
batch=10000
numBatch=6
tokenCop=[]
for idx in tqdm(range(0,numBatch)):
    if idx*batch>corpusLen:
        break
    tokenCop+=list(tokenizer(listCorp[idx*batch:(idx+1)*batch], return_tensors="np")['input_ids'])

  0%|          | 0/6 [00:00<?, ?it/s]

In [211]:
len(tokenCop)

60000

In [215]:
sorted(idfToken.items(), key=lambda x:x[1])

[('ি', 0),
 ('も', 0),
 ('[unused765]', 0),
 ('disrepair', 0),
 ('squinted', 0),
 ('[unused51]', 0),
 ('[unused373]', 0),
 ('##⁻', 0),
 ('nestor', 0),
 ('[unused172]', 0),
 ('[unused909]', 0),
 ('##₃', 0),
 ('བ', 0),
 ('[unused164]', 0),
 ('##জ', 0),
 ('[unused691]', 0),
 ('##⇒', 0),
 ('ₛ', 0),
 ('[unused723]', 0),
 ('メ', 0),
 ('ᵐ', 0),
 ('秀', 0),
 ('[unused317]', 0),
 ('[unused74]', 0),
 ('retracted', 0),
 ('serie', 0),
 ('spaceship', 0),
 ('slung', 0),
 ('whispering', 0),
 ('##ᵏ', 0),
 ('[unused344]', 0),
 ('deportivo', 0),
 ('ζ', 0),
 ('[unused210]', 0),
 ('##ே', 0),
 ('[unused514]', 0),
 ('sprawled', 0),
 ('義', 0),
 ('shuddered', 0),
 ('dioceses', 0),
 ('italianate', 0),
 ('haynes', 0),
 ('ᄏ', 0),
 ('##դ', 0),
 ('##長', 0),
 ('##阝', 0),
 ('[unused826]', 0),
 ('[unused256]', 0),
 ('[unused257]', 0),
 ('[unused53]', 0),
 ('##■', 0),
 ('harmonica', 0),
 ('[unused671]', 0),
 ('[unused757]', 0),
 ('##п', 0),
 ('[unused840]', 0),
 ('[unused927]', 0),
 ('戦', 0),
 ('[unused85]', 0),
 ('grima

In [None]:
# Is during pregnancy caffeine dangerous?

In [181]:
idf=np.array(list(idfToken.values()))

In [185]:
idf[idf>0].min()

1

In [187]:
(idf==0).sum()

2908

In [188]:
(idf==1).sum()

35

In [216]:
sorted(idfToken.items(), key=lambda x:-x[1])

[('[SEP]', 60000),
 ('[CLS]', 60000),
 ('.', 58975),
 ('the', 52594),
 (',', 50504),
 ('of', 44417),
 ('and', 43502),
 ('a', 40833),
 ('to', 38390),
 ('in', 36382),
 ('is', 31998),
 ('for', 22325),
 ('-', 20306),
 ('that', 17595),
 (')', 17060),
 ('(', 16945),
 ('or', 16408),
 ('on', 16330),
 ('with', 15429),
 ('as', 15224),
 ('are', 15190),
 ('it', 14810),
 ('##s', 14246),
 ('by', 13002),
 (':', 12751),
 ("'", 12622),
 ('you', 12062),
 ('from', 11992),
 ('an', 11364),
 ('be', 11339),
 ('this', 10821),
 ('s', 10727),
 ('1', 10616),
 ('at', 9957),
 ('can', 9953),
 ('your', 8881),
 ('2', 8625),
 ('have', 8181),
 ('not', 7589),
 ('was', 7404),
 ('one', 6814),
 ('which', 6624),
 ('if', 6438),
 ('3', 6279),
 ('more', 6173),
 ('has', 6160),
 ('all', 6072),
 ('but', 5956),
 ('will', 5841),
 ('when', 5713),
 ('also', 5650),
 ('most', 5206),
 ('other', 5121),
 ('##a', 5066),
 ('they', 4896),
 ('/', 4778),
 ('about', 4726),
 ('may', 4636),
 ('##as', 4476),
 ('there', 4436),
 ('up', 4352),
 ('i',