# Evaluación sencilla de MolForge con RDKit
_Notebook minimalista para probar rápidamente la **precisión (top‑1, Tc=1.0)** de un modelo MolForge ya entrenado._

Este cuaderno:
1) Recibe una **lista de SMILES** y un **tipo de fingerprint** de entrada.
2) Convierte cada SMILES a fingerprint con **RDKit** y lo tokeniza.
3) Llama a **MolForge** para decodificar fingerprint → SMILES (o SELFIES).
4) Calcula la **precisión** como el % de casos con **Tc=1.0** (Tanimoto exact) usando **Morgan sparse r=1** para evaluar.

**Código comentado en inglés.**

## 0) Device — detección automática CPU/GPU
<small>El entorno ya está preparado con `environment.yml`. Aquí detectamos el **device** automáticamente. Puedes forzarlo con `preferred="cpu"` o mediante variable de entorno `DEVICE`.</small>

In [None]:

# Device detection (import from src if available; otherwise define here)
try:
    from src.utils_device import pick_device  # prefer project utility if present
except Exception:
    def pick_device(preferred: str | None = None) -> str:
        """Return 'cuda', 'mps' (Apple), or 'cpu'. Supports override with preferred."""
        if preferred in {"cpu", "cuda", "mps"}:
            return preferred
        try:
            import torch
        except Exception:
            return "cpu"
        try:
            if torch.cuda.is_available():
                return "cuda"
        except Exception:
            pass
        try:
            import torch
            if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
                return "mps"
        except Exception:
            pass
        return "cpu"

import os
device = pick_device(os.getenv("DEVICE"))  # set DEVICE=cpu|cuda|mps to override
print("Using device:", device)


## 1) Input (edita aquí)
<small>Define los SMILES de prueba, el fingerprint de **entrada** (p. ej., `ECFP4`) y la configuración del modelo MolForge (nombre/ckpt).</small>

In [None]:

# --- Dummy SMILES to quickly run once you connect MolForge ---
SMILES_LIST = [
    "CCO",                       # ethanol
    "c1ccccc1",                  # benzene
    "CC(=O)OC1=CC=CC=C1C(=O)O",  # aspirin
    "CN1C=NC2=C1C(=O)N(C(=O)N2)C",  # caffeine
    "CC(=O)O",                   # acetic acid
]

# Fingerprint to FEED MolForge (input side). Common options: 'ECFP4', 'AEs', 'TT', 'HashAP', 'RDK4', 'MACCS', 'FCFP4', ...
INPUT_FP = "ECFP4"

# What MolForge will generate: 'SMILES' or 'SELFIES'
OUTPUT_REPR = "SMILES"

# MolForge model configuration (fill these with your real paths/names)
MODEL_NAME = "ecfp4_to_smiles"                 # adapt to your MolForge model naming
MOLFORGE_CHECKPOINT = "/path/to/model.ckpt"    # put your real checkpoint path

# Bit length for hashed fingerprints (used by ECFP/FCFP/HashAP/HashTT/RDK4/etc.)
HASHED_NBITS = 2048

## 2) Asunciones de entorno y conexión
<small>Este notebook asume que has creado y activado el entorno con `environment.yml` y que **MolForge** ya está instalado vía `pip install git+https://github.com/knu-lcbc/MolForge.git`. No hace falta repetir instalaciones aquí. Solo asegúrate de tener el **checkpoint** correcto, y el modelo podrá correr en `cpu` o `cuda` según lo detectado arriba.</small>

## 3) Imports
<small>Importamos RDKit y utilidades. Si SELFIES es la salida, instala `selfies` para convertir a SMILES al evaluar.</small>

In [None]:

from __future__ import annotations
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple

from rdkit import Chem
from rdkit.Chem import AllChem, MACCSkeys, rdMolDescriptors as rdDesc
from rdkit import DataStructs

try:
    import selfies as sf
    SELFIES_OK = True
except Exception:
    SELFIES_OK = False


