In [1]:
import itertools, collections, json, string, re
from transformers import AutoTokenizer, AutoModelForMaskedLM

In [2]:
model_name = "stevhliu/my_awesome_eli5_mlm_model"
# model_name = "microsoft/deberta-v3-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

KeyboardInterrupt: 

In [None]:
def tokenize_and_preserve(sentence, text_labels=None):
    if type(sentence) == str:
        sentence = sentence.translate({ord(c): " " for c in string.punctuation}).split()
    if text_labels is None:
        text_labels = itertools.count()
    tokenized_sentence = []
    labels = []

    for word, label in zip(sentence, text_labels):
        # Tokenize the word and count # of subwords the word is broken into
        tokenized_word = tokenizer.tokenize(word)
        n_subwords = len(tokenized_word)

        # Add the tokenized word to the final tokenized word list
        tokenized_sentence.extend(tokenized_word)

        # Add the same label to the new list of labels `n_subwords` times
        labels.extend([label] * n_subwords)
    cnt = itertools.count()
    return [
        (k, [(next(cnt), t, tokenizer.convert_tokens_to_ids(t)) for i, t in g])
        for k, g in itertools.groupby(zip(labels, tokenized_sentence), lambda x: x[0])
    ]

In [None]:
tokenize_and_preserve("my name is bert")

[(0, [(0, 'my', 4783)]),
 (1, [(1, 'name', 13650)]),
 (2, [(2, 'is', 354)]),
 (3, [(3, 'bert', 6747)])]

In [None]:
tokenizer.convert_tokens_to_ids("rt")

9713

In [None]:
txt = "my name is bert"

In [None]:
def mask_expansion(txt, k=10):
    ret = collections.defaultdict(list)
    X = tokenizer.encode(txt, return_tensors="pt")
    words = tokenize_and_preserve(txt)
    for wi, lst in words:
        X_m = X.clone()
        for mask_token_index, token, _ in lst:
            ti = mask_token_index
            if tokenizer.bos_token:
                ti += 1
            X_m[0, ti] = tokenizer.mask_token_id
        logits = model(X_m).logits
        for mask_token_index, token, _ in lst:
            mask_token_logits = logits[0, mask_token_index, :]
            max_ids = np.argsort(mask_token_logits.to("cpu").detach().numpy())[::-1][:k]
            max_tokens = tokenizer.convert_ids_to_tokens(max_ids)
            ret[wi].extend(max_tokens)
    ret = dict(ret)
    if tokenizer.bos_token:
        del ret[0]
    ret = list(ret.values())
    return ret

In [None]:
me = mask_expansion("my name is bert")
me

[['my', 'My', 'the', 'Ġmy', 'm', 'by', 's', 'MY', 'your', 'y'],
 ['Ġname',
  'ĠName',
  'Ġnickname',
  'name',
  'Ġtitle',
  'Ġstart',
  'Ġn',
  'Ġand',
  'Ġam',
  'Ġy'],
 ['Ġis', 'Ġwas', 'Ġam', 'Ġa', ':', 'Ġare', 'ĠIs', 'ĠIS', 'is', 'Ġhas']]

In [None]:
def only_alpha(txt):
    return "".join(c for c in txt if c in string.ascii_letters)


def elastic_format(expanded_list):
    ret = []
    for words in expanded_list:
        words = set(only_alpha(w).lower() for w in words)
        t = "("
        t += " OR ".join(words)
        t += ")"
        ret.append(t)
    return " ".join(ret)

In [None]:
def elastic_splade(txt):
    me = mask_expansion(txt)
    ret = elastic_format(me)
    return ret


elastic_splade("My name is John")

'(his OR a OR the OR s OR my OR i OR this OR our) (by OR name OR title OR id OR time OR m OR am OR start OR named) ( OR is OR was OR a OR s OR are OR am)'

# TODO:

1. Take in to account the logit values and use the `^` parameter for weights
1. Deploy a PYPI package

In [1]:
# add path to simple_splade
import sys

sys.path.append("../simple_splade")

from elastic_splade import splade

In [2]:
# model_name = "stevhliu/my_awesome_eli5_mlm_model"
model_name = "bert-base-uncased"

In [3]:
spalde_model = splade(model_name, model_name)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
test_texts = [
    "My name is John",
    "The quick brown fox jumps over the lazy dog",
    "I like to eat apples",
]

In [5]:
for test_text in test_texts:
    print(test_text)
    print(spalde_model.splade_it(test_text))

My name is John
(my^0.25 OR his^0.13 OR her^0.11 OR the^0.09 OR your^0.09 OR their^0.08 OR its^0.07 OR our^0.07 OR last^0.06 OR another^0.06) (name^0.3 OR father^0.09 OR husband^0.08 OR dad^0.08 OR brother^0.08 OR surname^0.08 OR nickname^0.07 OR title^0.07 OR boyfriend^0.07 OR son^0.07) (is^0.33 OR ^0.27 OR was^0.13 OR means^0.08 OR says^0.06 OR are^0.06 OR goes^0.06)
The quick brown fox jumps over the lazy dog
(the^0.29 OR a^0.15 OR one^0.09 OR some^0.08 OR little^0.07 OR this^0.07 OR his^0.06 OR another^0.06 OR no^0.06 OR my^0.06) (lazy^0.21 OR little^0.12 OR fat^0.09 OR young^0.09 OR big^0.09 OR great^0.08 OR hungry^0.08 OR small^0.08 OR large^0.08 OR old^0.08) (thinking^0.11 OR little^0.1 OR old^0.1 OR ^0.1 OR talking^0.1 OR y^0.1 OR ie^0.1 OR ing^0.1 OR en^0.09 OR ening^0.09) (dog^0.2 OR cat^0.1 OR ie^0.1 OR bear^0.09 OR man^0.09 OR one^0.09 OR boy^0.09 OR girl^0.08 OR guy^0.08 OR wolf^0.08) (took^0.12 OR takes^0.11 OR watched^0.11 OR watches^0.1 OR loomed^0.1 OR looked^0.1 OR ra