# Entity Alignment Experiment

Problem Statement: Entities extracted from Large Language Models (LLMs) often lack consistency, leading to messy data. Our goal is to standardize these entities into a canonical form to ensure they reference the same concept.

Procedure:

1. Extract Entities: Identify and isolate entities from the data provided by the LLM.
1. Project to Semantic Space: Map these entities onto a semantic space where they can be analyzed based on meaning.
1. Define Canonical Form: Determine the canonical form of entities by measuring the semantic distance between them. This involves setting a similarity threshold manually to decide when entities are considered equivalent.

Summary:

- The performance of embedding alignment methods seems similar to basic word parsers, as embeddings primarily encode information based on the concrete words (e.g., "granite", "sandstone", "limestone", "volcanics", " -member", " -formation") themselves without much additional context from the "name" parts.
- Setting the similarity threshold at 0.9 is recommended to lower the risk of mistakenly associating new terms with known entities.
- If creating new objects in Macrostrat is a goal, it might be necessary to have a human expert review a list of entities considered high risk.

To-do:

- Implement dynamic prompting for exact match scenarios.
- Implement alignment based on known entity embeddings with a 0.9 similarity threshold to capture more known entities.


In [1]:
import altair as alt
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import cosine_similarity

from text2graph.macrostrat import get_all_lithologies, get_all_strat_names

alt.data_transformers.disable_max_rows()

  from .autonotebook import tqdm as notebook_tqdm


DataTransformerRegistry.enable('default')

In [None]:
# # Pre-process entities
# df = pd.read_sql("SELECT * FROM entities", "sqlite:///data/entities.db")


# def flatten(x: pd.Series) -> list[str]:
#     """Flatten a list of lists."""
#     outputs = []
#     for i in x:
#         outputs.extend([j.strip() for j in i.split(",") if j.strip()])
#     return sorted(list(set(outputs)))


# locations = flatten(df.locations)
# strats = flatten(df.stratigraphic_names)
# liths = flatten(df.lithologies)


# df = pd.DataFrame(
#     {
#         "category": ["location"] * len(locations)
#         + ["stratigraphic_name"] * len(strats)
#         + ["lithology"] * len(liths),
#         "name": locations + strats + liths,
#     }
# )

# df.to_parquet("data/llm_entities_v0.parquet", index=False)

In [None]:
# # Calculate embeddings

# model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# embeddings = model.encode(df.name.to_list())
# np.savez("data/llm_entities_v0.npz", embeddings=embeddings)

In [6]:
df_llm = pd.read_parquet("llm_entities_v0.parquet")
embeddings_llm = np.load("llm_entities_v0.npz")["embeddings"]
df_llm.groupby("category").count()

Unnamed: 0_level_0,name
category,Unnamed: 1_level_1
lithology,7123
location,10618
stratigraphic_name,4151


Try to anchor with the known entities (strat and lith)

In [None]:
# macrostrat_strat_names = get_all_strat_names(long=True)
# macrostrat_lithologies = get_all_lithologies()

# macrostrat_embeddings = model.encode(macrostrat_strat_names + macrostrat_lithologies)
# macrostrat_df = pd.DataFrame(
#     {
#         "category": ["known_stratigraphic_name"] * len(macrostrat_strat_names)
#         + ["known_lithology"] * len(macrostrat_lithologies),
#         "name": macrostrat_strat_names + macrostrat_lithologies,
#     }
# )

# np.savez("data/known_entities_v0.npz", embeddings=macrostrat_embeddings)
# macrostrat_df.to_parquet("data/known_entities_v0.parquet", index=False)

In [4]:
df_known = pd.read_parquet("known_entities_v0.parquet")
embeddings_known = np.load("known_entities_v0.npz")["embeddings"]

In [5]:
df_known.groupby("category").count()

Unnamed: 0_level_0,name
category,Unnamed: 1_level_1
known_lithology,212
known_stratigraphic_name,47821


## Plot every llm and known entities in 2D

In [7]:
df = pd.concat([df_llm, df_known]).reset_index(drop=True)
embeddings = np.concatenate([embeddings_llm, embeddings_known])
assert len(df) == embeddings.shape[0]

In [None]:
# t-sne projection
tsne = TSNE(n_components=2, random_state=0)
x_2d = tsne.fit_transform(embeddings)

