# Evaluate Embedding Model for use in Retrieval

In [1]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer
from scipy.stats import cramervonmises_2samp, ks_2samp, skew

model = SentenceTransformer("multi-qa-mpnet-base-dot-v1")

  from tqdm.autonotebook import tqdm, trange


The following information is scraped from the Wikipedia articles for various topics. OpenAI's GPT4o has generated short sentences and queries about the articles.

In [2]:
llama_statements = [
    "The llama (/ˈlɑːmə/; Spanish pronunciation: [ˈʎama] or [ˈʝama]) is a domesticated South American camelid.",
    "Llamas have been used by Andean cultures as meat and pack animals since the pre-Columbian era.",
    "Llamas are social animals that live in herds.",
    "Their wool is soft and contains only a small amount of lanolin.",
    "Llamas can learn simple tasks after a few repetitions.",
    "They can carry about 25 to 30% of their body weight for 8 to 13 km (5–8 miles).",
    "The name 'llama' was adopted by European settlers from native Peruvians.",
    "Llamas' ancestors originated from the Great Plains of North America about 40 million years ago.",
    "They migrated to South America about three million years ago.",
    "This migration occurred during the Great American Interchange.",
    "Camelids became extinct in North America by the end of the last ice age (10,000–12,000 years ago).",
    "As of 2007, there were over seven million llamas and alpacas in South America.",
    "The United States and Canada had over 158,000 llamas and 100,000 alpacas in 2007.",
    "These North American populations descended from llamas imported in the late 20th century.",
    "Llamas are used for carrying loads in Andean cultures.",
    "The word 'llama' was sometimes spelled 'lama' or 'glama' in the past.",
    "Llamas are capable of transporting loads over long distances.",
    "Llamas are known for their soft wool.",
    "European settlers learned about llamas from the indigenous people of Peru.",
    "The llama's migration to South America was a significant event in its evolutionary history.",
    "Llamas belong to the species Lama glama.",
    "They are widely used as pack animals in the Andes.",
    "Llamas can live and work in harsh mountainous environments.",
    "The soft wool of llamas is prized for its quality.",
    "Llamas are capable of adapting to various tasks.",
    "They play an important role in Andean agriculture and transportation.",
    "Llamas communicate using a series of vocalizations.",
    "The llama's scientific name is Lama glama.",
    "They have been domesticated for thousands of years.",
    "Llamas are closely related to alpacas, vicuñas, and guanacos.",
    "They have a lifespan of about 15 to 25 years.",
    "Llamas are known for their gentle and calm demeanor.",
    "Their ancestors once roamed North America.",
    "Llamas have a unique split-upper lip.",
    "They can thrive on a diet of grass and hay.",
    "Llamas are often used in therapy and educational programs.",
    "Their wool comes in a variety of natural colors.",
    "Llamas have been depicted in ancient Andean art and artifacts.",
    "They are an integral part of Andean cultural heritage.",
    "Llamas can form strong bonds with humans and other animals.",
    "Llamas can spit as a defense mechanism.",
    "They have padded feet that are well-suited for mountainous terrain.",
    "Llamas are known for their long necks and large eyes.",
    "They can weigh between 130 and 200 kilograms (290–440 pounds).",
    "Llamas are herbivores, primarily eating grass and hay.",
    "They have a three-compartment stomach for efficient digestion.",
    "Llamas are often used in trekking and hiking tours.",
    "They have a strong sense of curiosity.",
    "Llamas can be trained to pull carts.",
    "Their wool is hypoallergenic.",
    "Llamas have a gestation period of about 11.5 months.",
    "A baby llama is called a cria.",
    "Llamas are important to the rural economy in the Andes.",
    "They can survive with minimal water in arid environments.",
    "Llamas have a strong herding instinct.",
    "They can recognize individual humans and animals.",
    "Llamas' wool is often used for making textiles and clothing.",
    "They are sometimes used as guard animals for livestock.",
    "Llamas have been featured in various films and media.",
    "They play a significant role in traditional Andean festivals and rituals.",
]

