In [1]:
from sentence_transformers import SentenceTransformer
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

import seaborn as sns
import os
import umap

from bokeh.plotting import figure, show, output_notebook
from bokeh.models import HoverTool, ColumnDataSource, CategoricalColorMapper
from bokeh.palettes import Category10 as BokehPalette
from bokeh.resources import INLINE
import bokeh.io

In [2]:
bokeh.io.output_notebook(INLINE)

In [None]:
ct24_dataset = "ct24:latest"
ct24_dev_dataset = "ct24-dev"
ct24_dev_test_dataset = "ct24-dev-test"
ct24_test_dataset = "ct24-test"
synth_example = "ct24_synthetic_only:latest" # Synthetic data generated with example prompt method
synth_10k = "ct24_synth_0_10k:latest"
synth_gpt_4_1  = "synthetic_gpt_4_1:latest"
synth_gpt_4o  = "synthetic_gpt_4o:latest"
synth_gpt_o3_mini  = "synthetic_gpt_o3_mini:latest"
synth_gpt_o4_mini  = "synthetic_gpt_o4_mini:latest"
gc_filtered = "general_claim_filtered:latest"

In [4]:
import wandb

api = wandb.Api()

project = "redstag/thesis"
artifact_names = [ct24_dataset, synth_example, synth_10k, gc_filtered, synth_gpt_4_1, synth_gpt_4o, synth_gpt_o3_mini, synth_gpt_o4_mini]

artifacts = {}
for name in artifact_names:
    artifact = api.artifact(f"{project}/{name}")
    artifact_dir = artifact.download()
    artifacts[name] = artifact_dir

