In [5]:
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import gaussian_kde
import sys

# ============================================================
# PATH TO THE RUN
# ============================================================

run_path = Path("/Users/franco/Desktop/CoSyBio/Multi-Dim/TM-Landscape/runs/AK1_test02")  # <<< CHANGE
fig_dir = run_path / "figs_fixed"
fig_dir.mkdir(exist_ok=True)


# ============================================================
# LOAD PARAMETERS FROM THE RUN
# ============================================================

params = pd.read_csv(run_path / "data" / "params.tsv", sep="\t", header=None, index_col=0)[1].to_dict()

sequence = params["sequence"]
source = params["source"]
size = params["size"]

print("Loaded run parameters:")
print("  sequence:", sequence)
print("  source:", source)
print("  size:", size)


# ============================================================
# IMPORT PROJECT MODULES TO LOAD Z_ref
# ============================================================

# replicate script logic
THIS_NOTEBOOK = Path().resolve()
PROJECT_ROOT = THIS_NOTEBOOK.parents[2]     # same as script: parents[2]
sys.path.append(str(PROJECT_ROOT))

from scripts.utils.load_tmvec_embeddings import load_tmvec_embeddings


# ============================================================
# LOAD ORIGINAL TM-VEC MANIFOLD (Z_ref)
# ============================================================

print("Loading TM-Vec reference manifold...")
Z_ref, metadata = load_tmvec_embeddings(source=source, size=size, frac=1.0)
print("Z_ref shape:", Z_ref.shape)


# ============================================================
# LOAD RUN DATA (variants_summary and TSNE)
# ============================================================

df = pd.read_csv(run_path / "data" / "variants_summary.csv")
df_tsne = pd.read_csv(run_path / "data" / "tsne_2d_variants.csv")

# also load embeddings.npz to get query embeddings
emb_npz = np.load(run_path / "data" / "embeddings.npz")

Z_WT = emb_npz["WT"].reshape(1, -1)
Z_masked = emb_npz["masked"]
Z_ala = emb_npz["ala"]
Z_mix = emb_npz["mix"]
Z_del = emb_npz["dele"]

query_embeddings = np.vstack([Z_WT, Z_masked, Z_ala, Z_mix, Z_del])
print("Query embeddings shape:", query_embeddings.shape)


# ============================================================
# GLOBAL TSNE LIMITS (BASED ON FULL Z_ref + QUERY EMBEDDINGS)
# ============================================================

Z_all = np.vstack([Z_ref, query_embeddings])

# We DO NOT recompute t-SNE â†’ we compute the TRUE absolute coordinate range
xmin, xmax = Z_all.min(), Z_all.max()   # This is WRONG: Z is 512-dim
# Instead: use TSNE projection MIN/MAX of all points that participated in TSNE.

# Real fix:
tsne_xmin = df_tsne["TSNE1"].min()
tsne_xmax = df_tsne["TSNE1"].max()
tsne_ymin = df_tsne["TSNE2"].min()
tsne_ymax = df_tsne["TSNE2"].max()

# pad slightly
pad_x = 0.02 * (tsne_xmax - tsne_xmin)
pad_y = 0.02 * (tsne_ymax - tsne_ymin)

tsne_xmin -= pad_x
tsne_xmax += pad_x
tsne_ymin -= pad_y
tsne_ymax += pad_y

print("t-SNE limits:")
print("  X:", tsne_xmin, tsne_xmax)
print("  Y:", tsne_ymin, tsne_ymax)


# ============================================================
# ORIGINAL PALETTE
# ============================================================

PALETTE = {
    "WT": "#1f77b4",
    "masked": "#ff7f0e",
    "ala": "#2ca02c",
    "mix": "#9467bd",
    "del": "#e377c2",
    "neighbor": "#111111",
    "background": "#d3d3d3",
}


# ============================================================
# ENERGY KERNEL
# ============================================================

