```
docker stop es-local-dev kibana-local-dev ; curl -fsSL https://elastic.co/start-local | sh
```

# data setup

In [1]:
import pandas as pd
df = pd.read_csv('superheroes.csv')
hero_dict = df[['name', 'description']].set_index('name')['description']

hero_dict_alt = {
    "Spider-Man": "An adolescent scholar affected by an irradiated arachnid",
    "Batman": "A wealthy entrepreneur and humanitarian",
    "Wonder Woman": "A mythical female combatant from a secluded isle",
    "Iron Man": "A brilliant innovator and corporate magnate",
    "Superman": "An extraterrestrial being from a distant celestial body",
    "Black Panther": "Monarch of an imaginary technologically advanced realm",
    "The Flash": "A criminal investigator possessing extraordinary velocity",
    "Captain America": "A mid-20th century enhanced combatant",
    "Green Lantern": "An aviator selected by an intergalactic peacekeeping force",
    "Thor": "A Norse deity commanding atmospheric phenomena",
    "Hulk": "An academic transformed by electromagnetic emissions",
    "Wolverine": "A genetic anomaly with rapid recuperation and metallic appendages",
    "Black Widow": "An expertly trained covert operative",
    "Doctor Strange": "A brain surgeon transformed into a mystical guardian",
    "Deadpool": "A hired gun with rapid cellular regeneration",
    "Captain Marvel": "A former military aviator with extraterrestrial abilities",
    "Scarlet Witch": "A genetic anomaly capable of warping existence",
    "Ant-Man": "A reformed burglar capable of altering his dimensions",
    "Daredevil": "A visually impaired attorney with heightened perception",
    "Aquaman": "The semi-terrestrial sovereign of an underwater civilization",
    "Green Arrow": "A wealthy masked bowman",
    "Cyborg": "A former sportsman transformed into a mechanized defender",
    "Hawkeye": "An expert marksman and ex-carnival entertainer",
    "Black Canary": "A combatant with ultrasonic vocal capabilities",
    "Vision": "A synthetic being crafted by an artificial intelligence, energized by a cosmic gem",
    "Martian Manhunter": "A metamorphosing extraterrestrial from a neighboring planet",
    "Storm": "A genetic anomaly capable of atmospheric manipulation",
    "Nightwing": "The inaugural juvenile assistant who became an autonomous guardian",
    "Jean Grey": "A formidable psychic with telekinetic capabilities",
    "Shazam": "A juvenile who metamorphoses into an adult champion",
    "Beast": "A brilliant academic with feral characteristics",
    "Batgirl": "A technologically adept masked information specialist",
    "Gambit": "A genetic anomaly capable of imbuing objects with kinetic potential",
    "Green Lantern": "A construction professional and ex-military serviceman",
    "Wasp": "A couturier with mass-altering capabilities",
    "Zatanna": "An illusionist with authentic arcane abilities",
    "Cyclops": "The commander of genetic anomalies with ocular energy projection",
    "Supergirl": "A female relative of an extraterrestrial champion",
    "Falcon": "A former aerial rescue specialist with mechanical appendages",
    "Batwoman": "A former armed forces commander turned masked vigilante",
    "Luke Cage": "An individual with impenetrable epidermis and extraordinary vigor",
    "Starfire": "An otherworldly royal with luminous capabilities",
    "Quicksilver": "A genetic anomaly with extraordinary velocity",
    "Raven": "A semi-infernal empath with shadowy capabilities",
    "Moon Knight": "A masked guardian with dissociative identity disorder",
    "Firestorm": "Amalgamated into an atomic-powered champion",
    "She-Hulk": "A legal professional with emerald-hued capabilities",
    "Atom": "An academic capable of reducing to microscopic proportions",
    "Nova": "A member of a cosmic peacekeeping organization",
    "Plastic Man": "A reformed lawbreaker with malleable physiology",
    "Ghost Rider": "A motorcycle daredevil merged with an infernal entity"
}

# SPLADE setup

In [2]:
from transformers import AutoModelForMaskedLM, AutoTokenizer

import torch

model_id = 'naver/splade-cocondenser-ensembledistil'

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id)

vocab = tokenizer.get_vocab()
id2token = {v: k for k, v in vocab.items()}

  from .autonotebook import tqdm as notebook_tqdm
BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