## 4) Utilidades básicas (parseo y canonización)
<small>Funciones pequeñas para pasar de SMILES a moléculas RDKit y obtener SMILES canónicos.</small>

In [None]:

def to_mol(smiles: str):
    """Parse SMILES into an RDKit Mol with sanitization; return None on failure."""
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=True)
        if mol is None:
            return None
        Chem.SanitizeMol(mol)
        return mol
    except Exception:
        return None

def canon_smiles(mol, isomeric: bool = True) -> str:
    """Return canonical SMILES (isomeric if requested)."""
    if mol is None:
        return ""
    return Chem.MolToSmiles(mol, isomericSmiles=isomeric, canonical=True)

def selfies_to_smiles(s: str) -> str:
    """Convert SELFIES -> SMILES if selfies is installed; otherwise return input."""
    if not SELFIES_OK:
        return s
    try:
        return sf.decoder(s)
    except Exception:
        return s


## 5) Fingerprints de RDKit (entrada y evaluación)
<small>Implementamos los fingerprints de **entrada** y también el fingerprint de **evaluación** Morgan sparse r=1 usado para Tc.</small>

In [None]:

@dataclass
class FPResult:
    kind: str
    obj: Any
    is_sparse: bool
    nbits: Optional[int] = None

# ---- Input FPs (to FEED MolForge) ----
def fp_morgan_hashed(mol, radius: int, nBits: int, useFeatures: bool = False) -> FPResult:
    bv = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=nBits, useFeatures=useFeatures)
    kind = f"{'FCFP' if useFeatures else 'ECFP'}{2*radius}"
    return FPResult(kind=kind, obj=bv, is_sparse=False, nbits=nBits)

def fp_morgan_sparse(mol, radius: int) -> FPResult:
    siv = AllChem.GetMorganFingerprint(mol, radius=radius)  # UIntSparseIntVect
    kind = "AEs" if radius == 1 else f"MorganSparse_r{radius}"
    return FPResult(kind=kind, obj=siv, is_sparse=True)

def fp_tt_sparse(mol) -> FPResult:
    vec = rdDesc.GetTopologicalTorsionFingerprintAsIntVect(mol)
    return FPResult(kind="TT", obj=vec, is_sparse=True)

def fp_ap_hashed(mol, nBits: int) -> FPResult:
    bv = rdDesc.GetHashedAtomPairFingerprintAsBitVect(mol, nBits=nBits)
    return FPResult(kind="HashAP", obj=bv, is_sparse=False, nbits=nBits)

def fp_rdk4(mol, branchedPaths: bool = True, nBits: int = 2048) -> FPResult:
    bv = Chem.RDKFingerprint(mol, fpSize=nBits, minPath=2, maxPath=4, branchedPaths=branchedPaths)
    return FPResult(kind="RDK4" if branchedPaths else "RDK4-L", obj=bv, is_sparse=False, nbits=nBits)

def fp_maccs(mol) -> FPResult:
    bv = MACCSkeys.GenMACCSKeys(mol)  # 167 bits
    return FPResult(kind="MACCS", obj=bv, is_sparse=False, nbits=167)

def compute_input_fp(mol, fp_name: str, nBits: int = 2048) -> FPResult:
    name = fp_name.upper()
    if name == "ECFP0":
        return fp_morgan_hashed(mol, radius=0, nBits=nBits, useFeatures=False)
    if name == "ECFP2":
        return fp_morgan_hashed(mol, radius=1, nBits=nBits, useFeatures=False)
    if name == "ECFP4":
        return fp_morgan_hashed(mol, radius=2, nBits=nBits, useFeatures=False)
    if name == "FCFP2":
        return fp_morgan_hashed(mol, radius=1, nBits=nBits, useFeatures=True)
    if name == "FCFP4":
        return fp_morgan_hashed(mol, radius=2, nBits=nBits, useFeatures=True)
    if name == "AES" or name == "AES" or name == "AEs".upper():
        return fp_morgan_sparse(mol, radius=1)
    if name == "TT":
        return fp_tt_sparse(mol)
    if name == "HASHAP":
        return fp_ap_hashed(mol, nBits=nBits)
    if name == "RDK4":
        return fp_rdk4(mol, branchedPaths=True, nBits=nBits)
    if name == "RDK4-L":
        return fp_rdk4(mol, branchedPaths=False, nBits=nBits)
    if name == "MACCS":
        return fp_maccs(mol)
    raise ValueError(f"Unknown input FP: {fp_name}")

