In [None]:
!pip install transformers elasticsearch 

import numpy as np 
from transformers import AutoTokenizer, AutoModel 
from elasticsearch import Elasticsearch 
import torch 

# Define Elasticsearch connection with credentials 
es = Elasticsearch(
    ['https://host:port'],
    http_auth=('username', 'password'),
    verify_certs=False
)
 

# Define the mapping for the dense vector field 
mapping = { 
    'properties': { 
        'embedding': { 
            'type': 'dense_vector', 
            'dims': 768, # the number of dimensions of the dense vector 
            'index': 'true',
            "similarity": "cosine"
        } 
    } 
} 

# Create an index with the defined mapping 
es.indices.create(index='jokes-index', body={'mappings': mapping}) 

# Define a set of jokes 
jokes = [ 
    { 
        'text': 'Why do cats make terrible storytellers? Because they only have one tail.', 
        'category': 'cat' 
    }, 
    { 
        'text': 'What did the cat say when he lost all his money? I am paw.', 
        'category': 'cat' 
    }, 
    { 
        'text': 'Why don\'t cats play poker in the jungle? Too many cheetahs.', 
        'category': 'cat' 
    },
    { 
        'text': 'Why did the tomato turn red? Because it saw the salad dressing!', 
        'category': 'vegetable' 
    },
    { 
        'text': 'Why did the scarecrow win an award? Because he was outstanding in his field.', 
        'category': 'farm' 
    },
    { 
        'text': 'Why did the hipster burn his tongue? Because he drank his coffee before it was cool.', 
        'category': 'hipster' 
    },    
    {
        'text': 'Why did the tomato turn red? Because it saw the salad dressing!', 
        'category': 'food' 
    },
    {
        'text': 'Why did the scarecrow win an award? Because he was out-standing in his field!', 
        'category': 'puns' 
    },
    {
        'text': 'What do you call a fake noodle? An impasta!', 
        'category': 'food' 
    },
    {
        'text': 'What do you call a belt made out of watches? A waist of time!', 
        'category': 'puns' 
    },
    {
        'text': 'Why did the math book look sad? Because it had too many problems!', 
        'category': 'math' 
    },
    {
        'text': 'Why did the gym close down? It just didn\'t work out!', 
        'category': 'exercise' 
    },
    {
        'text': 'Why don\'t scientists trust atoms? Because they make up everything!', 
        'category': 'science' 
    },
    {
        'text': 'What do you call a fake noodle? An impasta!', 
        'category': 'food' 
    },
    {
        'text': 'Why did the chicken cross the playground? To get to the other slide!', 
        'category': 'kids' 
    },
    {
        'text': 'Why did the frog call his insurance company? He had a jump in his car!', 
        'category': 'puns' 
    }

] 

# Load the BERT tokenizer and model 
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') 
model = AutoModel.from_pretrained('bert-base-uncased') 

# Generate embeddings for the jokes using BERT 
for joke in jokes: 
    text = joke['text'] 
    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True) 
    with torch.no_grad(): 
        output = model(**inputs).last_hidden_state.mean(dim=1).squeeze(0).numpy() 
        joke['embedding'] = output.tolist() 

# Index the jokes in Elasticsearch 
for joke in jokes: 
    es.index(index='jokes-index', body=joke) 

# Define the query vector 
# Define a query text and convert it to a dense vector using BERT
query = "What do you get when you cross a snowman and a shark?"
inputs = tokenizer(query, return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
    output = model(**inputs).last_hidden_state.mean(dim=1).squeeze(0).numpy()
query_vector = output

# Define the Elasticsearch KNN search 
search = {
    "knn": {
        "field": "embedding",
        "query_vector": query_vector.tolist(),
        "k": 3,
        "num_candidates": 100
    },
    "fields": [ "text" ]
}

# Perform the KNN search and print the results 
response = es.search(index='jokes-index', body=search)
for hit in response['hits']['hits']:
    print(f"Joke: {hit['_source']['text']}")