In [3]:
# the equation is explained in the paper
# the code is copied from https://www.pinecone.io/learn/splade/#SPLADE-Embeddings 
def get_splade_embedding(text, num_tokens=50):
    # get the tokens
    tokens = tokenizer(text, return_tensors='pt')

    # get the splade embedding
    output = model(**tokens)
    vec = torch.max(
        torch.log(
            1 + torch.relu(output.logits)
        ) * tokens.attention_mask.unsqueeze(-1),
    dim=1)[0].squeeze()

    # Convert vec to numpy for easier manipulation
    vec_np = vec.detach().numpy()

    # Get indices of non-zero elements
    non_zero_indices = vec_np.nonzero()[0]

    # Create a list of (token, value) pairs for non-zero elements, excluding the input tokens
    token_value_pairs = [
        (id2token[idx], vec_np[idx]) 
        for idx in non_zero_indices 
        if idx not in tokens['input_ids'][0]
    ]

    # Sort by value in descending order
    token_value_pairs.sort(key=lambda x: x[1], reverse=True)

    new_tokens = [token for token, value in token_value_pairs[:num_tokens]]
        
    return new_tokens

def get_tokens_as_text(text):
    tokens = tokenizer(text, return_tensors='pt').input_ids[0]
    return ' '.join([id2token[i] for i in tokens.tolist()][1:-1])

    
text = "marry had a little lamb, it's fleece was white as snow"
print(get_tokens_as_text(text))
print(get_splade_embedding(text, num_tokens=100))



marry had a little lamb , it ' s flee ##ce was white as snow
['marriage', 'married', 'winter', 'song', 'wedding', 'have', 'sheep', 'whites', 'baby', 'like', 'color', 'wearing', 'film', 'character', 'murder', 'said', 'england', 'gay', 'story', 'horse', 'went', 'gypsy', 'were', 'snowfall', 'chorus', 'clothing', 'dance', 'got', 'the']


In [4]:
for hero, description in hero_dict.items():
    splade_tokens = get_splade_embedding(hero_dict_alt[hero],10)
    splade_tokens_w_hero = get_splade_embedding(hero_dict_alt[hero] + ' ' + hero,10)
    # print(hero, '|', description, '|', hero_dict_alt[hero], '|', splade_tokens, '|', splade_tokens_w_hero, "\n")
    num_included = 0
    for token in splade_tokens:
        if token in description:
            num_included += 1
    # print(f'Number of included tokens: {num_included}')


# elasticsearch setup

In [5]:
import json
from datetime import datetime
import math
import os

from elasticsearch import Elasticsearch
import pandas as pd

# print pwd using python
from pathlib import Path

# pull in environment variables
from dotenv import load_dotenv
path = Path.cwd().parent.parent.parent / 'elastic-start-local' / 'elastic-start-local' / '.env'
load_dotenv(path, override=True)
print(path)
print(os.getenv("ES_LOCAL_PASSWORD"))


/Users/johnberryman/Dropbox/Notebooks/elastic-start-local/elastic-start-local/.env
Y9mOHGDA


In [6]:
es = Elasticsearch(
    "http://localhost:9200",
    basic_auth=("elastic", os.getenv("ES_LOCAL_PASSWORD"))
)

# create enum for splade
from enum import Enum
class Splade(Enum):
    NONE = 1
    WITHOUT_HERO = 2
    WITH_HERO = 3
    WITH_HERO_AND_SUPERPOWERS = 4


def reindex_superheroes(splade=Splade.NONE, num_tokens=50):
    # Create the index with mappings
    index_name = "superheroes"
    mappings = {
        "mappings": {
            "dynamic": "false",
            "properties": {
                "name": {"type": "text"},
                "description": {
                    "type": "text",
                    "analyzer": "english",
                },
                "splade": {
                    "type": "text",
                }
            }
        }
    }

    # delete and recreate the index
    if es.indices.exists(index=index_name):
        es.indices.delete(index=index_name)
        print(f"Index '{index_name}' deleted successfully.")
    else:
        print(f"Index '{index_name}' does not exist.")

    es.indices.create(index=index_name, body=mappings)
    print(f"Index '{index_name}' created successfully.")

    df = pd.read_csv('superheroes.csv')
    # Index the superheroes
    for i, (index, row) in enumerate(df.iterrows(), start=1):
        # Combine the index (superhero name) with the row data
        full_row = pd.concat([pd.Series({'name': index}), row])
        doc = full_row.to_dict()
        if splade == Splade.WITHOUT_HERO:
            doc['splade'] = get_splade_embedding(doc['description'], num_tokens)
        elif splade == Splade.WITH_HERO:
            doc['splade'] = get_splade_embedding(doc['description'] + ' ' + doc['name'], num_tokens)
        elif splade == Splade.WITH_HERO_AND_SUPERPOWERS:
            doc['splade'] = get_splade_embedding(doc['description'] + ' ' + doc['name'] + ' ' + doc['superpowers'], num_tokens)
        else:
            doc['splade'] = []
        es.index(index=index_name, id=i, body=doc)

    print(f"Indexed {len(df)} superheroes.")

