# MolForge Evaluation Playground (RDKit + MolForge)

Notebook modular i executable per
- **rebre SMILES** (p. ex. de CoCoGraph),
- **convertir-los a fingerprints** amb RDKit,
- **decodificar amb MolForge** (fingerprint → SMILES/SELFIES), i
- **avaluar** amb les **mètriques** que es fan servir al paper de MolForge.

**Comentaris del codi en anglès** per claredat; explicacions i instruccions en català/castellà.

---
### Assumpcions de l'entorn
- **Entorn creat via `environment.yml`**
- **RDKit i MolForge** ja disponibles (instal·lats via `environment.yml` + pip GitHub).
- No cal clonar MolForge: s'ha instal·lat directament des de GitHub.
- (Opcional) **SELFIES** ja inclòs al `environment.yml`.

Aquest notebook inclou un **mode demo** opcional: simula prediccions perquè puguis executar-lo i veure la mètrica. 
Quan tinguis MolForge preparat, desactiva el `DEMO_MODE` a la secció d'Input i omple `MOLFORGE_CHECKPOINT`/`MODEL_NAME`.

## 0) Device — detección automática CPU/GPU
<small>El entorno ya está preparado con `environment.yml`. Aquí detectamos el **device** automáticamente. Puedes forzarlo con `preferred="cpu"` o mediante variable de entorno `DEVICE`.</small>

In [None]:

# Device detection (import from src if available; otherwise define here)
try:
    from src.utils_device import pick_device  # prefer project utility if present
except Exception:
    def pick_device(preferred: str | None = None) -> str:
        """Return 'cuda', 'mps' (Apple), or 'cpu'. Supports override with preferred."""
        if preferred in {"cpu", "cuda", "mps"}:
            return preferred
        try:
            import torch
        except Exception:
            return "cpu"
        try:
            if torch.cuda.is_available():
                return "cuda"
        except Exception:
            pass
        try:
            import torch
            if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
                return "mps"
        except Exception:
            pass
        return "cpu"

import os
device = pick_device(os.getenv("DEVICE"))  # set DEVICE=cpu|cuda|mps to override
print("Using device:", device)


## Sobre el 'sesgo' a l'avaluació
Tu vols **avaluar un model ja entrenat** amb molècules qualsevol. Perfecte. El que anomenem aquí 'sesgo de mètrica' **no té a veure** amb com s'ha entrenat el model, sinó amb **com mesurem la similitud**.
- Diferents fingerprints tenen **escales de Tc** diferents. Si només uses **una** huella per avaluar, el número pot semblar més alt/baix només per l'escala d'aquella huella.
- Per això el paper mostra **multi-fingerprint** i un **llindar CDF p=0.01**: perquè les lectures siguin més comparables entre representacions. 
- **No és obligatori**. Aquí et dono **totes les mètriques**, i tu pots activar/desactivar l'avaluació multi-fingerprint segons et convingui.

## 1) Input — Paràmetres i SMILES de prova
Edita aquesta cel·la per escollir fingerprints, sortida, checkpoints, etc.

In [None]:

# ==============================
# INPUT SECTION (edit here)
# ==============================

# --- Dummy SMILES (real molecules) to let you run the notebook immediately ---
SMILES_LIST = [
    "CCO",                       # ethanol
    "c1ccccc1",                  # benzene
    "CC(=O)OC1=CC=CC=C1C(=O)O",  # aspirin
    "CN1C=NC2=C1C(=O)N(C(=O)N2)C",  # caffeine
    "CC(=O)O",                   # acetic acid
]

# --- Fingerprint used as *input* to MolForge (choose one) ---
# Common: 'ECFP4', 'AEs', 'TT', 'HashAP', 'RDK4', 'MACCS', 'FCFP4', etc.
INPUT_FP = "ECFP4"

# --- Output representation from MolForge ---
# 'SMILES' or 'SELFIES' (SELFIES avoids invalid strings but you need selfies installed)
OUTPUT_REPR = "SMILES"