def fp_to_tokens(fp: FPResult) -> list[int]:
    """Map an RDKit fingerprint to token IDs for a Transformer.
"
    "- Hashed bit vectors: return active bit indices
"
    "- Sparse vectors: return explicit keys, repeated by count (multiset)
"
    """"
    if fp.is_sparse:
        elems = fp.obj.GetNonzeroElements()  # dict[id] = count
        toks = []
        for k, c in elems.items():
            toks.extend([int(k)] * int(c))
        toks.sort()
        return toks
    else:
        return sorted(list(fp.obj.GetOnBits()))

# ---- Evaluation FP (Morgan sparse r=1 for Tc) ----
def eval_fp_morgan_sparse_r1(mol) -> FPResult:
    return fp_morgan_sparse(mol, radius=1)

def tanimoto(fp_a: FPResult, fp_b: FPResult) -> float:
    return DataStructs.TanimotoSimilarity(fp_a.obj, fp_b.obj)


## 6) Adaptador de MolForge (conecta tu modelo aquí)
<small>Implementa 2 métodos: `load()` (carga modelo/checkpoint/tokenizador) y `predict_batch()` (decodifica una lista de secuencias de tokens a SMILES/SELFIES).
Este cuaderno **no incluye modo demo**: debes tener MolForge operativo para ejecutar la celda 8.
</small>

In [None]:

class MolForgeDecoder:
    """Minimal adapter for the MolForge model.
"
    "You MUST implement .load() and .predict_batch() to call the real MolForge API.
"
    """"
    def __init__(self, model_name: str, checkpoint: str, device: str = "cpu",
                 tokenizer_vocab_json: Optional[str] = None, output_repr: str = "SMILES"):
        self.model_name = model_name
        self.checkpoint = checkpoint
        self.device = device
        self.tokenizer_vocab_json = tokenizer_vocab_json
        self.output_repr = output_repr
        self._model = None
        self._tokenizer = None

    def load(self):
        """Load your MolForge model + tokenizer here.
"
        "Example (pseudo-code):
"
        "    from molforge.models import load_model, load_tokenizer
"
        "    self._model = load_model(self.model_name, self.checkpoint, device=self.device)
"
        "    self._tokenizer = load_tokenizer(self.model_name, vocab_json=self.tokenizer_vocab_json)
"
        """"
        raise NotImplementedError("Implement MolForgeDecoder.load() to load the actual model.")

    def predict_batch(self, list_of_token_lists: list[list[int]]) -> list[str]:
        """Decode a batch of token sequences into strings (SMILES or SELFIES).
"
        "Return one string per input sequence, in the same order.
"
        "Example (pseudo-code):
"
        "    return self._model.generate(list_of_token_lists, output_repr=self.output_repr)
"
        """"
        raise NotImplementedError("Implement MolForgeDecoder.predict_batch() to call the actual MolForge decoding.")


## 7) Función mínima de evaluación (precisión top‑1, Tc=1.0)
<small>Calcula la **precisión** como el % de pares (GT, pred) cuya **huella de evaluación** (`Morgan sparse r=1`) coincide exactamente (**Tc=1.0**). También devuelve cuántas predicciones fueron inválidas (no parseables).</small>

In [None]:

