In [None]:
import pickle
from io import BytesIO
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

In [None]:
!mkdir -p dev/results
!wget -O dev/results/de_novo_results.zip \
    "https://zenodo.org/records/16438770/files/de_novo_results.zip?download=1"
!unzip -o dev/results/de_novo_results.zip -d dev/results/
!rm dev/results/de_novo_results.zip

!mkdir -p dev/results
!wget -O dev/results/fragmentation_results.zip \
    "https://zenodo.org/records/16438770/files/fragmentation_results.zip?download=1"
!unzip -o dev/results/fragmentation_results.zip -d dev/results/
!rm dev/results/fragmentation_results.zip

In [None]:
base_folder = Path("dev/results")
(base_folder / "figures").mkdir(exist_ok=True, parents=True)

## data loading

In [None]:
df = (
    pd.concat(
        [
            pickle.load(BytesIO(p.read_bytes())).assign(
                fragments=False,
            )
            for p in (base_folder / "de_novo_results").iterdir()
        ]
        + [
            pickle.load(BytesIO(p.read_bytes())).assign(
                fragments=True,
            )
            for p in (base_folder / "fragmentation_results").iterdir()
        ]
    )
    .reset_index(drop=True)
    .replace(
        {
            "dataset": {
                "casmi_2016": "CASMI 2016",
                "casmi_2017": "CASMI 2017",
                "casmi_2022": "CASMI 2022",
            },
        }
    )
)

In [None]:
pdf = (
    df.loc[lambda d: d["model"].eq("seismiq_pretrained") & ~d["fragments"]]
    .groupby(
        [
            "model",
            "dataset",
            "challenge",
        ],
        dropna=True,
    )
    .agg({"tanimoto": "max"})
    .reset_index()
    .assign(
        perfect=lambda d: 1 * (d["tanimoto"] >= 0.999),
    )
    .replace(
        {
            "model": {
                "seismiq_pretrained": "SEISMiQ",
            },
        }
    )
)

df_all = pd.concat(
    [
        pdf,
        pd.DataFrame(
            [
                {"model": "MSNovelist", "dataset": "CASMI 2016", "perfect": 0.57},
                {"model": "MS2Mol", "dataset": "CASMI 2022", "perfect": 0.10},
                {"model": "MassGenie\n(Train)", "dataset": "CASMI 2017", "perfect": 0.53},
                {"model": "MADGEN\n(Oracle)", "dataset": "MassSpecGym", "perfect": 0.386},
            ]
        ),
    ]
)

## figure 2

In [None]:
g = sns.catplot(
    df_all.reset_index(drop=True),
    x="model",
    y="perfect",
    kind="bar",
    errorbar="se",
    col="dataset",
    col_wrap=4,
    col_order=["CASMI 2016", "CASMI 2017", "CASMI 2022", 'MassSpecGym'],
    height=3,
    aspect=0.8,
    hue="model",
    sharex=False,
).set(xlabel="", ylabel="Accuracy")
for ax in g.axes:
    ax.set_title(ax.get_title().split("=")[-1].strip())
g.figure.tight_layout()
g.figure.savefig(base_folder / "figures" / "figure_2.png")

## fragmentation data

In [None]:
df_best = (
    df.groupby(
        ["model", "dataset", "challenge", "fragments", "bond_idx", "dummy_idx", "given_atoms", "missing_atoms"],
        dropna=False,
    )
    .apply(
        lambda g: g.loc[g["tanimoto"].idxmax(), ["tanimoto", "perplexity", "generation_count"]],
        include_groups=False,
    )
    .reset_index()
    .assign(
        perfect=lambda d: d["tanimoto"] > 0.999,
        excellent=lambda d: d["tanimoto"] > 0.850,
        close=lambda d: d["tanimoto"] > 0.675,
        # meaningful=lambda d: d["tanimoto"] > 0.400,
        missing_atoms=lambda d: np.where(d["missing_atoms"].isna(), 1000, d["missing_atoms"]),
        missing_atoms_bin=lambda d: pd.cut(
            d["missing_atoms"],
            [1, 5, 10, 20, 30, 45, 90, 1100],
        ),
    )
)

df_best["missing_atoms_bin"] = df_best["missing_atoms_bin"].cat.rename_categories(
    {
        cat: "(all)" if cat == pd.Interval(90, 1100, closed="right") else cat
        for cat in df_best["missing_atoms_bin"].cat.categories
    }
)

In [None]:
challenge_atoms = (
    df_best[["dataset", "challenge", "given_atoms", "missing_atoms"]]
    .assign(total_atoms=lambda d: d["given_atoms"] + d["missing_atoms"])
    .drop(columns=["given_atoms", "missing_atoms"])
    .drop_duplicates()
    .dropna()
)
g = sns.lmplot(
    pd.merge(df_best, challenge_atoms)
    .assign(missing_atoms=lambda d: np.where(d["missing_atoms"] > 200, d["total_atoms"], d["missing_atoms"]))
    .loc[lambda d: d["model"].eq("seismiq_pretrained")],
    x="missing_atoms",
    y="tanimoto",
    col="dataset",
    col_wrap=4,
    col_order=['MassSpecGym', "CASMI 2016", "CASMI 2017", "CASMI 2022"],
    height=2.25,
    aspect=1.25,
    markers=".",
    line_kws=dict(color="red"),
    lowess=True,
    scatter_kws=dict(alpha=0.2),
)
g.figure.set_dpi(600)
g.set(ylabel="Tanimoto", xlabel="Missing atoms")
for k, v in g.axes_dict.items():
    v.set_title(v.get_title().split("=")[-1].strip())
