In [3]:
# Import required libraries
from sentence_transformers import SentenceTransformer
import torch
import numpy as np
import pandas as pd
import math

# Select and initialize a pre-trained model for embeddings
MODEL = SentenceTransformer("sentence-transformers/sentence-t5-base")

# Function to compute cosine similarity between two vectors
def cosine_similarity(a, b):
    unity = torch.Tensor([1.0]).to(b.device).type(a.type())
    if not torch.isclose(a[0].norm(), unity):  # Normalize a if not already done
        a = (a.T / a.norm(dim=1)).T
    if not torch.isclose(b[0].norm(), unity):  # Normalize b if not already done
        b = (b.T / b.norm(dim=1)).T
    return torch.matmul(a, b.T)  # Return cosine similarity as dot product

# Function to compute embeddings for given strings
def embed_strings(strings):
    return torch.Tensor(MODEL.encode(strings, normalize_embeddings=True))

# Function to rank documents based on 'surprise score'
def rank_documents(keys, query, ensemble):
    # Get embeddings for keys and query
    key_embeddings = embed_strings(keys)
    query_embedding = embed_strings([query])

    # Compute ensemble embeddings as the embeddings of the keys
    ensemble_embeddings = embed_strings(ensemble)

    # Calculate cosine similarities between ensemble embeddings
    ensemble_cos_scores = cosine_similarity(ensemble_embeddings, ensemble_embeddings)

    # Calculate cosine similarities between keys and query
    cos_scores = cosine_similarity(key_embeddings, query_embedding).numpy().T

    # Compute mean and standard deviation for ensemble cosine similarities
    all_means = np.median(ensemble_cos_scores.T.numpy(), axis=1)
    all_stds = np.std(ensemble_cos_scores.T.numpy(), axis=1)
    avgs_stds = np.array([all_means, all_stds]).T

    # Compute surprise deviations
    surprise_devs = ((cos_scores - avgs_stds[:, 0]) / (avgs_stds[:, 1]))[0, :]

    # Compute surprise scores using the normal distribution cumulative function
    surprise_scores = np.array([((1 + math.erf(surprise_devs[i] / 2**0.5)) / 2) for i in range(len(surprise_devs))])

    # Return results as a list of dictionaries containing word, surprise score, mean, standard deviation,
    # cosine similarity and surprise deviation for each key
    results = [
        {
            "word": word,
            "surprise": surprise,
            "mean": all_means[i],
            "std": all_stds[i],
            "cosine": cos_scores[0, i],
            "surprise dev": surprise_devs[i],
        }
        for i, (word, surprise) in enumerate(zip(keys, surprise_scores))
    ]
    return results


In [4]:
# Import the english words from english_words library
from english_words import english_words_alpha_set

# Define the query word
query = "dog"

# Define keys as all words in the English dictionary
keys =  list(english_words_alpha_set)

# Rank all words in the dictionary based on their surprise score with respect to the query
ranking = rank_documents(keys=keys, query=query, ensemble = keys)

# Convert the ranking results into a pandas DataFrame for easy manipulation and visualization
df = pd.DataFrame(ranking)[["word", "cosine", "surprise", "surprise dev", "mean", "std"]]

# Sort the DataFrame by surprise score in descending order, reset the index and display the top 20 words
df.sort_values("surprise", ascending=False).reset_index(drop=True).head(20)

Unnamed: 0,word,cosine,surprise,surprise dev,mean,std
0,dog,1.0,1.0,9.353889,0.783604,0.023134
1,canine,0.952441,1.0,6.934583,0.79362,0.022903
2,pup,0.935472,1.0,6.278327,0.789649,0.023227
3,Doge,0.898406,1.0,6.251827,0.775615,0.019641
4,pooch,0.928522,1.0,6.04779,0.792307,0.022523
5,pug,0.907003,1.0,5.965437,0.770632,0.02286
6,doberman,0.879496,1.0,5.846106,0.751959,0.021816
7,Canis,0.885421,1.0,5.8245,0.76228,0.021142
8,hound,0.921632,1.0,5.755539,0.795787,0.021865
9,dachshund,0.886862,1.0,5.744237,0.76346,0.021483


In [5]:
df.sort_values("cosine", ascending=False).reset_index(drop=True).head(20)

Unnamed: 0,word,cosine,surprise,surprise dev,mean,std
0,dog,1.0,1.0,9.353889,0.783604,0.023134
1,canine,0.952441,1.0,6.934583,0.79362,0.022903
2,pup,0.935472,1.0,6.278327,0.789649,0.023227
3,animal,0.929151,1.0,5.342314,0.806774,0.022907
4,pooch,0.928522,1.0,6.04779,0.792307,0.022523
5,cat,0.922278,1.0,5.402059,0.800487,0.022545
6,hound,0.921632,1.0,5.755539,0.795787,0.021865
7,pet,0.914594,0.999999,4.890484,0.803803,0.022655
8,bark,0.91458,1.0,5.425861,0.784009,0.024065
9,pug,0.907003,1.0,5.965437,0.770632,0.02286