[34m[1mwandb[0m:   4 of 4 files downloaded.  
[34m[1mwandb[0m:   4 of 4 files downloaded.  
[34m[1mwandb[0m:   6 of 6 files downloaded.  
[34m[1mwandb[0m:   4 of 4 files downloaded.  
[34m[1mwandb[0m:   6 of 6 files downloaded.  
[34m[1mwandb[0m:   6 of 6 files downloaded.  
[34m[1mwandb[0m:   6 of 6 files downloaded.  
[34m[1mwandb[0m:   6 of 6 files downloaded.  


In [5]:
dfs = {}
for name, dir_path in artifacts.items():
    csv_path = os.path.join(dir_path, "train.csv")
    dfs[name] = pd.read_csv(csv_path)

# Add CT24 eval, dev-test and test-gold
dfs[ct24_dev_dataset] = pd.read_csv(os.path.join(artifacts[ct24_dataset], "dev.csv"))
dfs[ct24_dev_test_dataset] = pd.read_csv(os.path.join(artifacts[ct24_dataset], "dev-test.csv"))
dfs[ct24_test_dataset] = pd.read_csv(os.path.join(artifacts[ct24_dataset], "test.csv"))

In [6]:
model = SentenceTransformer("all-mpnet-base-v2")

In [7]:
embeddings = {}

for name, frame in dfs.items():
    embeddings[name] = model.encode(list(frame["Text"]), show_progress_bar=True)

Batches:   0%|          | 0/704 [00:00<?, ?it/s]

Batches:   0%|          | 0/704 [00:00<?, ?it/s]

Batches:   0%|          | 0/313 [00:00<?, ?it/s]

Batches:   0%|          | 0/1549 [00:00<?, ?it/s]

Batches:   0%|          | 0/16 [00:00<?, ?it/s]

Batches:   0%|          | 0/16 [00:00<?, ?it/s]

Batches:   0%|          | 0/16 [00:00<?, ?it/s]

Batches:   0%|          | 0/16 [00:00<?, ?it/s]

Batches:   0%|          | 0/33 [00:00<?, ?it/s]

Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Batches:   0%|          | 0/11 [00:00<?, ?it/s]

In [8]:
all_datasets = [ct24_dev_dataset, ct24_dev_test_dataset, ct24_test_dataset, synth_example, synth_10k, gc_filtered]
filter_label = "Yes"

sims = {}
for ds in all_datasets:
    if filter_label:
        mask_ref = dfs[ct24_dataset]["class_label"] == filter_label
        mask_ds = dfs[ds]["class_label"] == filter_label
        sims[ds] = model.similarity(
            embeddings[ct24_dataset][mask_ref.values],
            embeddings[ds][mask_ds.values]
        )
    else:
        sims[ds] = model.similarity(embeddings[ct24_dataset], embeddings[ds])

In [9]:
for ds, s in sims.items():
    print(f"Similarities {ds}:")
    print(f"min: {s.min()}")
    print(f"max: {s.max()}")
    print(f"mean: {s.mean()}")
    print("-----------------")

Similarities ct24-dev:
min: -0.2065959870815277
max: 1.0000001192092896
mean: 0.1746121048927307
-----------------
Similarities ct24-dev-test:
min: -0.22296826541423798
max: 0.8025941848754883
mean: 0.15725839138031006
-----------------
Similarities ct24-test:
min: -0.25562357902526855
max: 0.7282082438468933
mean: 0.1478351205587387
-----------------
Similarities ct24_synthetic_only:latest:
min: -0.2824961245059967
max: 0.9844942092895508
mean: 0.1499294489622116
-----------------
Similarities ct24_synth_0_10k:latest:
min: -0.272604376077652
max: 0.9122903347015381
mean: 0.08540401607751846
-----------------
Similarities general_claim_filtered:latest:
min: -0.2737082540988922
max: 0.9999999403953552
mean: 0.07583823055028915
-----------------


In [10]:
def plot_umap(
        umap_candidates: list[str],
        dataframes: dict[str, pd.DataFrame],
        embeddings: dict[str, np.ndarray], 
        max_samples: int|None = None,
        filter_label: str|None = None,
        balance_classes: bool = False,
        dot_size: float = 6
        ):
    """
    umap_candidates: List with names of datasets to include in UMAP projection.
    dataframes: All datasets.
    embeddings: All embeddings.
    max_samples: The max number of samples per dataframe to plot.
    filter_label: Filter samples per class label ("Yes"/"No").
    balance_classes: Class balance the samples before plotting. Only considered when filter_label is None.
    dot_size: Scatter plot dot size.
    """

    if filter_label and balance_classes:
        raise ValueError("balance_classes only possible when filter_label is None")
    
    embeddings_subset = []
    umap_text = []
    dataset_source_column = []
    label_column = []

    # Process all candidates for plot
    for name in umap_candidates:
        df = dataframes[name]
        e = embeddings[name]
        slice_end = min(len(e), max_samples) if max_samples else len(e)
        
        if filter_label:
            mask = df["class_label"] == filter_label
            embedding_slice = e[mask.values][:slice_end]
            umap_text += df[mask.values][:slice_end]["Text"].to_list()
            label_column += [filter_label] * len(embedding_slice)
        else:
            if balance_classes:
                # Only consider "Yes" and "No" as possible labels
                labels = ["Yes", "No"]
                class_indices = [df[df["class_label"] == label].index for label in labels]
                min_class_count = min(len(idx) for idx in class_indices)
                # Sample equally from both classes, up to max_samples//2 in total
                samples_per_class = min(min_class_count, max_samples // len(labels))
                balanced_indices = np.hstack([
                    np.random.choice(idx, samples_per_class, replace=False) for idx in class_indices
                ])
                embedding_slice = e[balanced_indices]
                umap_text += df.loc[balanced_indices, "Text"].to_list()
                label_column += df.loc[balanced_indices, "class_label"].to_list()
            else:
                embedding_slice = e[:slice_end]
                umap_text += df[:slice_end]["Text"].to_list()
                label_column += df[:slice_end]["class_label"].to_list()
        
        embeddings_subset.append(embedding_slice)
        dataset_source_column += [name.replace(":latest", "")] * len(embedding_slice)

    # Create a column that combines the dataset name and the label
    dataset_source_label_column = [f'{ds} {l}' for ds, l in zip(dataset_source_column, label_column)]

    embeddings_cat = np.concat(embeddings_subset)

    # Apply UMAP
    reducer = umap.UMAP()
    umapped = reducer.fit_transform(embeddings_cat)

    # Setup tooltip
    scatter_df = pd.DataFrame(umapped, columns=('x', 'y'))
    scatter_df['text'] = umap_text
    scatter_df['dataset'] = dataset_source_column
    scatter_df['label'] = label_column
    scatter_df['dataset_label'] = dataset_source_label_column

    palette = BokehPalette[max(len(umap_candidates), 3)][:len(umap_candidates)]
    factors = scatter_df['dataset_label'].unique().tolist()

    if not filter_label:
        # Duplicate the palette by making a darker variant
        def darken_color(hex_color, amount=0.35):
            rgb = mcolors.hex2color(hex_color)
            dark_rgb = tuple([c * amount for c in rgb])
            return mcolors.to_hex(dark_rgb)

        palette = list(palette) + [darken_color(c) for c in palette]

        # Arrange the dataset_source - (Yes/No) labels with the original/darkened sub-palette
        factors.sort()
        factors.reverse()
        factors = list(factors[::2]) + list(factors[1::2])

    datasource = ColumnDataSource(scatter_df)
        
    color_mapping = CategoricalColorMapper(factors=factors,
                                        palette=palette)

    plot_figure = figure(
        title=f'UMAP projection of claim datasets' + (f' (label: {filter_label})' if filter_label else ' (label: all)'),
        width=750,
        height=750,
        tools=('pan, wheel_zoom, reset')
    )

    plot_figure.add_tools(HoverTool(
        attachment="vertical",
        tooltips="""
    <div>
        <div>
            <span style='font-size: 13px; color: #224499'>Dataset:</span>
            <span style='font-size: 13px'>@dataset_label</span>
            <br>
            <span style='font-size: 13px; color: #224499'>Label:</span>
            <span style='font-size: 13px'>@label</span>
            <br>
            <span style='font-size: 13px; color: #224499'>Text:</span>
            <span style='font-size: 13px'>@text</span>
        </div>
    </div>
    """))

    plot_figure.scatter(
        'x',
        'y',
        source=datasource,
        color=dict(field='dataset_label', transform=color_mapping),
        legend_field='dataset_label',
        line_alpha=0.6,
        fill_alpha=0.6,
        size=dot_size
    )
    show(plot_figure)

In [13]:
umap_candidates = [ct24_dataset]

plot_umap(
    umap_candidates=umap_candidates,
    dataframes=dfs,
    embeddings=embeddings,
    max_samples=None,
    filter_label=None,
    balance_classes=False,
    dot_size=4)

In [12]:
umap_candidates = [synth_gpt_o3_mini]

plot_umap(
    umap_candidates=umap_candidates,
    dataframes=dfs,
    embeddings=embeddings,
    max_samples=2500,
    filter_label=None,
    balance_classes=True,
    dot_size=6)