# Inference demo for QUARC

## Overview

QUARC is a multi-stage condition recommendation framework that predicts:
1. **Stage 1**: Agents
2. **Stage 2**: Temperature
3. **Stage 3**: Reactant amounts (equivalence ratio)
4. **Stage 4**: Agent amounts (equivalence ratio)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from chemprop.featurizers import CondensedGraphOfReactionFeaturizer

import warnings
warnings.filterwarnings("ignore")

from quarc.models.modules.agent_encoder import AgentEncoder
from quarc.models.modules.agent_standardizer import AgentStandardizer
from quarc.models.modules.rxn_encoder import ReactionClassEncoder
from quarc.data.datapoints import AgentRecord, ReactionDatum
from quarc.data.eval_datasets import EvaluationDatasetFactory
from quarc.data.binning import BinningConfig
from quarc.utils.smiles_utils import parse_rxn_smiles

from quarc.predictors.model_factory import load_models_from_yaml
from quarc.predictors.multistage_predictor import EnumeratedPredictor
from quarc.predictors.base import PredictionList

from quarc.settings import load as load_settings

cfg = load_settings()

/home/xiaoqis/miniconda3/envs/quarc-env/lib/python3.11/site-packages/lightning/fabric/__init__.py:40: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.


In [3]:
# Load supporting components
a_enc = AgentEncoder(class_path=cfg.processed_data_dir / "agent_encoder/agent_encoder_list.json")
a_standardizer = AgentStandardizer(
    conv_rules=cfg.processed_data_dir / "agent_encoder/agent_rules_v1.json",
    other_dict=cfg.processed_data_dir / "agent_encoder/agent_other_dict.json",
)
rxn_encoder = ReactionClassEncoder(
    class_path=cfg.pistachio_namerxn_path
)
featurizer = CondensedGraphOfReactionFeaturizer(mode_="REAC_DIFF")

## Prepare input data

The full version of QUARC (the paper version) requires:
- **Reaction SMILES**: To use the GNN models, atom-mapped SMILES `reactants>agents>products` is required. If atom-mapping is unavailable, the FFN models should be used. 
- **NameRxn Code**: Hierarchical reaction classification (e.g., "3.1.1" for Bromo Suzuki Coupling)

In [4]:
# Input reactions should have atom-mapped SMILES and the corresponding NameRxn code
example_inputs = [
    ('[CH3:1][O:2][C:3](=[O:4])[c:5]1[cH:6][cH:7][c:8]2[nH:9][c:10]([c:11]([c:12]2[cH:13]1)[CH2:14][c:15]1[cH:16][cH:17][c:18]([cH:19][c:20]1[Cl:21])I)[CH3:22].[CH3:23][C:24]([CH3:25])([CH3:26])[SH:27]>c1ccc(cc1)[P](c1ccccc1)(c1ccccc1)[Pd]([P](c1ccccc1)(c1ccccc1)c1ccccc1)([P](c1ccccc1)(c1ccccc1)c1ccccc1)[P](c1ccccc1)(c1ccccc1)c1ccccc1.CCCCN(CCCC)CCCC.CN(C)C=O>[CH3:1][O:2][C:3](=[O:4])[c:5]1[cH:6][cH:7][c:8]2[nH:9][c:10]([c:11]([c:12]2[cH:13]1)[CH2:14][c:15]1[cH:16][cH:17][c:18]([cH:19][c:20]1[Cl:21])[S:27][C:24]([CH3:23])([CH3:25])[CH3:26])[CH3:22]', '1.8.7'),
    ('Cl.[O:1]=[C:2]1[CH2:3][CH2:4][CH:5]([C:6]([NH:7]1)=[O:8])[N:9]1[CH2:10][c:11]2[c:12]([cH:13][cH:14][cH:15][c:16]2[O:17][CH2:18][c:19]2[cH:20][cH:21][cH:22][c:23]([cH:24]2)[CH2:25]Br)[C:26]1=[O:27].[F:28][c:29]1[cH:30][cH:31][c:32]([cH:33][cH:34]1)[CH:35]1[CH2:36][CH2:37][NH:38][CH2:39][CH2:40]1>CC(C)N(CC)C(C)C.CC#N>[O:1]=[C:2]1[CH2:3][CH2:4][CH:5]([C:6]([NH:7]1)=[O:8])[N:9]1[CH2:10][c:11]2[c:16]([cH:15][cH:14][cH:13][c:12]2[C:26]1=[O:27])[O:17][CH2:18][c:19]1[cH:20][cH:21][cH:22][c:23]([cH:24]1)[CH2:25][N:38]1[CH2:37][CH2:36][CH:35]([CH2:40][CH2:39]1)[c:32]1[cH:31][cH:30][c:29]([cH:34][cH:33]1)[F:28]', '1.6.2'),
    ('CC(C)(C)OC(=O)O[C:1](=[O:2])[O:3][C:4]([CH3:5])([CH3:6])[CH3:7].[CH3:8][c:9]1[cH:10][c:11]([nH:12][cH:13]1)[CH:14]=[O:15]>CN(C)c1ccncc1.CC#N>[CH3:5][C:4]([CH3:6])([CH3:7])[O:3][C:1](=[O:2])[n:12]1[cH:13][c:9]([cH:10][c:11]1[CH:14]=[O:15])[CH3:8]', '5.1.1'),
]

