Initialize the environment

In [None]:
import torch
import argparse

print(torch.__version__)
print(torch.cuda.is_available())

In [None]:
# from TamGen_Demo import TamGenDemo, prepare_pdb_data
# import os

# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# worker = TamGenDemo(
#     data="./TamGen_Demo_Data",
#     ckpt="checkpoints/crossdocked_model/checkpoint_best.pt",
#     use_conditional=False
# )

Set the PDB you want to generate cmpds

In [None]:
# pdb_id = "7d1m" # "5dzk, 7d1m" "7te0"
# prepare_pdb_data(pdb_id,)
# worker.reload_data(subset="gen_" + pdb_id.lower())

# hyps, ref = worker.sample(
#     m_sample=5000, 
#     maxseed=30,
# )

In [None]:
from fairseq.molecule_utils.external_tools.autodock_smina import AutoDockSmina

print("Discovered binary:", AutoDockSmina.find_binary())

In [None]:
import time
from fairseq.molecule_utils.basic.run_gnina_docking import docking

smiles = 'CC(C)CCCC(C)C1CCC2C3CCC4CC(O)CC[C@]4(C)C3CC[C@]12C'
pdb = '3ny8'

# First run
start = time.time()
_ = docking(pdb_id=pdb, ligand_smiles=smiles)
print(f"First run took {time.time() - start:.2f}s")

# Second run
start = time.time()
_ = docking(pdb_id=pdb, ligand_smiles=smiles)
print(f"Second run took {time.time() - start:.2f}s")

In [None]:
from rdkit import Chem
from rdkit.Chem import AllChem, rdmolfiles

smiles_list = ['CC(C)CCCC(C)C1CCC2C3CCC4CC(O)CC[C@]4(C)C3CC[C@]12C']
writer = rdmolfiles.SDWriter("ligands.sdf")
for smi in smiles_list:
    mol = Chem.MolFromSmiles(smi)
    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol)
    AllChem.MMFFOptimizeMolecule(mol)
    writer.write(mol)
writer.close()

In [None]:
import os
import sys

os.execv(sys.executable, [sys.executable] + sys.argv)


In [None]:
# -*- coding: utf-8 -*-

from fairseq.molecule_utils.basic import run_gnina_docking

affinity = run_gnina_docking.docking(pdb_id='3ny8',
    ligand_smiles='CC(C)CCCC(C)C1CCC2C3CCC4CC(O)CC[C@]4(C)C3CC[C@]12C')

print(affinity)

In [None]:
import os
import time
import numpy as np
import logging
from feedback.centroid_optimizer import centroid_shift_optimize
from TamGen_Demo import TamGenDemo, prepare_pdb_data

# === Setup Logging ===
os.makedirs("latent_logs", exist_ok=True)
logging.basicConfig(
    filename="latent_logs/debug_latent.log",
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s"
)
logging.info("🚀 Feedback loop started.")

# === Initialize TamGen ===
print("📦 Initializing TamGen...")
demo = TamGenDemo(
    data="TamGen_Demo_Data",
    ckpt="checkpoints/crossdocked_model/checkpoint_best.pt",
    use_conditional=True
)

pdb_id = "3ny8"
print(f"📄 Preparing PDB: {pdb_id}")
prepare_pdb_data(pdb_id)
demo.reload_data(subset="gen_" + pdb_id.lower())

# === Configuration ===
NUM_ITER = 5
LATENT_DIM = 256
ALPHA = 0.5
TOP_K = 50
LAMBDA_SAS = 0.3
LAMBDA_LOGP = 0.1
LAMBDA_MW = 0.1