# --- Multi-fingerprint evaluation (recommended but optional) ---
# If True, we compute Tc with a panel of ~15 fingerprints; if False, we only use a single eval FP.
USE_MULTI_FP = True

# --- If not using multi-fp, pick a single eval fingerprint for Tc ---
SINGLE_EVAL_FP = "ECFP4"

# --- MolForge model config (fill when you have the real repo installed) ---
MODEL_NAME = "ecfp4_to_smiles"       # adapt to your MolForge model naming
MOLFORGE_CHECKPOINT = "/path/to/molforge_ckpt.pt"  # <-- put your real path here

# --- Device for MolForge ('cuda' or 'cpu') ---

# --- Tokenizer vocab for sparse modes (AEs/TT) if your MolForge needs it ---
TOKENIZER_VOCAB_JSON = None

# --- Bit length for hashed fingerprints (used for ECFP/FCFP/HashAP/HashTT/RDK4/Pattern/Layered) ---
HASHED_NBITS = 2048

# --- DEMO MODE: If True, no real MolForge is required. A DummyDecoder will echo canonical SMILES.
# Set to False when you are ready to use the real MolForge model.
DEMO_MODE = False
# Nota: el 'device' s'ha detectat a la cel·la 0 i es passarà automàticament al decoder.


## 2) Imports i comprovacions d'entorn
L'entorn ja ha estat creat amb `environment.yml`. Aquí només comprovem la disponibilitat de **RDKit** i **SELFIES** per informar.

In [None]:

# Core imports
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import math

# RDKit
try:
    from rdkit import Chem
    from rdkit.Chem import AllChem, MACCSkeys, rdMolDescriptors as rdDesc
    from rdkit import DataStructs
    RDKit_OK = True
except Exception as e:
    RDKit_OK = False
    print("RDKit not available. Please install RDKit (conda recommended): conda install -c rdkit rdkit")

# Avalon (optional; only if you intend to use Avalon fingerprints)
try:
    from rdkit.Avalon import pyAvalonTools as avalon
    AVALON_OK = True
except Exception:
    AVALON_OK = False

# SELFIES (optional; only if OUTPUT_REPR='SELFIES')
try:
    import selfies as sf
    SELFIES_OK = True
except Exception:
    SELFIES_OK = False

print("RDKit_OK:", RDKit_OK, "| AVALON_OK:", AVALON_OK, "| SELFIES_OK:", SELFIES_OK)


## 3) Helpers — parsing, canonització i conversió SELFIES

In [None]:

# ------------------------------
# Helper functions (parsing, canonicalization, selfies)
# ------------------------------

def to_mol(smiles: str):
    """Parse SMILES into an RDKit Mol with sanitization; return None on failure."""
    if not RDKit_OK:
        return None
    try:
        mol = Chem.MolFromSmiles(smiles, sanitize=True)
        if mol is None:
            return None
        Chem.SanitizeMol(mol)
        return mol
    except Exception:
        return None

def canon_smiles(mol, isomeric: bool = True) -> str:
    """Return canonical SMILES (isomeric if requested)."""
    if not RDKit_OK or mol is None:
        return ""
    return Chem.MolToSmiles(mol, isomericSmiles=isomeric, canonical=True)

def selfies_to_smiles(s: str) -> str:
    """Convert SELFIES -> SMILES if selfies is installed; otherwise return input."""
    if not SELFIES_OK:
        return s
    try:
        return sf.decoder(s)
    except Exception:
        return s


## 4) Fingerprints RDKit — definicions i codi
Aquí construïm totes les huelles que farem servir com a **entrada** del model i també per **avaluar** (Tc).

In [None]:

from dataclasses import dataclass

@dataclass
class FPResult:
    kind: str           # name
    obj: Any            # RDKit ExplicitBitVect or UIntSparseIntVect
    is_sparse: bool
    nbits: Optional[int] = None

