# Entity Alignment Experiment

Problem Statement: Entities extracted from Large Language Models (LLMs) often lack consistency, leading to messy data. Our goal is to standardize these entities into a canonical form to ensure they reference the same concept.

Procedure:

1. Extract Entities: Identify and isolate entities from the data provided by the LLM.
1. Project to Semantic Space: Map these entities onto a semantic space where they can be analyzed based on meaning.
1. Define Canonical Form: Determine the canonical form of entities by measuring the semantic distance between them. This involves setting a similarity threshold manually to decide when entities are considered equivalent.

In [None]:
import altair as alt
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import cosine_similarity

from text2graph.macrostrat import get_all_lithologies, get_all_strat_names

alt.data_transformers.disable_max_rows()

In [None]:
# # Pre-process entities
# df = pd.read_sql("SELECT * FROM entities", "sqlite:///data/entities.db")


# def flatten(x: pd.Series) -> list[str]:
#     """Flatten a list of lists."""
#     outputs = []
#     for i in x:
#         outputs.extend([j.strip() for j in i.split(",") if j.strip()])
#     return sorted(list(set(outputs)))


# locations = flatten(df.locations)
# strats = flatten(df.stratigraphic_names)
# liths = flatten(df.lithologies)


# df = pd.DataFrame(
#     {
#         "category": ["location"] * len(locations)
#         + ["stratigraphic_name"] * len(strats)
#         + ["lithology"] * len(liths),
#         "name": locations + strats + liths,
#     }
# )

# df.to_parquet("data/llm_entities_v0.parquet", index=False)

In [None]:
# # Calculate embeddings

# model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# embeddings = model.encode(df.name.to_list())
# np.savez("data/llm_entities_v0.npz", embeddings=embeddings)

In [None]:
df_llm = pd.read_parquet("data/llm_entities_v0.parquet")
embeddings_llm = np.load("data/llm_entities_v0.npz")["embeddings"]

Try to anchor with the known entities (strat and lith)

In [None]:
macrostrat_strat_names = get_all_strat_names(long=True)
macrostrat_lithologies = get_all_lithologies()

In [None]:
# macrostrat_embeddings = model.encode(macrostrat_strat_names + macrostrat_lithologies)
# macrostrat_df = pd.DataFrame(
#     {
#         "category": ["known_stratigraphic_name"] * len(macrostrat_strat_names)
#         + ["known_lithology"] * len(macrostrat_lithologies),
#         "name": macrostrat_strat_names + macrostrat_lithologies,
#     }
# )

# np.savez("data/known_entities_v0.npz", embeddings=macrostrat_embeddings)
# macrostrat_df.to_parquet("data/known_entities_v0.parquet", index=False)

In [None]:
df_known = pd.read_parquet("data/known_entities_v0.parquet")
embeddings_known = np.load("data/known_entities_v0.npz")["embeddings"]

## Plot every llm and known entities in 2D

In [None]:
df = pd.concat([df_llm, macrostrat_df]).reset_index(drop=True)
embeddings = np.concatenate([embeddings_llm, embeddings_known])

# t-sne projection
tsne = TSNE(n_components=2, random_state=0)
x_2d = tsne.fit_transform(embeddings)

In [None]:
df["x"] = x_2d[:, 0]
df["y"] = x_2d[:, 1]

In [None]:
df.to_parquet("data/df_merged_v0.parquet", index=False)

In [None]:
plot = (
    (
        alt.Chart(df)
        .mark_circle()
        .encode(
            x="x",
            y="y",
            color="category",
            tooltip=["category", "name"],
        )
        .interactive()
    )
    .properties(width=1000, height=1000)
    .save("data/entities_v0_tsne.html")
)

- location and lithology seems to be quite separated 
- strat name seems to covers quite a board area, which could separate is subgroup? However, the subgroups are very superficial, e.g., "granite", "sandstone", "limestone", "volcanics", " -member", " -formation". 
- Somewhat dead-end. Strat name's namespace perhaps is too sparse (in the training data) to encode any geologically meaning information.

## Closest known entity

Will the closet known entity make sense?

In [None]:
df.category.unique()

In [None]:
def get_closest_known_entity(
    row: pd.Series,
    embeddings: np.ndarray,
    known_df: pd.DataFrame,
    known_embeddings: np.ndarray,
) -> tuple[str | None, str | None, float | None]:
    assert len(known_df) == known_embeddings.shape[0]
    """Get the closest known entity to a given case according to its category."""

    # Get the embedding of the case
    idx = row.name
    x = embeddings[idx].reshape(1, -1)

    # Return the closest known entity in the same category
    x_category = df.iloc[idx]["category"]
    known_df_in_category = known_df[known_df["category"] == f"known_{x_category}"]

    if known_df_in_category.empty:
        return (None, None, None)

    known_embeddings_in_category = known_embeddings[known_df_in_category.index]

    similarity = cosine_similarity(x, known_embeddings_in_category).flatten()
    idx_closest = np.argmax(similarity)
    return (
        row.name,  # idx
        row["name"],
        row["category"],
        known_df_in_category.iloc[idx_closest]["name"],
        known_df_in_category.iloc[idx_closest]["category"],
        similarity[idx_closest],
    )

In [None]:
# Let's test the function on 10 cases to be save time
df_test = df.query("category in ['stratigraphic_name', 'lithology']").sample(10)
df_test

In [None]:
tmp = df_test.apply(
    get_closest_known_entity,
    args=(embeddings_llm, df_known, embeddings_known),
    axis=1,
)

In [None]:
closest_df = pd.DataFrame(
    tmp.to_list(),
    columns=[
        "idx",
        "name",
        "category",
        "closest_known_entity",
        "closest_category",
        "closest_similarity",
    ],
)
closest_df

In [None]:
# Apply to whole test set

df_test = df.query("category in ['stratigraphic_name', 'lithology']")
tmp = df_test.apply(
    get_closest_known_entity,
    args=(embeddings_llm, df_known, embeddings_known),
    axis=1,
)
closest_df = pd.DataFrame(
    tmp.to_list(),
    columns=[
        "idx",
        "name",
        "category",
        "closest_known_entity",
        "closest_category",
        "closest_similarity",
    ],
)
closest_df.to_parquet("data/closest_known_entities_v0.parquet", index=False)

We may want to select an arbitrary threshold for similarity cutoff. Let see when will the similarity metrics breaks, where it wrongfully map an entity to an irrelevant known entity.  

In [None]:
closest_df.sort_values("closest_similarity", ascending=False).to_csv(
    "data/closest_known_entities_v0.csv", index=True
)