In [None]:
# Clone OrthoLM and ESM (run once per runtime)
!git clone https://github.com/ThomasGTHB/OrthoLM.git -q
!git clone https://github.com/facebookresearch/esm.git -q

import os
os.chdir("OrthoLM")  # safer than %cd for Kaggle too
os.makedirs("Results", exist_ok=True)

# Install packages (REMOVE the literal "..." line)
!pip -q install biopython matplotlib-venn fair-esm seaborn tqdm joblib


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m3.3/3.3 MB[0m [31m109.1 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m70.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# ===== Imports =====
import os, time, json, csv, copy, random, itertools, pickle
import numpy as np
import torch, torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

import esm
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA  # (unused here but kept for compatibility)
from sklearn.metrics import adjusted_mutual_info_score
import joblib


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# ===== Config =====
DATA_ROOT   = "/content/OrthoLM/Datasets/"

TRAIN_SPEC  = ["drer_xtro", "mmus_hsap"]   # pooled train
TEST_SPEC   = "pfal_pber"                  # held-out test

MODEL_NAME  = "esm2_t36_3B_UR50D"
EMB_LAYER   = 36

# AE
LATENT_DIM  = 128
LAMBDA_L1   = 7.5e-4
NOISE_SIGMA = 0.0
DROPOUT_P   = 0.10
LR          = 1e-3
WD          = 1e-5
BATCH_SIZE  = 512
EPOCHS      = 120
PATIENCE    = 10

# NEW: loss selector ("mse" or "cosine")
LOSS_TYPE: str = "cosine"            # <-- flip to "cosine" for the second run

# KMeans
KMEANS_NINIT   = 50
KMEANS_MAXITER = 500
KMEANS_ALGO    = "elkan"
RAND_STATE     = 0

# Variant: also evaluate a Top-k sparse version of latents
USE_TOPK = False
TOPK     = 32

# Reproducibility & device
SEED = 0
np.random.seed(SEED); random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

# ===== Experiment folder =====
ts = time.strftime("%Y%m%d-%H%M%S")
loss_tag = f"loss-{LOSS_TYPE}"
topk_tag = f"topk{TOPK}" if USE_TOPK else "topk0"
exp_name = (
    f"TRAIN[{'+'.join(TRAIN_SPEC)}]_TEST[{TEST_SPEC}]_"
    f"{MODEL_NAME}_AE{LATENT_DIM}_L1-{LAMBDA_L1}_SIG{NOISE_SIGMA}_DO{DROPOUT_P}_"
    f"ninit{KMEANS_NINIT}_iter{KMEANS_MAXITER}_topk{TOPK if USE_TOPK else 0}_{ts}_"
    f"{MODEL_NAME}_AE{LATENT_DIM}_{loss_tag}_L1-{LAMBDA_L1}_SIG{NOISE_SIGMA}_DO{DROPOUT_P}_"
    f"ninit{KMEANS_NINIT}_iter{KMEANS_MAXITER}_{topk_tag}_{ts}"
)
RESULTS_ROOT = "/content/drive/MyDrive/ae_ablations/lamdaL1_7.5e-4"
EXP_DIR = os.path.join(RESULTS_ROOT, exp_name)
os.makedirs(EXP_DIR, exist_ok=True)
print("Saving to:", EXP_DIR)


device: cuda
Saving to: /content/drive/MyDrive/ae_ablations/lamdaL1_7.5e-4/TRAIN[drer_xtro+mmus_hsap]_TEST[pfal_pber]_esm2_t36_3B_UR50D_AE128_L1-0.00075_SIG0.0_DO0.1_ninit50_iter500_topk0_20250913-144019_esm2_t36_3B_UR50D_AE128_loss-cosine_L1-0.00075_SIG0.0_DO0.1_ninit50_iter500_topk0_20250913-144019


In [None]:
def load_species_matrix(species, model_name, emb_layer, data_root):
    """Return:
       Xs  : (N, D) float32 mean embeddings
       ids : list[str] fasta headers (order matches Xs rows)
       meta: list[list[str]] split by '|' → [species_code, protein_code, ..., OG]
    """
    fasta_path = os.path.join(data_root, f"{species}.fasta")
    emb_dir    = os.path.join(data_root, f"{species}_emb_{model_name}")
    Xs, ids, meta = [], [], []
    for header, _seq in esm.data.read_fasta(fasta_path):
        pt = os.path.join(emb_dir, f"{header}.pt")
        if not os.path.isfile(pt):
            continue
        embs = torch.load(pt)
        Xs.append(embs['mean_representations'][emb_layer])
        ids.append(header)
        meta.append(header.split('|'))
    Xs = torch.stack(Xs, dim=0).numpy().astype(np.float32)
    return Xs, ids, meta

def count_OGs(prot_meta):
    return len(set([p[4] for p in prot_meta]))

def save_json(obj, path):
    with open(path, "w") as f:
        json.dump(obj, f, indent=2)


In [None]:
# ----- Pooled train -----
Xs_train_list, meta_train_list = [], []
for sp in TRAIN_SPEC:
    Xs_sp, ids_sp, meta_sp = load_species_matrix(sp, MODEL_NAME, EMB_LAYER, DATA_ROOT)
    np.save(os.path.join(EXP_DIR, f"{sp}_ids.npy"), np.array(ids_sp, dtype=object))
    Xs_train_list.append(Xs_sp)
    meta_train_list += meta_sp
    print(f"[train] {sp}: X={Xs_sp.shape}, OGs={count_OGs(meta_sp)}")

Xs_train_pool = np.concatenate(Xs_train_list, axis=0).astype(np.float32)
print("[train pool]", Xs_train_pool.shape, "from", TRAIN_SPEC)

# Fit scaler on pooled train (save it)
scaler = StandardScaler().fit(Xs_train_pool)
joblib.dump(scaler, os.path.join(EXP_DIR, "scaler.joblib"))

# ----- Test -----
Xs_test, ids_test, meta_test = load_species_matrix(TEST_SPEC, MODEL_NAME, EMB_LAYER, DATA_ROOT)
np.save(os.path.join(EXP_DIR, f"{TEST_SPEC}_ids.npy"), np.array(ids_test, dtype=object))
print(f"[test] {TEST_SPEC}: X={Xs_test.shape}, OGs={count_OGs(meta_test)}")


[train] drer_xtro: X=(61670, 2560), OGs=18374
[train] mmus_hsap: X=(45087, 2560), OGs=16659
[train pool] (106757, 2560) from ['drer_xtro', 'mmus_hsap']
[test] pfal_pber: X=(10263, 2560), OGs=5315


In [None]:
def saving_from_kmeans(Xs_train_pca, prot_names_and_group_train, kmeans):
    n_clusters = kmeans.n_clusters
    n_samples  = Xs_train_pca.shape[0]
    X_labels   = kmeans.labels_
    return [n_clusters, n_samples, prot_names_and_group_train, X_labels, kmeans,
            "n_clusters, n_samples, prot_names_and_group_train, X_labels, kmeans"]

def measure_pairwise_performance(saved_results, Xs_train_pca):
    n_clusters = saved_results[0]
    n_samples  = saved_results[1]
    prot_names_and_group_train = saved_results[2]
    X_labels   = saved_results[3]
    kmeans     = saved_results[4]

    X_dist = kmeans.transform(Xs_train_pca)**2
    orthologs_naiveSearch = []
    orthologs_distanceBasedSearch = []
    orthologs_1_to_1 = []

    n_species_total = len(list(set([prot[0] for prot in prot_names_and_group_train])))
    for cluster in range(n_clusters):
        ind_fromCluster = [i for i, x in enumerate(X_labels) if x==cluster]
        if len(ind_fromCluster) == 1:
            continue

        ind_sorted = np.argsort(X_dist[ind_fromCluster, cluster])
        all_prots = [prot_names_and_group_train[ind_fromCluster[i]] for i in ind_sorted]
        all_specs = [prot_names_and_group_train[ind_fromCluster[i]][0] for i in ind_sorted]

        # Naive: all cross-species pairs
        if len(list(set(all_specs))) > 1:
            for i1 in range(len(all_specs) - 1):
                for i2 in range(i1 + 1, len(all_specs)):
                    if all_specs[i1] != all_specs[i2]:
                        orthologs_naiveSearch.append([all_prots[i1], all_prots[i2]])

        # Distance-based: top of cluster + nearest with different species
        ind_species2 = 1
        if all_specs[0] != all_specs[ind_species2]:
            orthologs_distanceBasedSearch.append([all_prots[0], all_prots[ind_species2]])
        else:
            while ind_species2 < len(all_specs) and all_specs[0] == all_specs[ind_species2]:
                ind_species2 += 1
                if ind_species2 < len(all_specs) and all_specs[0] != all_specs[ind_species2]:
                    orthologs_distanceBasedSearch.append([all_prots[0], all_prots[ind_species2]])
                    break

        # 1-to-1: one per species in the cluster
        if len(all_specs) == len(list(set(all_specs))) == n_species_total:
            orthologs_1_to_1.append(all_prots)

    return [orthologs_naiveSearch, orthologs_distanceBasedSearch, orthologs_1_to_1]

def measure_group_performance(saved_results):
    n_clusters = saved_results[0]
    n_samples  = saved_results[1]
    prot_names_and_group_train = saved_results[2]
    X_labels   = saved_results[3]

    # groups per cluster
    list_groups_in_clusters = [[] for _ in range(n_clusters)]
    for i_label in range(n_samples):
        list_groups_in_clusters[X_labels[i_label]].append(prot_names_and_group_train[i_label][4])

    list_all_groups = list(set([prot[4] for prot in prot_names_and_group_train]))

    list_OG_count_in_cluster = np.zeros((len(list_all_groups), n_clusters))
    for i_cluster in range(n_clusters):
        for group in list_groups_in_clusters[i_cluster]:
            list_OG_count_in_cluster[list_all_groups.index(group), i_cluster] += 1
    if not (np.sum(list_OG_count_in_cluster) == n_samples):
        print("Error: missing some sequences in the count matrix")

    # Family completeness
    family_complet_stat = 0
    for i_group in range(len(list_all_groups)):
        family_complet_stat += max(list_OG_count_in_cluster[i_group,:])
    family_complet_stat = family_complet_stat / n_samples

    # AMI
    AMI = adjusted_mutual_info_score([prot[4] for prot in prot_names_and_group_train], X_labels)

    # % exact matches
    i_count_success_exactMatch = 0
    for i_group in range(len(list_all_groups)):
        if np.count_nonzero(list_OG_count_in_cluster[i_group,:] == 0) == n_clusters - 1:
            i_cluster = np.nonzero(list_OG_count_in_cluster[i_group,:])[0][0]
            if np.count_nonzero(list_OG_count_in_cluster[:,i_cluster] == 0) == len(list_all_groups) - 1:
                i_count_success_exactMatch += 1
    exact_over_total = i_count_success_exactMatch / n_clusters

    return [family_complet_stat, AMI, exact_over_total]

# utils.py
import numpy as np
import torch

def l2_normalize_rows(x: np.ndarray, eps=1e-9):
    n = np.linalg.norm(x, axis=1, keepdims=True)
    return x / (n + eps)

@torch.no_grad()
def torch_l2_normalize_rows(x: torch.Tensor, eps=1e-8):
    n = x.norm(dim=1, keepdim=True)
    return x / (n + eps)

# Small helper: cosine reconstruction loss (row-wise)
def cosine_recon_loss(x_hat: torch.Tensor, x: torch.Tensor, eps: float = 1e-8):
    xh = x_hat / (x_hat.norm(dim=1, keepdim=True) + eps)
    xx = x     / (x.norm(dim=1, keepdim=True)     + eps)
    # 1 - cos similarity; mean over batch
    return (1.0 - (xh * xx).sum(dim=1)).mean()
# class CosineReconLoss(torch.nn.Module):
#     def __init__(self, eps=1e-8):
#         super().__init__()
#         self.eps = eps
#     def forward(self, x_hat, x):
#         xh = x_hat / (x_hat.norm(dim=1, keepdim=True) + self.eps)
#         x  = x     / (x.norm(dim=1, keepdim=True)     + self.eps)
#         # want high cosine similarity -> minimize (1 - cos)
#         return (1.0 - (xh * x).sum(dim=1)).mean()

def compute_k_from_token(k_token: str, N: int) -> int:
    if isinstance(k_token, int):
        return k_token
    if isinstance(k_token, str):
        if k_token.upper() == "N2":
            return max(2, N // 2)
        raise ValueError(f"Unknown k token: {k_token}")
    raise ValueError(f"Unsupported k type: {type(k_token)}")

In [None]:
# --- AE (MSE vs Cosine) training cell ---------------------------------------
import os, csv, copy, torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

# -------------------- choose the loss here -----------------------------------
#LOSS_TYPE = "mse"      # <-- change to "cosine" for the second experiment
# -----------------------------------------------------------------------------

# For this experiment: no extras (as requested)
# LAMBDA_L1   = 0.0      # force OFF
NOISE_SIGMA = 0.0      # force OFF



class AE(nn.Module):
    def __init__(self, in_dim, latent_dim, dropout=0.1):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Linear(in_dim, 512), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(512, 128),   nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(128, latent_dim)
        )
        self.dec = nn.Sequential(
            nn.Linear(latent_dim, 128), nn.ReLU(),
            nn.Linear(128, 512),        nn.ReLU(),
            nn.Linear(512, in_dim)
        )
    def forward(self, x):
        z = self.enc(x)
        x_hat = self.dec(z)
        return x_hat, z

# -------------------- data split & loaders -----------------------------------
# expects: scaler, Xs_train_pool, BATCH_SIZE, device, LATENT_DIM, LR, WD, EPOCHS, PATIENCE, DROPOUT_P, EXP_DIR
Xs_train_scaled = scaler.transform(Xs_train_pool).astype(np.float32)
X_tensor = torch.from_numpy(Xs_train_scaled)
perm = torch.randperm(X_tensor.size(0))
n_val = int(0.10 * len(perm))
val_idx, train_idx = perm[:n_val], perm[n_val:]

tr_dl = DataLoader(TensorDataset(X_tensor[train_idx]), batch_size=BATCH_SIZE, shuffle=True)
va_dl = DataLoader(TensorDataset(X_tensor[val_idx]),   batch_size=BATCH_SIZE, shuffle=False)

# -------------------- model/optim/loss ---------------------------------------
ae  = AE(Xs_train_scaled.shape[1], LATENT_DIM, dropout=DROPOUT_P).to(device)
opt = torch.optim.Adam(ae.parameters(), lr=LR, weight_decay=WD)
mse = nn.MSELoss()

best_val = float('inf'); best_state = None; no_improve = 0
log_rows = []

print(f"[AE] training start (loss={LOSS_TYPE}, latent={LATENT_DIM}, L1={LAMBDA_L1}, noise={NOISE_SIGMA})")
for ep in range(1, EPOCHS+1):
    ae.train(); tr_loss = tr_recon = 0.0
    for (xb,) in tr_dl:
        xb = xb.to(device)

        # no noise for this experiment (kept line for clarity)
        noisy = xb + NOISE_SIGMA * torch.randn_like(xb) if NOISE_SIGMA > 0 else xb

        x_hat, z = ae(noisy)
        if LOSS_TYPE == "mse":
            recon = mse(x_hat, xb)
        elif LOSS_TYPE == "cosine":
            recon = cosine_recon_loss(x_hat, xb)
        else:
            raise ValueError("LOSS_TYPE must be 'mse' or 'cosine'")

        # # no L1 for this experiment
        # loss = recon
        spars = z.abs().mean()              # ← L1 on latent activations
        loss  = recon + LAMBDA_L1 * spars   # ← APPLY λ here

        opt.zero_grad(); loss.backward(); opt.step()

        tr_loss  += loss.item() * xb.size(0)
        tr_recon += recon.item() * xb.size(0)

    tr_loss  /= len(tr_dl.dataset)
    tr_recon /= len(tr_dl.dataset)

    # ---------------- validation ----------------
    ae.eval(); va_loss = va_recon = 0.0
    with torch.no_grad():
        for (xb,) in va_dl:
            xb = xb.to(device)
            x_hat, z = ae(xb)
            if LOSS_TYPE == "mse":
                recon = mse(x_hat, xb)
            else:
                recon = cosine_recon_loss(x_hat, xb)
            spars = z.abs().mean()
            loss  = recon + LAMBDA_L1 * spars

            va_loss  += loss.item()  * xb.size(0)
            va_recon += recon.item() * xb.size(0)

    va_loss  /= len(va_dl.dataset)
    va_recon /= len(va_dl.dataset)

    tag = ""
    if va_loss < best_val - 1e-6:
        best_val = va_loss
        best_state = {"ae": copy.deepcopy(ae.state_dict())}
        tag = "**best**"; no_improve = 0
    else:
        no_improve += 1

    print(f"[AE] ep{ep:03d}  train: total={tr_loss:.5f} ({LOSS_TYPE}={tr_recon:.5f}) | "
      f"val: total={va_loss:.5f} ({LOSS_TYPE}={va_recon:.5f}) | {tag}")
    log_rows.append([ep, tr_loss, tr_recon, va_loss, va_recon])
    if no_improve >= PATIENCE:
        break

# -------------------- save best weights + log (tagged by loss) ---------------
loss_tag = LOSS_TYPE.lower()
lam_tag  = f"lam{LAMBDA_L1:g}"

if best_state is not None:
    torch.save(best_state,       os.path.join(EXP_DIR, f"ae_checkpoint_{loss_tag}_{lam_tag}.pt"))
    torch.save(best_state["ae"], os.path.join(EXP_DIR, f"ae_state_{loss_tag}_{lam_tag}.pth"))

with open(os.path.join(EXP_DIR, f"trainlog_{loss_tag}_{lam_tag}.csv"), "w", newline="") as f:
    cw = csv.writer(f)
    cw.writerow(["epoch","tr_total",f"tr_{loss_tag}","va_total",f"va_{loss_tag}"])
    cw.writerows(log_rows)

print(f"[saved] AE + logs (loss={loss_tag}) → {EXP_DIR}")


[AE] training start (loss=cosine, latent=128, L1=0.00075, noise=0.0)
[AE] ep001  train: total=0.39104 (cosine=0.39003) | val: total=0.28412 (cosine=0.28323) | **best**
[AE] ep002  train: total=0.27072 (cosine=0.26981) | val: total=0.23410 (cosine=0.23324) | **best**
[AE] ep003  train: total=0.23971 (cosine=0.23881) | val: total=0.21124 (cosine=0.21039) | **best**
[AE] ep004  train: total=0.22421 (cosine=0.22334) | val: total=0.19829 (cosine=0.19747) | **best**
[AE] ep005  train: total=0.21448 (cosine=0.21362) | val: total=0.19172 (cosine=0.19091) | **best**
[AE] ep006  train: total=0.20791 (cosine=0.20707) | val: total=0.18463 (cosine=0.18384) | **best**
[AE] ep007  train: total=0.20338 (cosine=0.20255) | val: total=0.18083 (cosine=0.18006) | **best**
[AE] ep008  train: total=0.20000 (cosine=0.19918) | val: total=0.17800 (cosine=0.17724) | **best**
[AE] ep009  train: total=0.19704 (cosine=0.19624) | val: total=0.17599 (cosine=0.17524) | **best**
[AE] ep010  train: total=0.19476 (cosine

In [None]:

# Try to restore from file if present
ckpt_path = os.path.join(EXP_DIR, f"ae_checkpoint_{LOSS_TYPE}.pt")
if os.path.exists(ckpt_path):
    if 'ae' not in globals():
        # re-create the AE with the same dims/hparams you trained
        ae = AE(Xs_train_pool.shape[1], LATENT_DIM, dropout=DROPOUT_P).to(device)
    state_obj = torch.load(ckpt_path, map_location=device)
    ae.load_state_dict(state_obj["ae"])
    best_state = state_obj

# restore best
if best_state is not None:
    ae.load_state_dict(best_state["ae"])
ae.eval()

def encode_latents(ae, scaler, X, batch_size=512):
    Xs = scaler.transform(X).astype(np.float32)
    Zs = []
    with torch.no_grad():
        for i in range(0, Xs.shape[0], batch_size):
            xb = torch.from_numpy(Xs[i:i+batch_size]).to(device)
            _, z = ae(xb)
            Zs.append(z.cpu().numpy())
    Z = np.vstack(Zs).astype(np.float32)
    return Z

def topk_sparsify(Z, k):
    Zs = Z.copy()
    # zero all but top-k |values| per row
    idx = np.argpartition(np.abs(Zs), -k, axis=1)[:, :-k]
    rows = np.arange(Zs.shape[0])[:, None]
    Zs[rows, idx] = 0.0
    return Zs

# plain latents
Z_plain = encode_latents(ae, scaler, Xs_test)        # produces (N, latent_dim)
Z_plain = Z_plain / (np.linalg.norm(Z_plain, axis=1, keepdims=True) + 1e-9)

lat_fn = f"{TEST_SPEC}_{MODEL_NAME}_AE{LATENT_DIM}_{loss_tag}_{lam_tag}_latents_plain.npy"
np.save(os.path.join(EXP_DIR, lat_fn), Z_plain)

# # top-k variant
# if USE_TOPK:
#     Z_topk = topk_sparsify(Z_plain, TOPK)
#     Z_topk = Z_topk / (np.linalg.norm(Z_topk, axis=1, keepdims=True) + 1e-9)
#     np.save(os.path.join(EXP_DIR, f"{TEST_SPEC}_{MODEL_NAME}_AE{LATENT_DIM}_{LOSS_TYPE}_latents_topk{TOPK}.npy"), Z_topk)

print("[saved] latents (plain and top-k)" if USE_TOPK else "[saved] latents (plain)")


[saved] latents (plain)


In [None]:
def count_OGs(meta_list):
    # meta entries are like: [species_code, protein_code, ..., OG_id] -> OG_id at index 4
    return len(set([m[4] for m in meta_list]))
# meta alias (kept name consistent with your original code)
prot_names_and_group_test = meta_test

def compute_k_list(meta, N):
    ogs = count_OGs(meta)
    ks = [N//2] # , ogs, max(2, ogs-500)
    # unique+in-range
    seen=set(); out=[]
    for k in ks:
        if 1 < k <= N and k not in seen:
            seen.add(k); out.append(k)
    return out, ogs

def run_kmeans_eval(X_feat, prot_meta, variant_tag):
    X_feat = X_feat.astype(np.float32, copy=False)
    N = X_feat.shape[0]
    k_list, ogs = compute_k_list(prot_meta, N)
    print(f"[{variant_tag}] N={N}, #OGs={ogs}, ks={k_list}")

    kmeans_saving = {"Xs_train_pca": X_feat}   # name kept for compatibility
    rows = []

    for k in k_list:
        print(f"  k={k}")
        t0 = time.time()
        km = KMeans(n_clusters=k, n_init=KMEANS_NINIT, max_iter=KMEANS_MAXITER,
                    algorithm=KMEANS_ALGO, random_state=RAND_STATE).fit(X_feat)
        elapsed = time.time() - t0
        kmeans_saving[f"n_clusters{k}"] = saving_from_kmeans(X_feat, prot_meta, km)
        kmeans_saving[f"time_k{k}"]     = f"{elapsed:.2f}s"

        # Pair-level eval
        pairs = measure_pairwise_performance(kmeans_saving[f"n_clusters{k}"], X_feat)
        naive_list, dist_list, one2one_list = pairs

        def count_pairs(pairs_list, idx):
            n_corr, n_tot = 0, len(pairs_list)
            list_all_groups_no_set = [p[4] for p in prot_meta]
            n_species = len(set([p[0] for p in prot_meta]))
            for prots in pairs_list:
                same_group = (len(set([p[4] for p in prots])) == 1)
                if not same_group:
                    continue
                if idx == 2:
                    # additional 1:1 check (same as original code)
                    if list_all_groups_no_set.count(prots[0][4]) == n_species:
                        n_corr += 1
                else:
                    n_corr += 1
            return n_corr, n_tot

        n_corr_naive, n_tot_naive = count_pairs(naive_list, 0)
        n_corr_dist,  n_tot_dist  = count_pairs(dist_list, 1)
        n_corr_121,   n_tot_121   = count_pairs(one2one_list, 2)

        # Group-level eval
        fam, ami, exact = measure_group_performance(kmeans_saving[f"n_clusters{k}"])

        rows.append({
            "k": k,
            "naive_correct": n_corr_naive, "naive_total": n_tot_naive,
            "dist_correct":  n_corr_dist,  "dist_total":  n_tot_dist,
            "one2one_correct": n_corr_121, "one2one_total": n_tot_121,
            "family": float(fam), "AMI": float(ami), "exact_pct": float(exact*100.0),
            "kmeans_time": kmeans_saving[f"time_k{k}"]
        })

    # Save PKL (full kmeans_saving, includes meta inside each entry)
    with open(os.path.join(EXP_DIR, f"{TEST_SPEC}_{variant_tag}_{LOSS_TYPE}_kmeans.pkl"), "wb") as f:
        pickle.dump(kmeans_saving, f, protocol=pickle.HIGHEST_PROTOCOL)

    # Save CSV metrics
    csv_path = os.path.join(EXP_DIR, f"{TEST_SPEC}_{variant_tag}_{LOSS_TYPE}_metrics.csv")
    with open(csv_path, "w", newline="") as f:
        cw = csv.writer(f)
        cw.writerow(["k",
                     "naive_correct","naive_total","naive_pct",
                     "dist_correct","dist_total","dist_pct",
                     "one2one_correct","one2one_total","one2one_pct",
                     "family","AMI","exact_pct","kmeans_time"])
        for r in rows:
            cw.writerow([
                r["k"],
                r["naive_correct"], r["naive_total"],
                (100.0*r["naive_correct"]/max(1,r["naive_total"])) if r["naive_total"] else 0.0,
                r["dist_correct"],  r["dist_total"],
                (100.0*r["dist_correct"]/max(1,r["dist_total"])) if r["dist_total"] else 0.0,
                r["one2one_correct"], r["one2one_total"],
                (100.0*r["one2one_correct"]/max(1,r["one2one_total"])) if r["one2one_total"] else 0.0,
                r["family"], r["AMI"], r["exact_pct"], r["kmeans_time"]
            ])
    print(f"[saved] {csv_path}")
    return rows

# ---- Run for plain & top-k latents ----
rows_plain = run_kmeans_eval(
    Z_plain, prot_names_and_group_test,
    variant_tag=f"AE{LATENT_DIM}_{loss_tag}_{lam_tag}_plain"
)
# if USE_TOPK:
#     rows_topk  = run_kmeans_eval(Z_topk,  prot_names_and_group_test, variant_tag=f"AE{LATENT_DIM}_topk{TOPK}")


[AE128_cosine_lam0.00075_plain] N=10263, #OGs=5315, ks=[5131]
  k=5131
[saved] /content/drive/MyDrive/ae_ablations/lamdaL1_7.5e-4/TRAIN[drer_xtro+mmus_hsap]_TEST[pfal_pber]_esm2_t36_3B_UR50D_AE128_L1-0.00075_SIG0.0_DO0.1_ninit50_iter500_topk0_20250913-144019_esm2_t36_3B_UR50D_AE128_loss-cosine_L1-0.00075_SIG0.0_DO0.1_ninit50_iter500_topk0_20250913-144019/pfal_pber_AE128_cosine_lam0.00075_plain_cosine_metrics.csv


In [None]:
meta = {
  "train_pool": TRAIN_SPEC,
  "test": TEST_SPEC,
  "model": MODEL_NAME,
  "emb_layer": EMB_LAYER,
  "ae": {
    "latent_dim": LATENT_DIM, "lambda_l1": LAMBDA_L1,
    "noise_sigma": NOISE_SIGMA, "dropout": DROPOUT_P,
    "lr": LR, "weight_decay": WD, "batch_size": BATCH_SIZE,
    "epochs": EPOCHS, "patience": PATIENCE
  },
  "kmeans": {
    "n_init": KMEANS_NINIT, "max_iter": KMEANS_MAXITER,
    "algorithm": KMEANS_ALGO, "random_state": RAND_STATE
  },
  "topk_enabled": USE_TOPK, "topk": TOPK,
  "exp_dir": EXP_DIR, "timestamp": ts, "device": str(device)
}
save_json(meta, os.path.join(EXP_DIR, "experiment_meta.json"))
print("[saved] experiment_meta.json")


[saved] experiment_meta.json