alpaca_statements = [
    "The alpaca (Lama pacos) is a South American camelid mammal.",
    "Alpacas traditionally graze on the Andes' high plateaus.",
    "They are found in Southern Peru, Western Bolivia, Ecuador, and Northern Chile.",
    "Today, alpacas are raised globally on farms and ranches.",
    "Thousands of alpacas are born annually in North America, Europe, and Australia.",
    "There are two main breeds of alpaca: the Suri and the Huacaya.",
    "Suri alpacas produce straight 'locks' of fiber.",
    "Huacaya alpacas have crimped, wavy wool.",
    "Both types of alpaca fiber are highly valued.",
    "Alpaca fiber is used for knitted and woven items.",
    "Alpacas are similar in appearance to llamas.",
    "Alpacas are shorter than llamas.",
    "Alpacas are primarily bred for their wool.",
    "Llamas are often used as livestock guardians.",
    "Llamas also serve as pack animals.",
    "All four South American camelids are closely related.",
    "Alpacas and llamas can interbreed.",
    "Alpacas are believed to be domesticated from the vicuña.",
    "Llamas are thought to descend from the guanaco.",
    "Domestication of alpacas and llamas occurred 5,000 to 6,000 years ago.",
]

cattle_statements = [
    "Cattle (Bos taurus) are large, domesticated, bovid ungulates widely kept as livestock.",
    "They are prominent modern members of the subfamily Bovinae.",
    "Cattle are the most widespread species of the genus Bos.",
    "Mature female cattle are called cows.",
    "Mature male cattle are called bulls.",
    "Young female cattle are called heifers.",
    "Young male cattle are called oxen or bullocks.",
    "Castrated male cattle are known as steers.",
    "Cattle are commonly raised for meat.",
    "They are also raised for dairy products.",
    "Cattle are raised for leather as well.",
    "As draft animals, cattle pull carts and farm implements.",
    "In India, cattle are sacred animals within Hinduism.",
    "Cattle may not be killed in Hinduism.",
    "Small breeds like the miniature Zebu are kept as pets.",
    "Taurine cattle are widely distributed across Europe and temperate areas of Asia.",
    "Zebus are found mainly in India and tropical areas of Asia, America, and Australia.",
    "Sanga cattle are found primarily in sub-Saharan Africa.",
    "These types of cattle are sometimes classified as separate species or subspecies.",
    "There are over 1,000 recognized breeds of cattle.",
]

plastic_statements = [
    "A TPU resin consists of linear polymeric chains in block-structures.",
    "These chains contain low polarity segments which are rather long, called soft segments.",
    "Alternating with the soft segments are shorter, high polarity segments known as hard segments.",
    "Both types of segments in TPU are linked by covalent links.",
    "The segments form block-copolymers in TPU.",
    "The miscibility of the hard and soft segments depends on their glass transition temperature (Tg).",
    "Tg occurs at the onset of micro-Brownian segmental motion.",
    "It can be identified by dynamic mechanical spectra.",
    "For an immiscible TPU, the loss modulus spectrum typically shows double peaks.",
    "Each peak is assigned to the Tg of one component.",
    "If the two components are miscible, TPU shows a single broad peak.",
    "This peak's position lies between that of the two original Tg peaks.",
    "The polarity of the hard segments creates a strong attraction between them.",
    "This attraction causes a high degree of aggregation and order in this phase.",
    "Crystalline or pseudo-crystalline areas form in a soft and flexible matrix.",
    "Phase separation between blocks can vary based on polarity and molecular weight.",
    "The crystalline areas act as physical cross-links, giving TPU its high elasticity.",
    "The flexible chains provide elongation characteristics to the polymer.",
    "Pseudo cross-links disappear under heat, allowing for classical processing methods.",
    "TPU scrap can be reprocessed due to the thermal properties of the material.",
]