reactions = []
for i, (smi, rxn_class) in enumerate(example_inputs):
    reactants, agents, products = parse_rxn_smiles(smi)

    reactions.append(
        ReactionDatum(
            rxn_smiles=smi, # required for GNN models
            reactants=[AgentRecord(smiles=r, amount=None) for r in reactants], # required for FFN models
            agents=[AgentRecord(smiles=a, amount=None) for a in agents], # required for FFN models
            products=[AgentRecord(smiles=p, amount=None) for p in products], # required for FFN models
            rxn_class=rxn_class,
            document_id=f"demo_{i}",
            date=None,
            temperature=None,
        )
    )

# unified dataset for both FFN and GNN models
dataset = EvaluationDatasetFactory.for_inference(
    data=reactions,
    agent_standardizer=a_standardizer,
    agent_encoder=a_enc,
    rxn_encoder=rxn_encoder,
    featurizer=featurizer,
)

## Load trained models 

The inference workflow of QUARC chains the four individually trained models togther. First using the reaciton and the reaction class input we predict agents. Then the predicted agents together with the reaction information are used to predict temperature, reactant amounts and agent amounts simulteously.

It generates 80 candidate conditions by enumerating combinations of:
- Top 10 predicted agents
- Top 2 predicted temperatures  
- Top 2 predicted reactant amount groups
- Top 2 predicted agent amount groups

These candidates are then ranked using weights emprically optimized for top-5 and top-10 accuracies.

Available pre-configured pipelines:
- **FFN Pipeline**: FFNs for all stages (default)
- **GNN Pipeline**: GNNs for all stages
<!-- - **Hybrid Pipeline**: Stage 1 (GNN) + Stages 2-4 (FFN) -->


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
models, model_types, weights = load_models_from_yaml(cfg.checkpoints_dir / "ffn_pipeline.yaml", device)
# models, model_types, weights = load_models_from_yaml(cfg.checkpoints_dir / "gnn_pipeline.yaml", device)


## Setup predictor and run inference

The `EnumeratedPredictor` combines all four stages into a single prediction pipeline:


In [6]:
predictor = EnumeratedPredictor(
    agent_model=models["agent"],
    temperature_model=models["temperature"],
    reactant_amount_model=models["reactant_amount"],
    agent_amount_model=models["agent_amount"],
    model_types=model_types,
    agent_encoder=a_enc,
    device=device,
    weights=weights['use_top_5'],  # weights optimized for top 5 or top 10 accuracy
    use_geometric=weights['use_geometric'],
)

In [7]:
for reaction in dataset:
    predictions = predictor.predict(reaction, top_k=2)
    print(predictions)
    print("-" * 30)