In [None]:
df["x"] = x_2d[:, 0]
df["y"] = x_2d[:, 1]

In [None]:
df.to_parquet("data/df_merged_v0.parquet", index=False)

In [None]:
plot = (
    (
        alt.Chart(df)
        .mark_circle()
        .encode(
            x="x",
            y="y",
            color="category",
            tooltip=["category", "name"],
        )
        .interactive()
    )
    .properties(width=1000, height=1000)
    .save("data/entities_v0_tsne.html")
)

- location and lithology seems to be quite separated 
- strat name seems to covers quite a board area, which could separate is subgroup? However, the subgroups are very superficial, e.g., "granite", "sandstone", "limestone", "volcanics", " -member", " -formation". 
- Somewhat dead-end. Strat name's namespace perhaps is too sparse (in the training data) to encode any geologically meaning information.

## Closest known entity

Will the closet known entity make sense?

In [9]:
df.category.unique()

array(['location', 'stratigraphic_name', 'lithology',
       'known_stratigraphic_name', 'known_lithology'], dtype=object)

In [10]:
def get_closest_known_entity(
    row: pd.Series,
    embeddings: np.ndarray,
    known_df: pd.DataFrame,
    known_embeddings: np.ndarray,
    same_category: bool = True,
) -> tuple[str | None, str | None, float | None]:
    assert len(known_df) == known_embeddings.shape[0]
    """Get the closest known entity to a given case according to its category."""

    # Get the embedding of the case
    idx = row.name
    x = embeddings[idx].reshape(1, -1)

    # Return the closest known entity in the same category
    if same_category:
        x_category = df.iloc[idx]["category"]
        known_df_in_category = known_df[known_df["category"] == f"known_{x_category}"]
    else:
        known_df_in_category = known_df

    if known_df_in_category.empty:
        return (None, None, None)

    known_embeddings_in_category = known_embeddings[known_df_in_category.index]

    similarity = cosine_similarity(x, known_embeddings_in_category).flatten()
    idx_closest = np.argmax(similarity)
    return (
        row.name,  # idx
        row["name"],
        row["category"],
        known_df_in_category.iloc[idx_closest]["name"],
        known_df_in_category.iloc[idx_closest]["category"],
        similarity[idx_closest],
    )

In [11]:
# Let's test the function on 10 cases to be save time
df_test = df.query("category in ['stratigraphic_name', 'lithology']").sample(10)
df_test

Unnamed: 0,category,name
18279,lithology,highly vesicular weathered and leached basalt
18276,lithology,highly fossiliferous limestone
13290,stratigraphic_name,The Lakes and Gulf waterway
18191,lithology,hard crystalline layers
17815,lithology,geophysical logs
18392,lithology,ilmenite minerals
14070,stratigraphic_name,late Eocene time
15415,lithology,Surface modification
15134,lithology,Kinmanswick lime-stone
21500,lithology,unidentified pelecypods


In [12]:
tmp_respect_category = df_test.apply(
    get_closest_known_entity,
    args=(embeddings_llm, df_known, embeddings_known, True),
    axis=1,
)

tmp_global_closest = df_test.apply(
    get_closest_known_entity,
    args=(embeddings_llm, df_known, embeddings_known, False),
    axis=1,
)

In [14]:
df_tmp_respect_category = pd.DataFrame(
    tmp_respect_category.to_list(),
    columns=[
        "idx",
        "name",
        "category",
        "closest_known_entity",
        "closest_category",
        "closest_similarity",
    ],
)
df_tmp_respect_category

Unnamed: 0,idx,name,category,closest_known_entity,closest_category,closest_similarity
0,18279,highly vesicular weathered and leached basalt,lithology,basalt,known_lithology,0.693543
1,18276,highly fossiliferous limestone,lithology,limestone,known_lithology,0.725187
2,13290,The Lakes and Gulf waterway,stratigraphic_name,Waterways Formation,known_stratigraphic_name,0.610191
3,18191,hard crystalline layers,lithology,basalt,known_lithology,0.416933
4,17815,geophysical logs,lithology,eclogite,known_lithology,0.346297
5,18392,ilmenite minerals,lithology,quartzite,known_lithology,0.639966
6,14070,late Eocene time,stratigraphic_name,Times Porphyry Formation,known_stratigraphic_name,0.457848
7,15415,Surface modification,lithology,metaigneous,known_lithology,0.305676
8,15134,Kinmanswick lime-stone,lithology,lime mudstone,known_lithology,0.648993
9,21500,unidentified pelecypods,lithology,pelmicrite,known_lithology,0.481199