In [3]:
def compute_metric(
    model: SentenceTransformer, 
    query: list[str], 
    query_label: str,
    documents: list[list[str]], 
    document_labels: list[str],
):

    query_embedding = model.encode(query)
    document_embeddings = [
        model.encode(text)
        for text in documents
    ]

    self_distance = (model.similarity(query_embedding, query_embedding) / 100.0).flatten()
    cross_distances = [
        (model.similarity(query_embedding, document_embedding) / 100.0).flatten()
        for document_embedding in document_embeddings
    ]

    cvm_metrics = [
        cramervonmises_2samp(self_distance, dist)
        for dist in cross_distances
    ]
    ks_metrics = [
        ks_2samp(self_distance, dist)
        for dist in cross_distances
    ]

    stats = {
        "std": torch.std(self_distance),
        "mean": torch.mean(self_distance),
        "skew": skew(self_distance)
    }
    metrics = {
        query_label: {
            label: {
                "CvM": {
                    "statistic": cvm_metrics[i].statistic,
                    "pvalue": cvm_metrics[i].pvalue,
                },
                "KS": {
                    "statistic": ks_metrics[i].statistic,
                    "pvalue": ks_metrics[i].pvalue,
                },
            }
            for i, label in enumerate(document_labels)
        }
    }

    return (metrics, stats)

def compute_document_metrics(
    model: SentenceTransformer,
    documents: list[list[str]],
    labels: list[str],
):
    cvm_stats = []
    cvm_pvalues = []
    ks_stats = []
    ks_pvalues = []
    distribution_stats = {}
    for document, label in zip(documents, labels):
        metrics, stats = compute_metric(
            model=model,
            query=document,
            query_label=label,
            documents=documents,
            document_labels=labels,
        )
        cvm_stats.append(
            [
                metrics[label][other]["CvM"]["statistic"]
                for other in labels
            ]
        )
        cvm_pvalues.append(
            [
                metrics[label][other]["CvM"]["pvalue"]
                for other in labels
            ]
        )
        ks_stats.append(
            [
                metrics[label][other]["KS"]["statistic"]
                for other in labels
            ]
        )
        ks_pvalues.append(
            [
                metrics[label][other]["KS"]["pvalue"]
                for other in labels
            ]
        )
        distribution_stats[label] = stats

        
    return {
        "cvm": {
            "statistic": cvm_stats,
            "pvalue": cvm_pvalues,
        },
        "ks": {
            "statistic": ks_stats,
            "pvalue": ks_pvalues,
        },
        "distribution": distribution_stats,
    }


In [4]:
documents = [llama_statements, alpaca_statements, cattle_statements, plastic_statements]
labels = ["llama", "alpaca", "cattle", "plastic"]

metrics = compute_document_metrics(model, documents, labels)


Cramer-von Mises

In [5]:
metric = metrics["cvm"]["statistic"]

col_ix = pd.MultiIndex.from_product([['Document'], ['Llama', 'Alpaca', 'Cattle', 'Plastic']]) 
row_ix = pd.MultiIndex.from_product([['Query'], ['Llama', 'Alpaca', 'Cattle', 'Plastic']])
query_metrics_df = pd.DataFrame(metric)
query_metrics_df = query_metrics_df.set_index(row_ix)
query_metrics_df.columns = col_ix
print(query_metrics_df)
print()

metric = metrics["cvm"]["pvalue"]

col_ix = pd.MultiIndex.from_product([['Document'], ['Llama', 'Alpaca', 'Cattle', 'Plastic']]) 
row_ix = pd.MultiIndex.from_product([['Query'], ['Llama', 'Alpaca', 'Cattle', 'Plastic']])
query_metrics_df = pd.DataFrame(metric)
query_metrics_df = query_metrics_df.set_index(row_ix)
query_metrics_df.columns = col_ix
print(query_metrics_df)


                Document                                  
                   Llama     Alpaca     Cattle     Plastic
Query Llama     0.000000   4.004641  47.313574  226.543824
      Alpaca   15.236223   0.000000  38.555197   62.627884
      Cattle   42.814904  27.632434   0.000000   63.055916
      Plastic  93.141139  58.507616  63.662609    0.000000

                   Document                                          
                      Llama        Alpaca        Cattle       Plastic
Query Llama    1.000000e+00  7.021759e-10  2.320660e-08  5.550082e-08
      Alpaca   2.543085e-09  1.000000e+00  6.047391e-09  1.505867e-08
      Cattle   5.144078e-09  1.595944e-09  1.000000e+00  1.665279e-08
      Plastic  2.929343e-08  2.340761e-08  1.916224e-08  1.000000e+00


