# Scalable Autoregressive 3D Molecule Generation

This notebook samples molecules from Quetzal models pretrained on the QM9 and GEOM datasets.

[Paper](https://arxiv.org/abs/2505.13791) | [GitHub](https://github.com/aspuru-guzik-group/quetzal)

## Setup

In [1]:
!git clone https://github.com/aspuru-guzik-group/quetzal.git
import os
os.chdir('quetzal')
os.environ['PYTHONPATH'] = '/env/python:/content/quetzal'
!pip uninstall datasets -y # name collision
!pip install -q rdkit py3Dmol lightning==2.5.0.post0

Cloning into 'quetzal'...
remote: Enumerating objects: 98, done.[K
remote: Counting objects: 100% (98/98), done.[K
remote: Compressing objects: 100% (85/85), done.[K
remote: Total 98 (delta 11), reused 95 (delta 11), pack-reused 0 (from 0)[K
Receiving objects: 100% (98/98), 15.74 MiB | 16.45 MiB/s, done.
Resolving deltas: 100% (11/11), done.
Found existing installation: datasets 4.0.0
Uninstalling datasets-4.0.0:
  Successfully uninstalled datasets-4.0.0
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.4/40.4 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m815.2/815.2 kB[0m [31m25.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.1/36.1 MB[0m [31m52.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.5/65.5 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [

In [2]:
# Download pretrained checkpoints
!wget https://huggingface.co/auhcheng/quetzal/resolve/main/original.ckpt # best qm9 model
# !wget https://huggingface.co/auhcheng/quetzal/resolve/main/geom.ckpt # best geom model

--2025-09-06 04:02:07--  https://huggingface.co/auhcheng/quetzal/resolve/main/original.ckpt
Resolving huggingface.co (huggingface.co)... 3.175.34.113, 3.175.34.8, 3.175.34.95, ...
Connecting to huggingface.co (huggingface.co)|3.175.34.113|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.hf.co/repos/9c/25/9c2583bd7c1978d019177492ed9db53806548c10d6ed9d29add054597facac94/3bca624b49b45bc93fc8fc46525f76b4caa2b8e1bcfa3f22732a1981199954a1?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27original.ckpt%3B+filename%3D%22original.ckpt%22%3B&Expires=1757134927&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc1NzEzNDkyN319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zLzljLzI1LzljMjU4M2JkN2MxOTc4ZDAxOTE3NzQ5MmVkOWRiNTM4MDY1NDhjMTBkNmVkOWQyOWFkZDA1NDU5N2ZhY2FjOTQvM2JjYTYyNGI0OWI0NWJjOTNmYzhmYzQ2NTI1Zjc2YjRjYWEyYjhlMWJjZmEzZjIyNzMyYTE5ODExOTk5NTRhMT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW

## Generate molecules

In [4]:
import torch
DIFF_STEPS = 60
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
import sys; sys.path.append(".")
from train import Config, LitQuetzal

kwargs = {"bsz": 25, "device": DEVICE, "num_steps": DIFF_STEPS, "pbar": True, "max_len": 32}

### qm9
ckpt_name = "original"
kwargs["max_len"] = 32

### geom
# ckpt_name = "geom"
# kwargs["max_len"] = 192

ckpt = f"{ckpt_name}.ckpt"
lit = LitQuetzal.load_from_checkpoint(ckpt, map_location=DEVICE)
model = lit.ema.module
model.eval();

out = model.generate(**kwargs)
samples, all_traj = out

 72%|███████▏  | 23/32 [00:10<00:03,  2.28it/s]


In [5]:
from draw import show_grid
show_grid(samples, 5, 5)

<py3Dmol.view at 0x7addfa07a360>

In [6]:
from draw import show_traj

b_idx = 0 # visualize the first molecule
show_traj(out, b_idx=b_idx, interval=10)

<py3Dmol.view at 0x7ade0075b740>