def morgan_hashed(mol, radius: int, nBits: int, useFeatures: bool = False) -> FPResult:
    bv = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=nBits, useFeatures=useFeatures)
    kind = f"{'FCFP' if useFeatures else 'ECFP'}{2*radius}"
    return FPResult(kind=kind, obj=bv, is_sparse=False, nbits=nBits)

def morgan_sparse(mol, radius: int) -> FPResult:
    siv = AllChem.GetMorganFingerprint(mol, radius=radius)  # UIntSparseIntVect (counts)
    kind = "AEs" if radius == 1 else f"MorganSparse_r{radius}"
    return FPResult(kind=kind, obj=siv, is_sparse=True)

def tt_sparse(mol) -> FPResult:
    vec = rdDesc.GetTopologicalTorsionFingerprintAsIntVect(mol)
    return FPResult(kind="TT", obj=vec, is_sparse=True)

def tt_hashed(mol, nBits: int) -> FPResult:
    bv = rdDesc.GetHashedTopologicalTorsionFingerprintAsBitVect(mol, nBits=nBits)
    return FPResult(kind="HashTT", obj=bv, is_sparse=False, nbits=nBits)

def ap_sparse(mol) -> FPResult:
    vec = rdDesc.GetAtomPairFingerprint(mol)
    return FPResult(kind="AP", obj=vec, is_sparse=True)

def ap_hashed(mol, nBits: int) -> FPResult:
    bv = rdDesc.GetHashedAtomPairFingerprintAsBitVect(mol, nBits=nBits)
    return FPResult(kind="HashAP", obj=bv, is_sparse=False, nbits=nBits)

def rdk4(mol, branchedPaths: bool = True, nBits: int = 2048) -> FPResult:
    bv = Chem.RDKFingerprint(mol, fpSize=nBits, minPath=2, maxPath=4, branchedPaths=branchedPaths)
    return FPResult(kind="RDK4" if branchedPaths else "RDK4-L", obj=bv, is_sparse=False, nbits=nBits)

def avalon_fp(mol, nBits: int = 512) -> Optional[FPResult]:
    if not AVALON_OK:
        return None
    bv = avalon.GetAvalonFP(mol, nBits=nBits)
    return FPResult(kind="Avalon", obj=bv, is_sparse=False, nbits=nBits)

def maccs(mol) -> FPResult:
    bv = MACCSkeys.GenMACCSKeys(mol)  # 167 bits
    return FPResult(kind="MACCS", obj=bv, is_sparse=False, nbits=167)

def pattern_fp(mol, nBits: int = 2048) -> FPResult:
    bv = Chem.PatternFingerprint(mol, fpSize=nBits)
    return FPResult(kind="PatternFP", obj=bv, is_sparse=False, nbits=nBits)

def layered_fp(mol, nBits: int = 2048) -> FPResult:
    bv = Chem.LayeredFingerprint(mol, fpSize=nBits)
    return FPResult(kind="LayeredFP", obj=bv, is_sparse=False, nbits=nBits)

FP_ALIASES = {
    "ECFP0": "ECFP0", "ECFP2": "ECFP2", "ECFP4": "ECFP4",
    "FCFP2": "FCFP2", "FCFP4": "FCFP4",
    "AEs": "AEs",
    "TT": "TT", "HashTT": "HashTT",
    "AP": "AP", "HashAP": "HashAP",
    "RDK4": "RDK4", "RDK4-L": "RDK4-L",
    "Avalon": "Avalon",
    "MACCS": "MACCS",
    "PatternFP": "PatternFP",
    "LayeredFP": "LayeredFP",
}

DEFAULT_EVAL_FPS = [
    "ECFP0", "ECFP2", "ECFP4", "FCFP2", "FCFP4",
    "AEs", "TT", "HashTT", "AP", "HashAP",
    "RDK4", "RDK4-L", "Avalon", "MACCS", "PatternFP"
]