Kolmgorov-Smirnov

In [6]:
metric = metrics["ks"]["statistic"]

col_ix = pd.MultiIndex.from_product([['Document'], ['Llama', 'Alpaca', 'Cattle', 'Plastic']]) 
row_ix = pd.MultiIndex.from_product([['Query'], ['Llama', 'Alpaca', 'Cattle', 'Plastic']])
query_metrics_df = pd.DataFrame(metric)
query_metrics_df = query_metrics_df.set_index(row_ix)
query_metrics_df.columns = col_ix
print(query_metrics_df)
print()

metric = metrics["ks"]["pvalue"]

col_ix = pd.MultiIndex.from_product([['Document'], ['Llama', 'Alpaca', 'Cattle', 'Plastic']]) 
row_ix = pd.MultiIndex.from_product([['Query'], ['Llama', 'Alpaca', 'Cattle', 'Plastic']])
query_metrics_df = pd.DataFrame(metric)
query_metrics_df = query_metrics_df.set_index(row_ix)
query_metrics_df.columns = col_ix
print(query_metrics_df)

               Document                              
                  Llama    Alpaca    Cattle   Plastic
Query Llama    0.000000  0.127778  0.408056  0.753611
      Alpaca   0.333333  0.000000  0.650000  0.900000
      Cattle   0.579167  0.562500  0.000000  0.907500
      Plastic  0.882500  0.845000  0.912500  0.000000

                    Document                                             
                       Llama         Alpaca         Cattle        Plastic
Query Llama     1.000000e+00   3.052217e-13  7.706754e-135  1.057300e-321
      Alpaca    5.161628e-30   1.000000e+00   5.142712e-80  6.429024e-172
      Cattle    4.326166e-94   1.188149e-58   1.000000e+00  8.613687e-176
      Plastic  3.434451e-251  2.927706e-146  1.963082e-178   1.000000e+00


General Distribution

In [7]:
metric = metrics["distribution"]
pd.DataFrame(metric)


Unnamed: 0,llama,alpaca,cattle,plastic
std,tensor(0.0784),tensor(0.0664),tensor(0.0695),tensor(0.0662)
mean,tensor(0.1893),tensor(0.2249),tensor(0.2070),tensor(0.1846)
skew,0.436476,0.908672,1.158553,1.827339


# Query Distribution

In [8]:
llama_questions = [
    "What is the scientific name for a llama?",
    "Where are llamas originally from?",
    "How long is the typical lifespan of a llama?",
    "What do llamas primarily eat?",
    "How much does a fully grown llama typically weigh?",
    "What is the difference between a llama and an alpaca?",
    "How many stomach compartments do llamas have?",
    "What is the average height of a llama at the shoulder?",
    "How do llamas communicate with each other?",
    "What are the main uses of llamas by humans?",
    "How do llamas react when they feel threatened?",
    "What type of climate is best suited for llamas?",
    "What is the gestation period for a llama?",
    "How often do llamas need to be sheared?",
    "What is a group of llamas called?",
    "How can you tell the age of a llama?",
    "What are some common health issues that llamas face?",
    "How do llamas contribute to their ecosystems?",
    "Can llamas be used as guard animals?",
    "What are the main differences in behavior between wild and domesticated llamas?",
]

alpaca_questions = [
    "What is the average lifespan of an alpaca?",
    "What are the two types of alpacas, and how do they differ?",
    "What is the primary use of alpacas in farming?",
    "Where do alpacas originate from?",
    "What is the typical diet of an alpaca?",
    "How often should alpacas be sheared?",
    "What are some common health issues that alpacas face?",
    "How do alpacas communicate with each other?",
    "What is the gestation period for an alpaca?",
    "How can you distinguish between a male and a female alpaca?",
    "What are the different colors of alpaca fleece?",
    "How does the quality of alpaca fleece compare to sheep's wool?",
    "What are some products made from alpaca fleece?",
    "How much fleece does an alpaca produce annually?",
    "What is the average weight of an adult alpaca?",
    "How do alpacas behave in social settings with other alpacas?",
    "What are some signs of a healthy alpaca?",
    "How can you tell if an alpaca is stressed or unhappy?",
    "What kind of shelter do alpacas need?",
    "How do alpacas contribute to the ecosystem?",
]