reindex_superheroes(splade=Splade.WITH_HERO, num_tokens=10)

Index 'superheroes' deleted successfully.
Index 'superheroes' created successfully.
Indexed 50 superheroes.


In [16]:

def search_superheroes(description, size, splade):
    if splade:
        splade_tokens = get_tokens_as_text(description)
        query = {
            "query": {
                "bool": {
                    "should": [
                        {
                            "multi_match": {
                                "query": description,
                                "fields": ["description"]
                            }
                        },
                        {
                            "multi_match": {
                                "query": splade_tokens,
                                "fields": ["splade"]
                            }
                        }
                    ]
                }
            }
        }
    else:
        query = {
            "query": {
                "multi_match": {
                    "query": description,
                    "fields": ["description"]
                }
            }
        }
    query['size'] = size
    # print(query)
    
    response = es.search(index="superheroes", body=query)

    hits = [hit['_source'] for hit in response['hits']['hits']]
    return hits

def retrieve_superhero(name):
    query = {
        "query": {
            "match": {
                "name": name
            }
        }
    }
    response = es.search(index="superheroes", body=query)
    if response['hits']['hits']:
        return response['hits']['hits'][0]['_source']


search_superheroes("spider boy", 3, True)



[{'name': 'Spider-Man',
  'true_identity': 'Peter Parker',
  'description': ' a high school student bitten by a radioactive spider',
  'comics': 'The Amazing Spider-Man',
  'publisher': 'Marvel Comics',
  'superpowers': 'Web-slinging, superhuman strength, spider-sense',
  'splade': ['bite',
   'spiders',
   'students',
   '##man',
   'radiation',
   'murder',
   'character',
   'film',
   'radio',
   'bomb']},
 {'name': 'Shazam',
  'true_identity': 'Billy Batson',
  'description': ' a young boy who transforms into an adult superhero',
  'comics': 'Whiz Comics',
  'publisher': 'DC Comics',
  'superpowers': 'Superhuman strength, flight, lightning manipulation, wisdom of Solomon',
  'splade': ['transform',
   'transformation',
   'character',
   'boys',
   'actor',
   'hero',
   'teen',
   'become',
   'film',
   'became']}]

In [8]:
def recall_at_k(k, splade):
    counter = 0
    for hero in hero_dict.keys():
        # print(hero)
        alt_description = hero_dict_alt[hero]
        search_results = search_superheroes(alt_description, k, splade)
        result_heroes = [result['name'] for result in search_results]
        if hero in result_heroes:
            counter += 1
        
    return counter / len(hero_dict.keys())

recall_at_k(100, False)

{'query': {'multi_match': {'query': 'An adolescent scholar affected by an irradiated arachnid', 'fields': ['description']}}, 'size': 100}
{'query': {'multi_match': {'query': 'A wealthy entrepreneur and humanitarian', 'fields': ['description']}}, 'size': 100}
{'query': {'multi_match': {'query': 'A mythical female combatant from a secluded isle', 'fields': ['description']}}, 'size': 100}
{'query': {'multi_match': {'query': 'A brilliant innovator and corporate magnate', 'fields': ['description']}}, 'size': 100}
{'query': {'multi_match': {'query': 'An extraterrestrial being from a distant celestial body', 'fields': ['description']}}, 'size': 100}
{'query': {'multi_match': {'query': 'Monarch of an imaginary technologically advanced realm', 'fields': ['description']}}, 'size': 100}
{'query': {'multi_match': {'query': 'A criminal investigator possessing extraordinary velocity', 'fields': ['description']}}, 'size': 100}
{'query': {'multi_match': {'query': 'A mid-20th century enhanced combatant

0.1

# Demo

In [38]:
import time
reindex_superheroes(splade=Splade.WITHOUT_HERO, num_tokens=50)
time.sleep(2)
k = 3
print(f'{k}: {recall_at_k(k, False)} -> {recall_at_k(k, True)}')

Index 'superheroes' deleted successfully.
Index 'superheroes' created successfully.
Indexed 51 superheroes.
3: 0.28 -> 0.52