def compute_fp(mol, fp_name: str, nBits: int = 2048) -> FPResult:
    name = FP_ALIASES.get(fp_name, fp_name)
    if name == "ECFP0":
        return morgan_hashed(mol, radius=0, nBits=nBits, useFeatures=False)
    if name == "ECFP2":
        return morgan_hashed(mol, radius=1, nBits=nBits, useFeatures=False)
    if name == "ECFP4":
        return morgan_hashed(mol, radius=2, nBits=nBits, useFeatures=False)
    if name == "FCFP2":
        return morgan_hashed(mol, radius=1, nBits=nBits, useFeatures=True)
    if name == "FCFP4":
        return morgan_hashed(mol, radius=2, nBits=nBits, useFeatures=True)
    if name == "AEs":
        return morgan_sparse(mol, radius=1)
    if name == "TT":
        return tt_sparse(mol)
    if name == "HashTT":
        return tt_hashed(mol, nBits=nBits)
    if name == "AP":
        return ap_sparse(mol)
    if name == "HashAP":
        return ap_hashed(mol, nBits=nBits)
    if name == "RDK4":
        return rdk4(mol, branchedPaths=True, nBits=nBits)
    if name == "RDK4-L":
        return rdk4(mol, branchedPaths=False, nBits=nBits)
    if name == "Avalon":
        res = avalon_fp(mol, nBits=512)
        if res is None:
            raise RuntimeError("Avalon requested but Avalon toolkit is not available.")
        return res
    if name == "MACCS":
        return maccs(mol)
    if name == "PatternFP":
        return pattern_fp(mol, nBits=nBits)
    if name == "LayeredFP":
        return layered_fp(mol, nBits=nBits)
    raise ValueError(f"Unknown fingerprint: {fp_name}")

def fp_to_tokens(fp: FPResult) -> List[int]:
    """Map an RDKit fingerprint to token IDs for a Transformer.
"
    "- hashed bit vectors: return active bit indices
"
    "- sparse vectors: return explicit keys, repeated by count (multiset)
"
    """"
    if fp.is_sparse:
        elems = fp.obj.GetNonzeroElements()  # dict[id] = count
        toks = []
        for k, c in elems.items():
            toks.extend([int(k)] * int(c))
        toks.sort()
        return toks
    else:
        on_bits = list(fp.obj.GetOnBits())
        return sorted(on_bits)


## 5) Mètriques — Tc, breakdown i estadístics

In [None]:

# ------------------------------
# Similarity & metrics
# ------------------------------

def tanimoto(fp_a: FPResult, fp_b: FPResult) -> float:
    return DataStructs.TanimotoSimilarity(fp_a.obj, fp_b.obj)

def percentile(vals, q: float):
    if not vals:
        return float("nan")
    v = sorted(vals)
    k = min(max(int(round((q/100.0)*(len(v)-1))), 0), len(v)-1)
    return v[k]

def cdf_threshold(vals, p: float = 0.01):
    if not vals:
        return float("nan")
    v = sorted(vals)
    idx = max(int(math.floor(p * (len(v)-1))), 0)
    return v[idx]

from dataclasses import dataclass

@dataclass
class BreakdownCounts:
    string_exact: int = 0
    stereo_only: int = 0
    no_canonical: int = 0
    invalid: int = 0
    other_mismatch: int = 0

    def to_dict(self):
        return {
            "string_exact": self.string_exact,
            "stereo_only": self.stereo_only,
            "no_canonical": self.no_canonical,
            "invalid": self.invalid,
            "other_mismatch": self.other_mismatch,
        }