PredictionList:
  Doc ID: demo_0
  Reaction class: 1.8.7
  SMILES: [CH3:1][O:2][C:3](=[O:4])[c:5]1[cH:6][cH:7][c:8]2[...
  Predictions (2):
    [1]
      Agents: [8, 16, 34, 54]
      Temperature bin: 20
      Reactant bins: [1, 7]
      Agent amount bins: [(8, 23), (16, 8), (34, 2), (54, 1)]
      Score: 0.5819
      Meta: s1_score: 0.2197, s2_score: 0.7057, s3_score: 1.0000, s4_score: 0.6470

    [2]
      Agents: [8, 16, 34, 54]
      Temperature bin: 20
      Reactant bins: [1, 7]
      Agent amount bins: [(8, 23), (16, 8), (34, 2), (54, 2)]
      Score: 0.5751
      Meta: s1_score: 0.2197, s2_score: 0.7057, s3_score: 1.0000, s4_score: 0.6206
------------------------------
PredictionList:
  Doc ID: demo_1
  Reaction class: 1.6.2
  SMILES: Cl.[O:1]=[C:2]1[CH2:3][CH2:4][CH:5]([C:6]([NH:7]1)...
  Predictions (2):
    [1]
      Agents: [10, 12]
      Temperature bin: 13
      Reactant bins: [3, 1, 3]
      Agent amount bins: [(10, 10), (12, 26)]
      Score: 0.8379
      Meta: s1_score

### Annotate predictions with binning labels

In [9]:
def format_predictions(predictions: PredictionList, agent_encoder: AgentEncoder, top_k: int = 5):
    binning_config = BinningConfig.default()
    temp_labels = binning_config.get_bin_labels("temperature")
    reactant_labels = binning_config.get_bin_labels("reactant")
    agent_labels = binning_config.get_bin_labels("agent")

    print(f"doc id: {predictions.doc_id}")
    print(f"rxn class: {predictions.rxn_class}")
    print(f"rxn smiles: {predictions.rxn_smiles}")
    print("=" * 60)

    reactants_smiles, _, _ = parse_rxn_smiles(predictions.rxn_smiles)

    for i, pred in enumerate(predictions.predictions[:top_k]):
        agent_smiles = agent_encoder.decode(pred.agents)
        temp_label = temp_labels[pred.temp_bin]
        reactant_labels_list = [reactant_labels[bin_idx] for bin_idx in pred.reactant_bins]

        agent_amounts = []
        for agent_idx, bin_idx in pred.agent_amount_bins:
            agent_smi = agent_encoder.decode([agent_idx])[0]
            amount_label = agent_labels[bin_idx]
            agent_amounts.append(f"{agent_smi} -> {amount_label}")

        reactant_amounts = []
        for reactant_smi, reactant_label in zip(reactants_smiles, reactant_labels_list):
            reactant_amounts.append(f"{reactant_smi} -> {reactant_label}")

        print(f"\nrank {i+1} (score: {pred.score:.4f})")
        print("-" * 30)
        print(f"agents: {', '.join(agent_smiles)}")
        print(f"temperature: {temp_label}")
        print(f"reactant amounts:")
        for reactant_amount in reactant_amounts:
            print(f"  - {reactant_amount}")
        print(f"agent amounts:")
        for agent_amount in agent_amounts:
            print(f"  - {agent_amount}")


for reaction in dataset:
    predictions = predictor.predict(reaction, top_k=3)
    format_predictions(predictions, a_enc, top_k=3)
    print("=" * 30)

doc id: demo_0
rxn class: 1.8.7
rxn smiles: [CH3:1][O:2][C:3](=[O:4])[c:5]1[cH:6][cH:7][c:8]2[nH:9][c:10]([c:11]([c:12]2[cH:13]1)[CH2:14][c:15]1[cH:16][cH:17][c:18]([cH:19][c:20]1[Cl:21])I)[CH3:22].[CH3:23][C:24]([CH3:25])([CH3:26])[SH:27]>c1ccc(cc1)[P](c1ccccc1)(c1ccccc1)[Pd]([P](c1ccccc1)(c1ccccc1)c1ccccc1)([P](c1ccccc1)(c1ccccc1)c1ccccc1)[P](c1ccccc1)(c1ccccc1)c1ccccc1.CCCCN(CCCC)CCCC.CN(C)C=O>[CH3:1][O:2][C:3](=[O:4])[c:5]1[cH:6][cH:7][c:8]2[nH:9][c:10]([c:11]([c:12]2[cH:13]1)[CH2:14][c:15]1[cH:16][cH:17][c:18]([cH:19][c:20]1[Cl:21])[S:27][C:24]([CH3:23])([CH3:25])[CH3:26])[CH3:22]

rank 1 (score: 0.5819)
------------------------------
agents: C1COCCO1, O=C(O[Cs])O[Cs], CC(=O)O[Pd]OC(C)=O, CC1(C)c2cccc(P(c3ccccc3)c3ccccc3)c2Oc2c(P(c3ccccc3)c3ccccc3)cccc21
temperature: [90.00, 100.00)
reactant amounts:
  - COC(=O)c1ccc2[nH]c(C)c(Cc3ccc(I)cc3Cl)c2c1 -> [0.95, 1.05)
  - CC(C)(C)S -> [1.75, 2.25)
agent amounts:
  - C1COCCO1 -> [65.50, 75.50)
  - O=C(O[Cs])O[Cs] -> [1.75, 2.25)
  - CC(=