def evaluate_precision_tc1(
    smiles_list: list[str],
    input_fp: str,
    model_name: str,
    checkpoint: str,
    output_repr: str = "SMILES",
    device: str = "cpu",
    hashed_nbits: int = 2048,
    tokenizer_vocab_json: Optional[str] = None,
) -> dict:
    """Compute top-1 exactness (Tc=1.0) using Morgan sparse r=1 as evaluation FP.
"
    "Return a small dict with precision and simple counts.
"
    """"
    # 1) Parse GT mols and build input fingerprints (tokens) to FEED MolForge
    gt_mols = [to_mol(s) for s in smiles_list]
    input_tokens = []
    for m in gt_mols:
        if m is None:
            input_tokens.append([])
            continue
        fp = compute_input_fp(m, input_fp, nBits=hashed_nbits)
        toks = fp_to_tokens(fp)
        input_tokens.append(toks)

    # 2) Decode with MolForge
    decoder = MolForgeDecoder(
        model_name=model_name,
        checkpoint=checkpoint,
        device=device,
        tokenizer_vocab_json=tokenizer_vocab_json,
        output_repr=output_repr,
    )
    decoder.load()  # <-- implement inside the class
    preds = decoder.predict_batch(input_tokens)  # <-- implement inside the class

    # 3) Evaluate Tc=1.0 (Morgan sparse r=1) and invalids
    exact = 0
    invalid = 0
    total = len(smiles_list)

    for gt, pr in zip(smiles_list, preds):
        # Convert SELFIES -> SMILES if needed for evaluation
        if output_repr.upper() == "SELFIES":
            pr = selfies_to_smiles(pr)

        gt_m = to_mol(gt)
        pr_m = to_mol(pr)
        if gt_m is None or pr_m is None:
            invalid += 1
            continue

        fp_gt = eval_fp_morgan_sparse_r1(gt_m)
        fp_pr = eval_fp_morgan_sparse_r1(pr_m)
        tc = tanimoto(fp_gt, fp_pr)
        if tc >= 1.0 - 1e-12:
            exact += 1

    precision = exact / total if total else float("nan")
    return {
        "n": total,
        "precision_tc1": precision,  # e.g., ~0.93 means 93%
        "num_exact": exact,
        "num_invalid": invalid,
    }


## 8) Ejecutar evaluación con los SMILES de prueba
<small>**IMPORTANTE**: Esta celda requiere que hayas implementado `MolForgeDecoder.load()` y `predict_batch()` y que `MODEL_NAME`/`MOLFORGE_CHECKPOINT` apunten a tu modelo real.</small>

In [None]:

results = evaluate_precision_tc1(
    smiles_list=SMILES_LIST,
    input_fp=INPUT_FP,
    model_name=MODEL_NAME,
    checkpoint=MOLFORGE_CHECKPOINT,
    output_repr=OUTPUT_REPR,
    device=device,
    hashed_nbits=HASHED_NBITS,
    tokenizer_vocab_json=None,
)
print("n               :", results["n"])
print("precision_tc1   :", results["precision_tc1"])  # fraction; multiply by 100 for %
print("num_exact       :", results["num_exact"])
print("num_invalid     :", results["num_invalid"])


## 9) Cómo conectar MolForge al adaptador (resumen rápido)
<small>
1. Instala el repo de MolForge y asegúrate de poder importar sus módulos (e.g., `import molforge`).
2. En `MolForgeDecoder.load()`, carga el **modelo** y el **tokenizador** con tu checkpoint (`MOLFORGE_CHECKPOINT`) y `MODEL_NAME`.
3. En `MolForgeDecoder.predict_batch()`, convierte `list_of_token_lists` al tipo que espera el modelo y llama al método de **decodificación/generación** (p. ej., `model.generate(...)`). Devuelve una lista de strings (SMILES o SELFIES) en el mismo orden.
4. Verifica que tu **tokenización de entrada** coincide con la usada en el entrenamiento (para hashed: índices de bits activos; para sparse: IDs explícitos de fragmentos con multiplicidad).
5. Lanza la celda 8 para obtener la **precisión** (Tc=1.0) del modelo pre-entrenado sobre tus SMILES.
</small>