In [8]:
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch

In [2]:

model_id = 'naver/splade-cocondenser-ensembledistil'

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id)

In [24]:
text = (
    "Histochemical specificity of cholinesterases to phenylthioacetate"
)


In [25]:
tokens = tokenizer(text, return_tensors='pt')
output = model(**tokens)
output

MaskedLMOutput(loss=None, logits=tensor([[[ -6.0833,  -8.0931,  -7.5358,  ...,  -7.4679,  -7.2170,  -4.7996],
         [ -9.4616, -10.0497,  -9.9420,  ...,  -9.9228, -10.1299,  -8.1783],
         [ -7.0096,  -8.4798,  -8.2369,  ...,  -8.2445,  -7.7094,  -5.3451],
         ...,
         [ -6.3693,  -7.8741,  -7.4466,  ...,  -7.6402,  -7.3879,  -5.2072],
         [ -7.3810,  -8.6121,  -8.0993,  ...,  -8.0730,  -8.0244,  -5.6252],
         [-20.5680, -16.4817, -16.0739,  ..., -15.9652, -15.0908, -17.0160]]],
       grad_fn=<ViewBackward0>), hidden_states=None, attentions=None)

In [26]:
output.logits.shape

torch.Size([1, 20, 30522])

In [27]:


vec = torch.max(
    torch.log(
        1 + torch.relu(output.logits)
    ) * tokens.attention_mask.unsqueeze(-1),
dim=1)[0].squeeze()

vec.shape

torch.Size([30522])

In [28]:
vec

tensor([0., 0., 0.,  ..., 0., 0., 0.], grad_fn=<SqueezeBackward0>)

In [29]:


# extract non-zero positions
cols = vec.nonzero().squeeze().cpu().tolist()
print(len(cols))

# extract the non-zero values
weights = vec[cols].cpu().tolist()
# use to create a dictionary of token ID to weight
sparse_dict = dict(zip(cols, weights))
sparse_dict



49


{1998: 0.03310779854655266,
 2000: 0.6177812814712524,
 2010: 1.005732536315918,
 2168: 0.08255530148744583,
 2193: 0.011573554016649723,
 2368: 1.0316563844680786,
 2504: 0.17538274824619293,
 2791: 0.19038361310958862,
 3012: 1.4128109216690063,
 3231: 0.19589324295520782,
 3276: 0.08402110636234283,
 3366: 0.5248109698295593,
 3401: 0.6450810432434082,
 3406: 0.7270060777664185,
 3563: 1.797067642211914,
 3739: 0.018558084964752197,
 4179: 1.004936695098877,
 4668: 0.20239974558353424,
 4742: 0.43958884477615356,
 4962: 0.31664133071899414,
 5072: 0.7334992289543152,
 5250: 0.07509680837392807,
 5648: 0.017283421009778976,
 5783: 0.40568697452545166,
 6370: 0.37177905440330505,
 6463: 0.1956566423177719,
 6494: 0.28630903363227844,
 6693: 0.5737876892089844,
 6887: 0.882270097732544,
 7730: 0.13819989562034607,
 8516: 1.0287171602249146,
 8583: 0.29585856199264526,
 9007: 0.7726047039031982,
 10441: 0.7270888090133667,
 10788: 0.19046343863010406,
 11460: 0.07744210958480835,
 12115

In [30]:


# extract the ID position to text token mappings
idx2token = {
    idx: token for token, idx in tokenizer.get_vocab().items()
}



In [31]:


# map token IDs to human-readable tokens
sparse_dict_tokens = {
    idx2token[idx]: round(weight, 2) for idx, weight in zip(cols, weights)
}
# sort so we can see most relevant tokens first
sparse_dict_tokens = {
    k: v for k, v in sorted(
        sparse_dict_tokens.items(),
        key=lambda item: item[1],
        reverse=True
    )
}

In [32]:
sparse_dict_tokens

{'specific': 1.8,
 '##ity': 1.41,
 '##tera': 1.32,
 '##chemical': 1.23,
 'cho': 1.06,
 '##en': 1.03,
 '##yl': 1.03,
 '##thi': 1.02,
 'his': 1.01,
 '##line': 1.0,
 '##tate': 0.99,
 'ph': 0.88,
 'inhibitor': 0.82,
 'enzyme': 0.77,
 '##to': 0.73,
 'chemical': 0.73,
 '##oa': 0.73,
 '##ce': 0.65,
 'to': 0.62,
 'concentration': 0.57,
 '##se': 0.52,
 'signal': 0.44,
 '##lines': 0.44,
 'element': 0.41,
 'chemistry': 0.37,
 'similarity': 0.34,
 'gene': 0.32,
 'mutation': 0.31,
 '##ses': 0.3,
 '##tra': 0.29,
 'molecule': 0.25,
 'marker': 0.23,
 'test': 0.2,
 'reaction': 0.2,
 'ratio': 0.2,
 '##ness': 0.19,
 'detection': 0.19,
 'level': 0.18,
 '##ivity': 0.14,
 'hormone': 0.1,
 'same': 0.08,
 'relationship': 0.08,
 'protein': 0.08,
 'mg': 0.08,
 'and': 0.03,
 'presence': 0.02,
 'acid': 0.02,
 'number': 0.01,
 '##chrome': 0.0}