In [12]:
import itertools, collections, json, string, re
from transformers import AutoTokenizer, AutoModelForMaskedLM

In [120]:
model_name = "stevhliu/my_awesome_eli5_mlm_model"
# model_name = "microsoft/deberta-v3-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

Downloading (…)okenizer_config.json:   0%|          | 0.00/386 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/279 [00:00<?, ?B/s]

In [121]:
def tokenize_and_preserve(sentence, text_labels=None):
    if type(sentence)==str:
        sentence = sentence.translate({ord(c):" " for c in string.punctuation}).split()
    if text_labels is None:
        text_labels = itertools.count()
    tokenized_sentence = []
    labels = []

    for word, label in zip(sentence, text_labels):

        # Tokenize the word and count # of subwords the word is broken into
        tokenized_word = tokenizer.tokenize(word)
        n_subwords = len(tokenized_word)

        # Add the tokenized word to the final tokenized word list
        tokenized_sentence.extend(tokenized_word)

        # Add the same label to the new list of labels `n_subwords` times
        labels.extend([label] * n_subwords)
    cnt = itertools.count()
    return [(k,[(next(cnt),t,tokenizer.convert_tokens_to_ids(t)) for i,t in g]) for k,g in itertools.groupby(zip(labels,tokenized_sentence),lambda x:x[0])]

In [122]:
tokenize_and_preserve("my name is bert")

[(0, [(0, 'my', 4783)]),
 (1, [(1, 'name', 13650)]),
 (2, [(2, 'is', 354)]),
 (3, [(3, 'bert', 6747)])]

In [123]:
tokenizer.convert_tokens_to_ids('rt')

9713

In [124]:
txt = "my name is bert"


In [152]:
def mask_expansion(txt,k=10):
    ret = collections.defaultdict(list)
    X = tokenizer.encode(txt, return_tensors='pt')
    words = tokenize_and_preserve(txt)
    for wi,lst in words:
        X_m=X.clone()
        for mask_token_index,token,_ in lst:
            ti = mask_token_index
            if tokenizer.bos_token:
                ti+=1
            X_m[0,ti]=tokenizer.mask_token_id
        logits = model(X_m).logits
        for mask_token_index,token,_ in lst:
            mask_token_logits = logits[0, mask_token_index, :]
            max_ids = np.argsort(mask_token_logits.to("cpu").detach().numpy())[::-1][:k]
            max_tokens = tokenizer.convert_ids_to_tokens(max_ids)
            ret[wi].extend(max_tokens)
    ret = dict(ret)
    if tokenizer.bos_token:
        del ret[0]
    ret = list(ret.values())
    return ret
    

In [153]:
me = mask_expansion("my name is bert")
me

[['my', 'My', 'the', 'Ġmy', 'm', 'by', 's', 'MY', 'your', 'y'],
 ['Ġname',
  'ĠName',
  'Ġnickname',
  'name',
  'Ġtitle',
  'Ġstart',
  'Ġn',
  'Ġand',
  'Ġam',
  'Ġy'],
 ['Ġis', 'Ġwas', 'Ġam', 'Ġa', ':', 'Ġare', 'ĠIs', 'ĠIS', 'is', 'Ġhas']]

In [167]:
def only_alpha(txt):
    return "".join(c for c in txt if c in string.ascii_letters)

def elastic_format(expanded_list):
    ret = []
    for words in expanded_list:
        words = set(only_alpha(w).lower() for w in words)
        t="("
        t+=" OR ".join(words)
        t+=")"
        ret.append(t)
    return " ".join(ret)

In [168]:
def elastic_splade(txt):
    me = mask_expansion(txt)
    ret = elastic_format(me)
    return ret

elastic_splade("My name is John")

'(his OR a OR the OR s OR my OR i OR this OR our) (by OR name OR title OR id OR time OR m OR am OR start OR named) ( OR is OR was OR a OR s OR are OR am)'

# TODO:

1. Take in to account the logit values and use the `^` parameter for weights
1. Deploy a PYPI package

In [1]:
# add path to simple_splade
import sys

sys.path.append("../simple_splade")

from elastic_splade import splade

In [2]:
model_name = "stevhliu/my_awesome_eli5_mlm_model"

In [3]:
spalde_model = splade(model_name, model_name)

In [4]:
test_text = "My name is John"

In [5]:
spalde_model.splade_it(test_text)

'(his OR my OR this OR i OR a OR the OR s OR our) (am OR m OR id OR name OR named OR title OR by OR start OR time) ( OR am OR was OR is OR s OR a OR are)'