In [15]:
df_tmp_global_closest = pd.DataFrame(
    tmp_global_closest.to_list(),
    columns=[
        "idx",
        "name",
        "category",
        "closest_known_entity",
        "closest_category",
        "closest_similarity",
    ],
)
df_tmp_global_closest

Unnamed: 0,idx,name,category,closest_known_entity,closest_category,closest_similarity
0,18279,highly vesicular weathered and leached basalt,lithology,Weedy Basalt,known_stratigraphic_name,0.710374
1,18276,highly fossiliferous limestone,lithology,Fossil Hill Limestone,known_stratigraphic_name,0.842122
2,13290,The Lakes and Gulf waterway,stratigraphic_name,Waterways Formation,known_stratigraphic_name,0.610191
3,18191,hard crystalline layers,lithology,Crystal Peak Formation,known_stratigraphic_name,0.50036
4,17815,geophysical logs,lithology,Log Creek Formation,known_stratigraphic_name,0.535539
5,18392,ilmenite minerals,lithology,Gilman Quartzite,known_stratigraphic_name,0.647107
6,14070,late Eocene time,stratigraphic_name,Times Porphyry Formation,known_stratigraphic_name,0.457848
7,15415,Surface modification,lithology,Shade Formation,known_stratigraphic_name,0.412898
8,15134,Kinmanswick lime-stone,lithology,Kinney Limestone,known_stratigraphic_name,0.674658
9,21500,unidentified pelecypods,lithology,pelmicrite,known_lithology,0.481199


In [16]:
# Apply to whole test set

df_test = df.query("category in ['stratigraphic_name', 'lithology']")

tmp_local = df_test.apply(
    get_closest_known_entity,
    args=(embeddings_llm, df_known, embeddings_known, True),
    axis=1,
)

tmp_global = df_test.apply(
    get_closest_known_entity,
    args=(embeddings_llm, df_known, embeddings_known, False),
    axis=1,
)

local_closest_df = pd.DataFrame(
    tmp_local.to_list(),
    columns=[
        "idx",
        "name",
        "category",
        "local_closest_known_entity",
        "local_closest_category",
        "local_closest_similarity",
    ],
)

global_closest_df = pd.DataFrame(
    tmp_global.to_list(),
    columns=[
        "idx",
        "name",
        "category",
        "global_closest_known_entity",
        "global_closest_category",
        "global_closest_similarity",
    ],
)

merged_df = local_closest_df.merge(global_closest_df, left_index=True, right_on="idx")
# closest_df.to_parquet("data/closest_known_entities_v0.parquet", index=False)

In [19]:
local_closest_df

Unnamed: 0,idx,name,category,local_closest_known_entity,local_closest_category,local_closest_similarity
0,10618,'Bend series',stratigraphic_name,Bend Group,known_stratigraphic_name,0.697886
1,10619,'Cascade River Schist of Misch (1966)',stratigraphic_name,Cascade River Schist,known_stratigraphic_name,0.825065
2,10620,'D' coal bed,stratigraphic_name,Dandy Coal Bed,known_stratigraphic_name,0.787331
3,10621,'Leonardian-Roadian (upper Lower Permian)',stratigraphic_name,Permian Basal Breccia,known_stratigraphic_name,0.721919
4,10622,'Pg Grand Prize Formation (Lower Permian)',stratigraphic_name,Permian Siltstone Member,known_stratigraphic_name,0.671311
...,...,...,...,...,...,...
11269,21887,zoisite,lithology,dolomite,known_lithology,0.656132
11270,21888,zoisite(?),lithology,dolomite,known_lithology,0.623896
11271,21889,zone of gash veins,lithology,anthracite,known_lithology,0.262021
11272,21890,zoned plagioclase,lithology,pseudotachylite,known_lithology,0.393489


In [24]:
df_merged = local_closest_df.merge(global_closest_df.drop(columns=["name", "category"]))
df_merged.to_excel("closest_known_entities_v0.xlsx", index=False)

We may want to select an arbitrary threshold for similarity cutoff. Let see when will the similarity metrics breaks, where it wrongfully map an entity to an irrelevant known entity.  