def smooth_energy_field(x, y, E, grid_res=200, bw=0.25):
    xi = np.linspace(tsne_xmin, tsne_xmax, grid_res)
    yi = np.linspace(tsne_ymin, tsne_ymax, grid_res)
    xi, yi = np.meshgrid(xi, yi)

    weights = np.exp(-E)
    weights /= np.sum(weights)

    kde = gaussian_kde(np.vstack([x, y]), weights=weights, bw_method=bw)
    p = kde(np.vstack([xi.ravel(), yi.ravel()])).reshape(xi.shape)
    p = np.clip(p, 1e-12, None)

    E_eff = -np.log(p)
    E_eff -= E_eff.min()
    return xi, yi, E_eff


# ============================================================
# FIGURES (all fixed limits)
# ============================================================

# --- 1) TM histogram ---
plt.figure(figsize=(6, 4))
sns.histplot(df, x="Approx_TM_NN", hue="Variant_Type",
             element="step", stat="density", common_norm=False, palette=PALETTE)
plt.xlim(0, 1)
plt.savefig(fig_dir / "tm_score_hist_FIXED.png", dpi=300)
plt.close()

# --- 2) Distance to WT ---
plt.figure(figsize=(6, 4))
sns.boxplot(df[df["Variant_Type"].isin(["masked", "ala", "mix", "del"])],
            x="Variant_Type", y="Cosine_Distance_to_WT", palette=PALETTE)
plt.ylim(0, 1)
plt.savefig(fig_dir / "distance_boxplot_FIXED.png", dpi=300)
plt.close()

# --- 3) TM vs Density ---
plt.figure(figsize=(6, 4))
sns.scatterplot(df, x="Approx_TM_NN", y="Log_Density",
                hue="Variant_Type", palette=PALETTE)
plt.xlim(0, 1)
plt.savefig(fig_dir / "tm_vs_density_FIXED.png", dpi=300)
plt.close()

# --- 4) t-SNE ---
plt.figure(figsize=(7, 6))
for label, sub in df_tsne.groupby("Label"):
    color = PALETTE.get(label, "#cccccc")
    alpha = 0.25 if label == "background" else 0.75
    size = 10 if label == "background" else 20
    plt.scatter(sub["TSNE1"], sub["TSNE2"], s=size, alpha=alpha, color=color, label=label)

plt.xlim(tsne_xmin, tsne_xmax)
plt.ylim(tsne_ymin, tsne_ymax)
plt.legend()
plt.savefig(fig_dir / "tsne_2d_variants_FIXED.png", dpi=300)
plt.close()

# --- 5) 2D energy landscape ---
df_var = df_tsne[df_tsne["Label"].isin(["WT", "masked", "ala", "mix", "del"])].copy()
df_var["Energy"] = df["Energy"].values

x = df_var["TSNE1"]
y = df_var["TSNE2"]
E = df_var["Energy"]

X2, Y2, F2 = smooth_energy_field(x, y, E)

plt.figure(figsize=(8, 7))
plt.contourf(X2, Y2, F2, levels=14, cmap="turbo")
plt.scatter(x, y, s=10, c="black", alpha=0.3)
plt.xlim(tsne_xmin, tsne_xmax)
plt.ylim(tsne_ymin, tsne_ymax)
plt.colorbar()
plt.savefig(fig_dir / "energy_landscape_2d_FIXED.png", dpi=300)
plt.close()

print("Done. Figures saved to:", fig_dir)


Loaded run parameters:
  sequence: MEEKLKKTKIIFVVGGPGSGKGTQCEKIVQKYGYTHLSTGDLLRSEVSSGSARGKKLSEIMEKGQLVPLETVLDMLRDAMVAKVNTSKGFLIDGYPREVQQGEEFERRIGQPTLLLYVDAGPETMTQRLLKRGETSGRVDDNEETIKKRLETYYKATEPVIAFYEKRGIVRKVNAEGSVDSVFSQVCTHLDALK
  source: cath
  size: large


IndexError: 2