In [None]:
import os
import pickle

import numpy as np
import pandas as pd
import tqdm

from rxnutils.routes.readers import SynthesisRoute
from rxnutils.routes.comparison import simple_route_similarity

import seaborn as sns
import matplotlib.pylab as plt

In [None]:
root = f"multistep/last_epoch/output"
files = {
    k: [
        f"{root}/{k}/{f}"
        for f in os.listdir(f"{root}/{k}")
        if os.path.isfile(os.path.join(f"{root}/{k}", f)) and f.endswith(".json.gz")
    ]
    for k in os.listdir(root) if not k.endswith("pickle")
}
dfs: dict[str, pd.DataFrame] = {k: pd.concat(pd.read_json(f, orient="table") for f in fs) for k, fs in files.items()}

In [None]:
for k, v in dfs.items():
    with open(f"{root}/{k}_trees.pickle", "rb") as fileobj:
        trees = pickle.load(fileobj)
    dfs[k] = v.assign(route=[SynthesisRoute(tree) for tree in trees])
    assert dfs[k].target.tolist() == [r.reaction_tree["smiles"] for r in dfs[k].route]

In [None]:
targets = {tgt for df in dfs.values() for tgt in df.target}
for df in dfs.values():
    targets = targets.intersection(df.target)
len(targets)

In [None]:
for key in dfs.keys():
    sel = dfs[key].target.isin(targets)
    dfs[key] = dfs[key][sel]

In [None]:
target2route = {key: dict(zip(dfs[key].target, dfs[key].route)) for key in dfs.keys()}
target2solved = {key: dict(zip(dfs[key].target, dfs[key].is_solved)) for key in dfs.keys()}

In [None]:
mats = []
for target in tqdm.tqdm(targets):
    routes = [target2route[key][target] for key in dfs.keys()]
    mask = [not target2solved[key][target] for key in dfs.keys()]
    dm = simple_route_similarity(routes)
    dm[mask, :] = np.nan
    dm[:, mask] = np.nan
    mats.append(dm)

In [None]:
name_map = {
    "multistep_template": "Template-based",
    "multistep_multi": "Multi-expansion",

    'multistep_scratch_original_50k': "RandomInit-Original",
    'multistep_scratch_50k': "RandomInit-baseline",
    'multistep_scratch_acc_50k': "RandomInit-OptAcc",
    'multistep_scratch_rcs_50k': "RandomInit-OptRCS",

    "multistep_chemformer_original_50k": "FT-Zinc-Original",
    "multistep_chemformer_50k": "FT-Zinc-baseline",
    'multistep_pretrain_acc_50k': "KD-Zinc-OptAcc",
    'multistep_pretrain_rcs_50k': "KD-Zinc-OptRCS",
    'multistep_chemformer_acc_50k': "KD-Chemformer-OptAcc",
    'multistep_chemformer_rcs_50k': "KD-Chemformer-OptRCS",

    'multistep_scratch': "RandomInit-baseline",
    'multistep_scratch_acc': "RandomInit-OptAcc",
    'multistep_scratch_rcs': "RandomInit-OptRCS",
    'multistep_chemformer': "FT-Zinc-baseline",
}
mapping = lambda y: map(lambda x: {v: k for k, v in enumerate(name_map)}[x], y)

In [None]:
mat_mean = np.nanmean(mats, axis=0)
df = pd.DataFrame(mat_mean, columns=dfs.keys(), index=dfs.keys())
df = df.sort_index(key=lambda x: list(mapping(df.index))).sort_index(key=lambda x: list(mapping(df.index)), axis=1).rename(index=name_map, columns=name_map)
sns.heatmap(df)

In [None]:
def insert_row(df, position, row_name):
    new_row = pd.DataFrame([[np.nan] * df.shape[1]], columns=df.columns, index=[row_name])
    return pd.concat([df.iloc[:position], new_row, df.iloc[position:]])

def insert_col(df, position, col_name):
    new_col = pd.DataFrame({col_name: [np.nan] * df.shape[0]}, index=df.index)
    return pd.concat([df.iloc[:, :position], new_col, df.iloc[:, position:]], axis=1)

def insert_both(df, position, name):
    df = insert_row(df, position, name)
    df = insert_col(df, position, name)
    return df

df_temp = df.copy()

for i in range(len(df)//2):
    pos = i*2 + i
    df_temp = insert_both(df_temp, pos, "")

sns.set_theme(
    style="whitegrid",
    font="serif",
    # font_scale=1.5,
    context="paper"
)
sns.set_palette("colorblind")
palette = sns.color_palette("colorblind")
sns.despine()


mask = np.triu(np.ones_like(df_temp, dtype=bool))
fig = plt.figure(dpi=360)
ax = fig.gca()
ax = sns.heatmap(df_temp, mask=mask, linewidths=0.5)
ax.grid(False)