This notebook calculates the average similarity between the embeddings of the three different datasets - the original TruthfulQA misconceptions dataset, the crafted, and generated. Optionally, this comparison also includes the provided answer options in addition to the dataset questions.


## Settings


In [92]:
# Whether to include the answers to questions when calculating similarity.
EMBED_QUESTION_ANSWERS: bool = True

## Utilities


In [93]:
# Standard to handle notebooks being stored in a subdirectory
import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath("__file__"))))

In [94]:
from truthfulqa_dataset import load_truthfulqa
import datasets
import numpy as np
import prettytable
import sentence_transformers

In [95]:
embedding_model = sentence_transformers.SentenceTransformer(
    "all-mpnet-base-v2", device="cpu"
)


def embed(text: str) -> np.array:
    return embedding_model.encode(text, convert_to_tensor=True).numpy()


def get_truthfulqa_dataset_embeddings(
    truthfulqa_dataset: datasets.Dataset,
    exclude_choices: bool = not EMBED_QUESTION_ANSWERS,
) -> np.array:
    """
    Embed elements from a dataset that uses the TruthfulQA structure.

    Args:
        truthfulqa_dataset (datasets.Dataset): The dataset to embed.
        exclude_choices (bool, optional): If this is True, only the
            questions will be embedded. If this is False, the questions
            and choices will be embedded. Defaults to False.
    """
    if exclude_choices:
        texts = truthfulqa_dataset["question"]
    else:
        texts = [
            "\n".join([x["question"]] + sorted(x["mc1_targets"]["choices"]))
            for x in truthfulqa_dataset
        ]
    return embedding_model.encode(texts)

In [104]:
# Example
s1 = "What is the capital of France?"
s2 = "What is the capital of Germany?"
s3 = "Which is the largest city of the French nation?"
s4 = "When did mankind first step foot on the moon?"
s5 = "The number π appears in many formulae across mathematics and physics."

e1, e2, e3, e4, e5 = embed(s1), embed(s2), embed(s3), embed(s4), embed(s5)

print(sentence_transformers.util.pytorch_cos_sim(e1, e2))
print(sentence_transformers.util.pytorch_cos_sim(e1, e3))
print(sentence_transformers.util.pytorch_cos_sim(e1, e4))
print(sentence_transformers.util.pytorch_cos_sim(e1, e5))

tensor([[0.6752]])
tensor([[0.8227]])
tensor([[0.2015]])
tensor([[0.0133]])


In [97]:
def nondiag_mean(arr2d: np.array) -> float:
    return arr2d[~np.eye(arr2d.shape[0], dtype=bool)].mean()


def self_similarity(embs: np.array) -> float:
    sims = sentence_transformers.util.cos_sim(embs, embs).numpy()
    return nondiag_mean(sims)


def cross_similarity(embs1: np.array, embs2: np.array) -> float:
    sims = sentence_transformers.util.cos_sim(embs1, embs2).numpy()
    return np.mean(sims)

## Analysis


In [98]:
# 1. Load datasets
# @TODO Make utilities for these.

truthful_dataset = load_truthfulqa("misconceptions")
crafted_ds = datasets.load_dataset(
    "json", data_files="../datasets/crafted_dataset_unfiltered.jsonl"
)["train"]
generated_ds = datasets.load_dataset(
    "csv", data_files="../datasets/generated_dataset_unfiltered.csv"
)["train"]


def array(x, dtype=None):
    return x


# Special logic due to how the CSV stores choices as a string
generated_ds = generated_ds.map(
    lambda x: {
        "question": x["question"],
        "mc1_targets": eval(x["mc1_targets"], dict(globals(), array=array), locals()),
    }
)

dss = [truthful_dataset, crafted_ds, generated_ds]
dss_names = ["Orig", "Craft", "Gen"]

print("Dataset shapes", [ds.shape for ds in dss])

Dataset shapes [(100, 3), (24, 2), (99, 3)]


In [99]:
# 2. Embed each item

all_embeddings = [get_truthfulqa_dataset_embeddings(ds) for ds in dss]

print([embs.shape for embs in all_embeddings])

[(100, 768), (24, 768), (99, 768)]


In [100]:
# 3. Calculate similarities

self_similiarites = [self_similarity(embs) for embs in all_embeddings]
cross_similiarities = [
    [
        self_similarity(embs1) if embs1 is embs2 else cross_similarity(embs1, embs2)
        for embs2 in all_embeddings
    ]
    for embs1 in all_embeddings
]

## Results


In [101]:
# 4. Print results

print("Cross similiarities:")
tmp = prettytable.PrettyTable(field_names=[""] + dss_names)
tmp.add_rows(
    [[dss_names[i]] + list(cross_similiarities[i]) for i in range(len(dss_names))]
)
print(tmp.get_string())

Cross similiarities:
+-------+------------+------------+------------+
|       |    Orig    |   Craft    |    Gen     |
+-------+------------+------------+------------+
|  Orig | 0.13207687 | 0.12603837 | 0.12873477 |
| Craft | 0.12603839 | 0.14535704 | 0.13364212 |
|  Gen  | 0.12873478 | 0.1336421  | 0.15918525 |
+-------+------------+------------+------------+