g.figure.savefig(base_folder / "figures" / "figure_3d.png", dpi=600)

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(3.25, 2.75), height_ratios=[3, 1])

fdf = df_best.loc[lambda d: d["missing_atoms"].le(400) & d["model"].eq("seismiq_pretrained")]

sns.histplot(
    fdf["missing_atoms"],
    bins=30,
    ax=ax1,
)
ax1.set_yscale("log")
ax1.set_xlabel("")
ax1.set_ylabel("")
ax1.set_xticks([])

sns.boxenplot(fdf["missing_atoms"], ax=ax2, orient="h")
ax2.set_xlabel("Missing Atoms")
sns.despine()
g.figure.set_dpi(600)
g.figure.savefig(base_folder / "figures" / "figure_3b.png")

In [None]:
def plot_fragmentation_results(
    data, fig, ax_fragmentation, fragmentation_kws, ax_de_novo, de_novo_kws, legend_kws,
):
    df_fragmentation = data.loc[
        lambda d: d["missing_atoms_bin"].ne('(all)')
    ].assign(
        missing_atoms_bin=lambda d: d["missing_atoms_bin"].cat.remove_categories(['(all)'])
    )

    df_de_novo = data.loc[
        lambda d: d["missing_atoms_bin"].eq('(all)')
    ].assign(
        missing_atoms_bin=lambda d: d["missing_atoms_bin"].cat.remove_categories([
            cat for cat in d["missing_atoms_bin"].cat.categories if str(cat) != '(all)'
        ])
    )

    sns.pointplot(
        data=df_fragmentation,
        ax=ax_fragmentation,
        legend=True,
        **fragmentation_kws,
    )
    ax_fragmentation.set_xticklabels(ax_fragmentation.get_xticklabels(), rotation=-45)
    ax_fragmentation.set_xlabel("Missing atoms")
    ax_fragmentation.set_ylabel("Accuracy")
    sns.despine(ax=ax_fragmentation)

    sns.pointplot(
        data=df_de_novo,
        ax=ax_de_novo,
        legend=False,
        **de_novo_kws
    )

    sns.despine(ax=ax_de_novo, left=True, top=True, right=True)
    ax_de_novo.set_xlabel("")
    ax_de_novo.set_yticks([])
    ax_de_novo.set_ylabel("")
    ax_de_novo.set_xticklabels(ax_de_novo.get_xticklabels(), rotation=-45)

    if ax_fragmentation.legend_ is not None:
        handles, labels = ax_fragmentation.get_legend_handles_labels()
        ax_fragmentation.legend_.remove()
        fig.legend(
            handles, labels,
            **legend_kws
        )


fig, (ax1, ax2) = plt.subplots(
    1, 2,
    figsize=(4.76, 2.59),
    width_ratios=[5, 1],
    dpi=300,
)

palette = {
    'CASMI 2016': "C1",
    'CASMI 2017': "C2",
    'CASMI 2022': "C3",
    'MassSpecGym': "C0",
}

plot_fragmentation_results(
    df_best.loc[
            lambda d: d["model"].eq("seismiq_pretrained")
    ],
    fig=fig,
    ax_fragmentation=ax1,
    fragmentation_kws=dict(
        x="missing_atoms_bin",
        y="perfect",
        hue="dataset",
        palette=palette,
        linestyles=":",
        dodge=True,
    ),
    ax_de_novo=ax2,
    de_novo_kws=dict(
        x='missing_atoms_bin',
        y="perfect",
        hue="dataset",
        dodge=True,
        palette=palette,
    ),
    legend_kws=dict(
        loc='upper center',
        bbox_to_anchor=(0.55, 1.05),
        ncol=2,
        frameon=False,
        #title='Dataset',
    )
)
fig.tight_layout(rect=(0, 0, 1, 0.9))
fig.savefig(base_folder / "figures" / "figure_3c.png")

In [None]:
g = sns.catplot(
    df_best.replace(
        {
            "model": {
                "seismiq_pretrained": "Pretrained",
                "seismiq_finetuned_casmi": "CASMI Finetuned",
            },
        }
    ),
    x="missing_atoms_bin",
    y="perfect",
    hue="model",
    col="dataset",
    kind="point",
    dodge=True,
    height=2.25,
    aspect=1.5,
)
g.set_xticklabels(rotation=-45)
g.set(ylabel="Accuracy", xlabel="Missing atoms")
g.legend.set_title("Model")
for k, ax in g.axes_dict.items():
    ax.set_title(k)
g.figure.savefig(base_folder / "figures" / "figure_s3.png")