cattle_questions = [
    "What are the main differences between beef cattle and dairy cattle?",
    "How long is the gestation period for a cow?",
    "What is the average lifespan of a domesticated cow?",
    "What are the primary breeds of beef cattle?",
    "How much milk does a dairy cow produce on average per day?",
    "What is the purpose of dehorning cattle?",
    "How do farmers typically identify and track their cattle?",
    "What are the common health issues faced by cattle?",
    "How is IVF used in cattle breeding?",
    "What are the nutritional requirements for a growing calf?",
    "What is the significance of the rumen in a cow's digestive system?",
    "How does climate change impact cattle farming?",
    "What are the benefits of rotational grazing for cattle?",
    "How do cattle contribute to greenhouse gas emissions?",
    "What are the key indicators of a cow's health and well-being?",
    "How has the domestication of cattle influenced human agriculture?",
    "What are the ethical considerations in cattle farming?",
    "How do cattle farmers ensure the welfare of their animals during transport?",
    "What are the common methods used for weaning calves from their mothers?",
    "How is biotechnology used to improve cattle breeds and productivity?",
]

plastic_questions = [
    "What does TPU stand for in the context of plastics?",
    "What are the primary uses of TPU plastic in manufacturing?",
    "How does the flexibility of TPU plastic compare to other plastics?",
    "What are the main advantages of using TPU plastic over traditional rubber?",
    "Can TPU plastic be recycled, and if so, how?",
    "What are the key chemical properties of TPU plastic?",
    "How is TPU plastic produced?",
    "What industries commonly use TPU plastic for their products?",
    "How does TPU plastic perform in high-temperature environments?",
    "What are the environmental impacts of producing TPU plastic?",
    "Is TPU plastic considered biodegradable?",
    "How does the hardness of TPU plastic vary with different formulations?",
    "What are some common consumer products made from TPU plastic?",
    "How does the transparency of TPU plastic compare to other thermoplastics?",
    "What are the tensile strength and elongation properties of TPU plastic?",
    "How does TPU plastic react to exposure to UV radiation?",
    "What are the differences between TPU plastic and TPE plastic?",
    "How does TPU plastic handle exposure to chemicals and oils?",
    "What are some limitations or disadvantages of TPU plastic?",
    "How does the cost of TPU plastic compare to other types of plastics?",
]

In [9]:
def compute_query_metrics(
    model: SentenceTransformer,
    queries: list[list[str]],
    documents: list[list[str]],
    labels: list[str],
):
    cvm_stats = []
    cvm_pvalues = []
    ks_stats = []
    ks_pvalues = []
    distribution_stats = {}
    for query, label in zip(queries, labels):
        metrics, stats = compute_metric(
            model=model,
            query=query,
            query_label=label,
            documents=documents,
            document_labels=labels,
        )
        cvm_stats.append(
            [
                metrics[label][other]["CvM"]["statistic"]
                for other in labels
            ]
        )
        cvm_pvalues.append(
            [
                metrics[label][other]["CvM"]["pvalue"]
                for other in labels
            ]
        )
        ks_stats.append(
            [
                metrics[label][other]["KS"]["statistic"]
                for other in labels
            ]
        )
        ks_pvalues.append(
            [
                metrics[label][other]["KS"]["pvalue"]
                for other in labels
            ]
        )
        distribution_stats[label] = stats

        
    return {
        "cvm": {
            "statistic": cvm_stats,
            "pvalue": cvm_pvalues,
        },
        "ks": {
            "statistic": ks_stats,
            "pvalue": ks_pvalues,
        },
        "distribution": distribution_stats,
    }

In [10]:
queries = [llama_questions, alpaca_questions, cattle_questions, plastic_questions]
documents = [llama_statements, alpaca_statements, cattle_statements, plastic_statements]
labels = ["llama", "alpaca", "cattle", "plastic"]

query_metrics = compute_query_metrics(model, queries, documents, labels)

Cramer-von Mises