def classify_pair(gt_smiles: str, pred_str: str, output_repr: str = "SMILES"):
    """Classify prediction outcome and return (category, canon_gt_iso, canon_pred_iso)."""
    gt_m = to_mol(gt_smiles)
    gt_iso = canon_smiles(gt_m, isomeric=True)

    if output_repr.upper() == "SELFIES":
        pred_smiles = selfies_to_smiles(pred_str)
    else:
        pred_smiles = pred_str

    pred_m = to_mol(pred_smiles)
    if pred_m is None:
        return "invalid", gt_iso, ""

    pred_iso = canon_smiles(pred_m, isomeric=True)

    # String-exact (isomeric)
    if pred_iso == gt_iso:
        return "string_exact", gt_iso, pred_iso

    # Stereo-only (non-isomeric match)
    gt_noniso = canon_smiles(gt_m, isomeric=False)
    pred_noniso = canon_smiles(pred_m, isomeric=False)
    if gt_noniso == pred_noniso:
        return "stereo_only", gt_iso, pred_iso

    # No-canonical (Tc==1.0 wrt Morgan sparse r=1 but strings differ)
    fp_gt = morgan_sparse(gt_m, radius=1)
    fp_pd = morgan_sparse(pred_m, radius=1)
    if tanimoto(fp_gt, fp_pd) >= 1.0 - 1e-12:
        return "no_canonical", gt_iso, pred_iso

    return "other_mismatch", gt_iso, pred_iso


## 6) MolForge Decoder — adaptador
Aquesta classe encapsula la càrrega i la decodificació.
**Assumpció**: MolForge ja és instal·lable (`import molforge`).
Quan tinguis el repo/config llest, implementa `load()` i `predict_batch()` segons l'API real i assegura que el model es **mou al `device` detectat** (`.to(self.device)`).

In [None]:

class MolForgeDecoder:
    def __init__(self, model_name: str, checkpoint: Optional[str], device: str = "cpu",
                 tokenizer_vocab_json: Optional[str] = None, output_repr: str = "SMILES",
                 demo_mode: bool = False):
        self.model_name = model_name
        self.checkpoint = checkpoint
        self.device = device
        self.tokenizer_vocab_json = tokenizer_vocab_json
        self.output_repr = output_repr
        self.demo_mode = demo_mode
        self._model = None
        self._tokenizer = None

    def load(self):
        if self.demo_mode:
            self._model = object()  # mark as 'loaded'
            return
        # TODO: replace with real MolForge API
        try:
            import importlib
            importlib.import_module("molforge")
            # Example pseudo-code:
            # from molforge.models import load_model, load_tokenizer
            # self._model = load_model(self.model_name, self.checkpoint, device=self.device)
            # self._tokenizer = load_tokenizer(self.model_name, vocab_json=self.tokenizer_vocab_json)
            self._model = object()
        except Exception as e:
            raise RuntimeError("MolForge not found. Install the repo and adapt MolForgeDecoder.load().") from e

    def predict_batch(self, list_of_token_lists: List[List[int]], gt_smiles: Optional[List[str]] = None) -> List[str]:
        if self._model is None:
            raise RuntimeError("Decoder not loaded. Call load().")
        if self.demo_mode:
            # In demo: echo canonical SMILES of the GT (if provided), otherwise return empty strings
            out = []
            if gt_smiles is None:
                return ["" for _ in list_of_token_lists]
            for s in gt_smiles:
                m = to_mol(s)
                out.append(canon_smiles(m, isomeric=True))
            return out
        # TODO: replace with real generation call
        raise NotImplementedError("Implement MolForge decoding (model.generate/decode) for your repo setup.")


## 7) Funció principal — avaluació end-to-end

In [None]:

def evaluate_molforge(
    smiles_list: List[str],
    input_fp: str = "ECFP4",
    output_repr: str = "SMILES",
    molforge_model_name: str = "ecfp4_to_smiles",
    molforge_checkpoint: Optional[str] = None,
    eval_fps: Optional[List[str]] = None,
    device: str = "cpu",
    tokenizer_vocab_json: Optional[str] = None,
    hashed_nbits: int = 2048,
    return_predictions: bool = True,
    use_multi_fp: bool = True,
    single_eval_fp: str = "ECFP4",
    demo_mode: bool = False,
):
    """End-to-end evaluation mirroring MolForge paper metrics.
"
    "Returns a dict with breakdown counts, per-FP stats, and averages.
"
    """"
    if eval_fps is None:
        eval_fps = list(DEFAULT_EVAL_FPS) if use_multi_fp else [single_eval_fp]

    # 1) Parse GT molecules and build *input* fingerprints -> tokens
    gt_mols = [to_mol(s) for s in smiles_list]
    input_tokens_batch: List[List[int]] = []
    for m in gt_mols:
        if m is None:
            input_tokens_batch.append([])  # keep alignment
            continue
        fp = compute_fp(m, input_fp, nBits=hashed_nbits)
        toks = fp_to_tokens(fp)
        input_tokens_batch.append(toks)

    # 2) Decode with MolForge
    decoder = MolForgeDecoder(
        model_name=molforge_model_name,
        checkpoint=molforge_checkpoint,
        device=device,
        tokenizer_vocab_json=tokenizer_vocab_json,
        output_repr=output_repr,
        demo_mode=demo_mode,
    )
    decoder.load()
    predictions = decoder.predict_batch(input_tokens_batch, gt_smiles=smiles_list)

    # 3) Breakdown classification
    breakdown = BreakdownCounts()
    validity_mask: List[bool] = []
    for gt, pred in zip(smiles_list, predictions):
        cat, _, canon_pred = classify_pair(gt, pred, output_repr=output_repr)
        if cat == "string_exact":
            breakdown.string_exact += 1; validity_mask.append(True)
        elif cat == "stereo_only":
            breakdown.stereo_only += 1; validity_mask.append(True)
        elif cat == "no_canonical":
            breakdown.no_canonical += 1; validity_mask.append(True)
        elif cat == "invalid":
            breakdown.invalid += 1; validity_mask.append(False)
        else:
            breakdown.other_mismatch += 1; validity_mask.append(bool(canon_pred))

    total_valid = sum(1 for v in validity_mask if v)

    # 4) Multi-fingerprint Tc computation
    per_fp_tc = {fpn: [] for fpn in eval_fps}
    top1_counts = {fpn: 0 for fpn in eval_fps}

    for (gt, pred, is_valid) in zip(smiles_list, predictions, validity_mask):
        if not is_valid:
            continue
        gt_m = to_mol(gt)
        pred_m = to_mol(pred if output_repr.upper() == "SMILES" else selfies_to_smiles(pred))
        if gt_m is None or pred_m is None:
            continue
        for fp_name in eval_fps:
            gt_fp = compute_fp(gt_m, fp_name, nBits=hashed_nbits)
            pd_fp = compute_fp(pred_m, fp_name, nBits=hashed_nbits)
            tc = tanimoto(gt_fp, pd_fp)
            per_fp_tc[fp_name].append(float(tc))
            if tc >= 1.0 - 1e-12:
                top1_counts[fp_name] += 1

    # 5) Aggregate stats per FP
    eval_stats = {}
    for fp_name, tcs in per_fp_tc.items():
        if tcs:
            eval_stats[fp_name] = {
                "top1": top1_counts[fp_name] / max(total_valid, 1.0),
                "mean_tc": float(sum(tcs) / len(tcs)),
                "p01_threshold": float(cdf_threshold(tcs, p=0.01)),
                "p50": float(percentile(tcs, 50)),
                "p90": float(percentile(tcs, 90)),
                "p99": float(percentile(tcs, 99)),
                "n_pairs": int(len(tcs)),
            }
        else:
            eval_stats[fp_name] = {
                "top1": float('nan'), "mean_tc": float('nan'),
                "p01_threshold": float('nan'), "p50": float('nan'),
                "p90": float('nan'), "p99": float('nan'), "n_pairs": 0
            }

    # 6) Mean across eval FPs
    valid_fps = [fp for fp, st in eval_stats.items() if not math.isnan(st["top1"])]
    top1_mean = float(sum(eval_stats[fp]["top1"] for fp in valid_fps) / max(len(valid_fps), 1))
    mean_tc_mean = float(sum(eval_stats[fp]["mean_tc"] for fp in valid_fps) / max(len(valid_fps), 1))

    out = {
        "n": len(smiles_list),
        "breakdown": breakdown.to_dict(),
        "eval_fps": eval_stats,
        "eval_fps_average": {
            "top1_mean_over_fps": top1_mean,
            "mean_tc_mean_over_fps": mean_tc_mean,
            "n_eval_fps": len(valid_fps),
        },
        "predictions": predictions,
    }
    return out