print("⚙️  Starting closed-loop optimization...")
for iteration in range(NUM_ITER):
    print(f"\n🚀 Iteration {iteration + 1}/{NUM_ITER}")
    start_time = time.time()

    # 1. Sampling
    print("🔍 Generating candidates...")
    results, _ = demo.sample(m_sample=100, maxseed=20)
    smiles_list = list(results.keys())
    print(f"   ✔ Generated {len(smiles_list)} molecules.")

    # 2. Load Latent Vectors
    print("📈 Loading latent vectors...")
    z_vectors = np.loadtxt("latent_logs/latent_vectors.tsv")  # use default whitespace splitting

    if z_vectors.ndim != 2 or z_vectors.shape[1] != LATENT_DIM:
        raise ValueError(f"❌ Latent vector file malformed: expected {LATENT_DIM} columns, got {z_vectors.shape}")

    if len(z_vectors) != len(smiles_list):
        print("⚠️  Warning: Latent vector count mismatch. Truncating to match.")
        min_len = min(len(z_vectors), len(smiles_list))
        z_vectors = z_vectors[:min_len]
        smiles_list = smiles_list[:min_len]

    # 3. Placeholder Docking Scores
    docking_scores = [None] * len(smiles_list)

    # 4. Optimization
    print("📊 Optimizing latent space...")
    z_shifted, rewards, metrics = centroid_shift_optimize(
        z_vectors,
        smiles_list,
        docking_scores,
        latent_dim=LATENT_DIM,
        top_k=TOP_K,
        shift_alpha=ALPHA,
        lambda_sas=LAMBDA_SAS,
        lambda_logp=LAMBDA_LOGP,
        lambda_mw=LAMBDA_MW,
    )
    print("   ✔ Optimization complete.")

    # 5. Save Outputs
    print("💾 Saving latent vectors and rewards...")
    np.savetxt("latent_logs/latent_vectors.tsv", np.array(z_shifted), fmt="%.5f")
    with open(f"latent_logs/rewards_iter_{iteration + 1}.tsv", "w") as f:
        for smi, r in zip(smiles_list, rewards):
            f.write(f"{smi}\t{r:.4f}\n")

    logging.info(f"✅ Completed Iteration {iteration + 1} in {time.time() - start_time:.2f}s")
    print(f"✅ Iteration {iteration + 1} complete.")

print("\n🎉 Feedback loop finished. Ready for SGDS optimization.")

In [None]:
import torch

torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [None]:
from TamGen_RL import TamGenRL
from utils import prepare_pdb_data, prepare_pdb_data_center, filter_generated_cmpd
import torch

torch.cuda.empty_cache()
torch.cuda.ipc_collect()

# === Setup TamGenRL ===
pdb_id = "3ny8"
print(f"📄 Preparing PDB: {pdb_id}")
prepare_pdb_data(pdb_id)

demo = TamGenRL(
    data="TamGen_Demo_Data",
    ckpt="checkpoints/crossdocked_model/checkpoint_best.pt",
    use_conditional=True
)
demo.reload_data(subset="gen_" + pdb_id.lower())

# === Run Closed-Loop Optimization ===
final_smiles = demo.sample(
    m_sample=100,         # Number of molecules per iteration
    num_iter=5,           # Number of closed-loop optimization steps
    latent_dim=256,       # Latent space dimensionality (set to your model's config)
    alpha=0.5,            # Centroid shift parameter
    top_k=50,             # How many top molecules to use for shifting
    lambda_sas=0.3,       # Reward hyperparameters
    lambda_logp=0.1,
    lambda_mw=0.1,
    maxseed=20,           # Number of random seeds (first iteration)
    use_cuda=True
)

# === Save or Analyze Results ===
print(f"\nFinal set of SMILES ({len(final_smiles)} molecules):")
for smi in final_smiles:
    print(smi)

In [None]:
import glob
import numpy as np
import matplotlib.pyplot as plt

reward_files = sorted(glob.glob('latent_logs/rewards_iter_*.tsv'))
means, maxs, medians = [], [], []

for f in reward_files:
    rewards = []
    with open(f) as fin:
        for line in fin:
            parts = line.strip().split('\t')
            if len(parts) == 2:
                rewards.append(float(parts[1]))
    if rewards:
        rewards = np.array(rewards)
        means.append(rewards.mean())
        maxs.append(rewards.max())
        medians.append(np.median(rewards))