In [11]:
metric = query_metrics["cvm"]["statistic"]

col_ix = pd.MultiIndex.from_product([['Document'], ['Llama', 'Alpaca', 'Cattle', 'Plastic']]) 
row_ix = pd.MultiIndex.from_product([['Query'], ['Llama', 'Alpaca', 'Cattle', 'Plastic']])
query_metrics_df = pd.DataFrame(metric)
query_metrics_df = query_metrics_df.set_index(row_ix)
query_metrics_df.columns = col_ix
print(query_metrics_df)
print()

metric = query_metrics["cvm"]["pvalue"]

col_ix = pd.MultiIndex.from_product([['Document'], ['Llama', 'Alpaca', 'Cattle', 'Plastic']]) 
row_ix = pd.MultiIndex.from_product([['Query'], ['Llama', 'Alpaca', 'Cattle', 'Plastic']])
query_metrics_df = pd.DataFrame(metric)
query_metrics_df = query_metrics_df.set_index(row_ix)
query_metrics_df.columns = col_ix
print(query_metrics_df)

                 Document                                 
                    Llama     Alpaca     Cattle    Plastic
Query Llama     41.192646  40.566891  66.617784  66.667172
      Alpaca    88.068517  24.654297  65.140622  66.667172
      Cattle    91.245890  61.086872  20.368016  66.662192
      Plastic  100.000154  66.667172  66.667172  62.580972

                   Document                                          
                      Llama        Alpaca        Cattle       Plastic
Query Llama    1.658194e-08  1.320490e-08  3.663468e-08  9.426975e-09
      Alpaca   4.225650e-08  2.418346e-09  2.668820e-08  9.426975e-09
      Cattle   2.182857e-08  1.036360e-08  1.180766e-09  9.416236e-09
      Plastic  2.618679e-08  9.426975e-09  9.426975e-09  1.489229e-08


Kolmgorov-Smirnov

In [12]:
metric = query_metrics["ks"]["statistic"]

col_ix = pd.MultiIndex.from_product([['Document'], ['Llama', 'Alpaca', 'Cattle', 'Plastic']]) 
row_ix = pd.MultiIndex.from_product([['Query'], ['Llama', 'Alpaca', 'Cattle', 'Plastic']])
query_metrics_df = pd.DataFrame(metric)
query_metrics_df = query_metrics_df.set_index(row_ix)
query_metrics_df.columns = col_ix
print(query_metrics_df)
print()

metric = query_metrics["ks"]["pvalue"]

col_ix = pd.MultiIndex.from_product([['Document'], ['Llama', 'Alpaca', 'Cattle', 'Plastic']]) 
row_ix = pd.MultiIndex.from_product([['Query'], ['Llama', 'Alpaca', 'Cattle', 'Plastic']])
query_metrics_df = pd.DataFrame(metric)
query_metrics_df = query_metrics_df.set_index(row_ix)
query_metrics_df.columns = col_ix
print(query_metrics_df)

               Document                        
                  Llama  Alpaca  Cattle Plastic
Query Llama    0.594167  0.7300  0.9925  1.0000
      Alpaca   0.859167  0.5225  0.9575  1.0000
      Cattle   0.873333  0.8875  0.4425  0.9975
      Plastic  1.000000  1.0000  1.0000  0.9150

                    Document                                             
                       Llama         Alpaca         Cattle        Plastic
Query Llama     2.032046e-99  1.415561e-103  9.041959e-232  1.063590e-239
      Alpaca   7.839997e-234   3.227642e-50  5.674041e-205  1.063590e-239
      Cattle   3.387513e-244  1.097288e-165   1.336195e-35  8.508717e-237
      Plastic  4.940656e-324  1.063590e-239  1.063590e-239  8.969697e-180


General Distribution

In [13]:
metric = query_metrics["distribution"]
pd.DataFrame(metric)

Unnamed: 0,llama,alpaca,cattle,plastic
std,tensor(0.0194),tensor(0.0253),tensor(0.0322),tensor(0.0318)
mean,tensor(0.2244),tensor(0.2161),tensor(0.1988),tensor(0.2629)
skew,2.002509,1.464772,1.600186,1.529174