## 8) Executar — prova ràpida amb els SMILES dummy

In [None]:

# Execute evaluation with current INPUTS
if not RDKit_OK:
    raise RuntimeError("RDKit not available. Please install RDKit to run this notebook.")

eval_fps = None  # use default ~15 if USE_MULTI_FP, else SINGLE_EVAL_FP
if not USE_MULTI_FP:
    eval_fps = [SINGLE_EVAL_FP]

results = evaluate_molforge(
    smiles_list=SMILES_LIST,
    input_fp=INPUT_FP,
    output_repr=OUTPUT_REPR,
    molforge_model_name=MODEL_NAME,
    molforge_checkpoint=MOLFORGE_CHECKPOINT,
    eval_fps=eval_fps,
    device=device,
    tokenizer_vocab_json=TOKENIZER_VOCAB_JSON,
    hashed_nbits=HASHED_NBITS,
    return_predictions=True,
    use_multi_fp=USE_MULTI_FP,
    single_eval_fp=SINGLE_EVAL_FP,
    demo_mode=DEMO_MODE,
)

print("# Inputs:")
print("n =", results["n"]) 
print("input_fp =", INPUT_FP, "| output_repr =", OUTPUT_REPR)
print("demo_mode =", DEMO_MODE)
print("\n# Breakdown:")
for k, v in results["breakdown"].items():
    print(f"{k:14s} : {v}")
print("\n# Eval FPS Average:")
for k, v in results["eval_fps_average"].items():
    print(f"{k:24s}: {v}")


## 9) Taula de mètriques per fingerprint d'avaluació

In [None]:

import pandas as pd
from caas_jupyter_tools import display_dataframe_to_user

df = pd.DataFrame(results["eval_fps"]).T
df = df[["n_pairs", "top1", "mean_tc", "p01_threshold", "p50", "p90", "p99"]]
df.sort_values(by=["top1", "mean_tc"], ascending=False, inplace=True)

display_dataframe_to_user("MolForge per-FP metrics", df)
df.head()


## 10) Gràfic opcional — histograma de Tc per a una huella d'avaluació
Selecciona una fingerprint de `results['eval_fps']` per visualitzar la distribució de Tc.

In [None]:

import matplotlib.pyplot as plt

# Choose an FP present in results to plot distribution
fp_to_plot = SINGLE_EVAL_FP if not USE_MULTI_FP else ("ECFP4" if "ECFP4" in results["eval_fps"] else list(results["eval_fps"].keys())[0])

# Recompute per-pair Tc list for the chosen FP to plot
tcs = []

for gt, pred in zip(SMILES_LIST, results["predictions"]):
    gt_m = to_mol(gt); pred_m = to_mol(pred if OUTPUT_REPR.upper()=="SMILES" else selfies_to_smiles(pred))
    if gt_m is None or pred_m is None:
        continue
    tc = tanimoto(compute_fp(gt_m, fp_to_plot, nBits=HASHED_NBITS), compute_fp(pred_m, fp_to_plot, nBits=HASHED_NBITS))
    tcs.append(tc)

plt.figure()
plt.hist(tcs, bins=10)
plt.title(f"Tc distribution for {fp_to_plot}")
plt.xlabel("Tc (Tanimoto)"); plt.ylabel("Count")
plt.show()


## 11) Notes ràpides d'ús
- RDKit i MolForge ja estan instal·lats via entorn.
- Usa el `device` detectat per moure el model a GPU/CPU.
- No cal repetir instruccions d'instal·lació aquí.