plt.plot(means, label='Mean')
plt.plot(maxs, label='Max')
plt.plot(medians, label='Median')
plt.xlabel('Iteration')
plt.ylabel('Reward')
plt.title('Reward Statistics Across Iterations')
plt.legend()
plt.show()

Visualize the molecules

In [None]:
# from rdkit import Chem 
# from rdkit.Chem import Draw, AllChem, DataStructs
# from rdkit.Chem import MACCSkeys

# fp_ref = MACCSkeys.GenMACCSKeys(ref)

# gens = []

# for k,v in hyps.items():
#     fp = MACCSkeys.GenMACCSKeys(v)
#     similarity = DataStructs.FingerprintSimilarity(fp_ref, fp, metric=DataStructs.TanimotoSimilarity)
#     gens.append((v,k, similarity))

# sorted_gen = sorted(gens, key=lambda e: e[-1], reverse=True)

# # img=Draw.MolsToGridImage([e[0] for e in sorted_gen], molsPerRow=5, legends=["idx={}, morgan={:.2f}".format(ii, e[2]) for ii, e in enumerate(sorted_gen)])
# img=Draw.MolsToGridImage([e[0] for e in sorted_gen], molsPerRow=5, legends=["idx={}".format(ii) for ii in range(len(sorted_gen))])
# img

In [None]:
# from rdkit import Chem 
# from rdkit.Chem import Draw, AllChem, DataStructs
# from rdkit.Chem import MACCSkeys

# fp_ref = MACCSkeys.GenMACCSKeys(ref)

# gens = []

# for k,v in hyps.items():
#     fp = MACCSkeys.GenMACCSKeys(v)
#     similarity = DataStructs.FingerprintSimilarity(fp_ref, fp, metric=DataStructs.TanimotoSimilarity)
#     gens.append((v,k, similarity))

# sorted_gen = sorted(gens, key=lambda e: e[-1], reverse=True)

# # img=Draw.MolsToGridImage([e[0] for e in sorted_gen], molsPerRow=5, legends=["idx={}, morgan={:.2f}".format(ii, e[2]) for ii, e in enumerate(sorted_gen)])
# img=Draw.MolsToGridImage([e[0] for e in sorted_gen], molsPerRow=5, legends=["idx={}".format(ii) for ii in range(len(sorted_gen))])
# img

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D
# import umap.umap_ as umap

# # --- Load Latent Vectors ---
# latent_file = "latent_logs/latent_vectors.tsv"
# latent_vectors = np.loadtxt(latent_file, delimiter="\t")

# # Optional: check length matches your generated molecules
# if len(latent_vectors) != len(sorted_gen):
#     print(f"Warning: {len(latent_vectors)} latent vectors vs {len(sorted_gen)} molecules")
#     min_len = min(len(latent_vectors), len(sorted_gen))
#     latent_vectors = latent_vectors[:min_len]
#     sorted_gen = sorted_gen[:min_len]

# # --- Project to 3D ---
# reducer = umap.UMAP(n_components=3, n_neighbors=15, min_dist=0.1, metric="euclidean")
# latent_3d = reducer.fit_transform(latent_vectors)

# # --- Plot in 3D ---
# fig = plt.figure(figsize=(10, 8))
# ax = fig.add_subplot(111, projection='3d')
# x, y, z = latent_3d[:, 0], latent_3d[:, 1], latent_3d[:, 2]

# colors = [entry[2] for entry in sorted_gen]  # Tanimoto similarity

# p = ax.scatter(x, y, z, c=colors, cmap="viridis", s=20)
# fig.colorbar(p, ax=ax, label="Tanimoto similarity to reference")
# ax.set_title("TamGen Latent Space (3D UMAP)")
# ax.set_xlabel("UMAP-1")
# ax.set_ylabel("UMAP-2")
# ax.set_zlabel("UMAP-3")
# plt.show()