In [11]:
import io
import os
from dataclasses import dataclass
from pathlib import Path
from zipfile import ZipFile

import ase.io
import hydra
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pymatgen.core.structure import Structure
from pymatgen.io.ase import AseAtomsAdaptor
from tqdm import tqdm

from mattergen.common.data.chemgraph import ChemGraph
from mattergen.common.data.collate import collate
from mattergen.common.data.condition_factory import ConditionLoader
from mattergen.common.data.num_atoms_distribution import NUM_ATOMS_DISTRIBUTIONS
from mattergen.common.data.types import TargetProperty
from mattergen.common.utils.data_utils import lattice_matrix_to_params_torch
from mattergen.common.utils.eval_utils import (
    MatterGenCheckpointInfo,
    get_crystals_list,
    load_model_diffusion,
    make_structure,
    save_structures,
)
from mattergen.common.utils.globals import DEFAULT_SAMPLING_CONFIG_PATH, get_device
from mattergen.diffusion.lightning_module import DiffusionLightningModule
from mattergen.diffusion.sampling.pc_sampler import PredictorCorrector
from mattergen.generator import *


In [12]:
def draw_samples_from_sampler(
    sampler: PredictorCorrector,
    condition_loader: ConditionLoader,
    properties_to_condition_on: TargetProperty | None = None,
    output_path: Path | None = None,
    cfg: DictConfig | None = None,
    record_trajectories: bool = True,
) -> list:
    properties_to_condition_on = properties_to_condition_on or {}
    assert all([key in sampler.diffusion_module.model.cond_fields_model_was_trained_on for key in properties_to_condition_on.keys()])  # type: ignore

    all_samples_list = []
    all_trajs_list = []
    for conditioning_data, mask in tqdm(condition_loader, desc="Generating samples"):
        if record_trajectories:
            sample, mean, intermediate_samples = sampler.sample_with_record(conditioning_data, mask)
            all_trajs_list.extend(list_of_time_steps_to_list_of_trajectories(intermediate_samples))
        else:
            sample, mean = sampler.sample(conditioning_data, mask)
        all_samples_list.extend(mean.to_data_list())
    all_samples = collate(all_samples_list)
    assert isinstance(all_samples, ChemGraph)
    lengths, angles = lattice_matrix_to_params_torch(all_samples.cell)
    all_samples = all_samples.replace(lengths=lengths, angles=angles)

    generated_strucs = structure_from_model_output(
        all_samples["pos"].reshape(-1, 3),
        all_samples["atomic_numbers"].reshape(-1),
        all_samples["lengths"].reshape(-1, 3),
        all_samples["angles"].reshape(-1, 3),
        all_samples["num_atoms"].reshape(-1),
    )

    if output_path is not None:
        assert cfg is not None
        save_structures(output_path, generated_strucs)

        if record_trajectories:
            dump_trajectories(
                output_path=output_path,
                all_trajs_list=all_trajs_list,
            )

    return generated_strucs

In [None]:
output_path = "results"
pretrained_name = "mattergen_base"
batch_size = 16
num_batches = 1
properties_to_condition_on = {}
sampling_config_path = None
sampling_config_name = "default"
sampling_config_overrides = []
record_trajectories = True
diffusion_guidance_factor = 5.0
target_compositions = []

checkpoint_info = MatterGenCheckpointInfo.from_hf_hub(
    pretrained_name, config_overrides=[]
)
generator = CrystalGenerator(
    checkpoint_info=checkpoint_info,
    properties_to_condition_on=properties_to_condition_on,
    batch_size=batch_size,
    num_batches=num_batches,
    sampling_config_name=sampling_config_name,
    sampling_config_path=sampling_config_path,
    sampling_config_overrides=sampling_config_overrides,
    record_trajectories=record_trajectories,
    diffusion_guidance_factor=diffusion_guidance_factor,
    target_compositions_dict=target_compositions,
)
generated_structures = generator.generate(output_dir=Path(output_path))

last.ckpt:   0%|          | 0.00/513M [00:00<?, ?B/s]

config.yaml:   0%|          | 0.00/7.24k [00:00<?, ?B/s]


Model config:
adapter:
  adapter:
    _target_: mattergen.adapter.GemNetTAdapter
    atom_type_diffusion: mask
    denoise_atom_types: true
    gemnet:
      _target_: mattergen.common.gemnet.gemnet_ctrl.GemNetTCtrl
      atom_embedding:
        _target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
        emb_size: 512
        with_mask_type: true
      condition_on_adapt:
      - space_group
      cutoff: 7.0
      emb_size_atom: 512
      emb_size_edge: 512
      latent_dim: 512
      max_cell_images_per_dim: 5
      max_neighbors: 50
      num_blocks: 4
      num_targets: 1
      otf_graph: true
      regress_stress: true
      scale_file: /scratch/amlt_code/mattergen/common/gemnet/gemnet-dT.json
    hidden_dim: 512
    property_embeddings: {}
    property_embeddings_adapt:
      space_group:
        _target_: mattergen.property_embeddings.PropertyEmbedding
        conditional_embedding_module:
          _target_: mattergen.property_embeddings.SpaceGroupEmbeddingV

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with hydra.initialize_config_dir(os.path.abspath(str(sampling_config_path))):
INFO:mattergen.common.utils.eval_utils:Loading model from checkpoint: /home/ubt/.cache/huggingface/hub/models--microsoft--mattergen/snapshots/2092423e1f1ab5f7d792142257f0477a57628105/checkpoints/space_group/checkpoints/last.ckpt



Sampling config:
sampler_partial:
  _target_: mattergen.diffusion.sampling.classifier_free_guidance.GuidedPredictorCorrector.from_pl_module
  'N': 1000
  eps_t: 0.001
  _partial_: true
  guidance_scale: 5.0
  remove_conditioning_fn:
    _target_: mattergen.property_embeddings.SetUnconditionalEmbeddingType
  keep_conditioning_fn:
    _target_: mattergen.property_embeddings.SetConditionalEmbeddingType
  predictor_partials:
    pos:
      _target_: mattergen.diffusion.wrapped.wrapped_predictors_correctors.WrappedAncestralSamplingPredictor
      _partial_: true
    cell:
      _target_: mattergen.common.diffusion.predictors_correctors.LatticeAncestralSamplingPredictor
      _partial_: true
    atomic_numbers:
      _target_: mattergen.diffusion.d3pm.d3pm_predictors_correctors.D3PMAncestralSamplingPredictor
      predict_x0: true
      _partial_: true
  corrector_partials:
    pos:
      _target_: mattergen.diffusion.wrapped.wrapped_predictors_correctors.WrappedLangevinCorrector
      _par

Generating samples:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Generating samples:   0%|          | 0/1 [04:22<?, ?it/s]
  """For each field, regardless of whether the corruption process is SDE or D3PM, we guide the score in the same way here,


KeyboardInterrupt: 

## Preprocessea custom dataset

In [None]:
from mattergen.scripts import csv_to_dataset
import sys
sys.argv = [
    "notebook",  # dummy script name
    "--csv-folder", "test/raw1_processed",
    "--dataset-name", "raw1",
    "--cache-folder", "datasets/cache"
]

csv_to_dataset.main()

Processing test/raw1_processed/val.csv


Parsing CIFs:   0%|          | 0/15 [00:00<?, ?it/s]

Converting structures to numpy:   0%|          | 0/15 [00:00<?, ?it/s]

Storing cached dataset in datasets/cache/raw1/val.
Processing test/raw1_processed/test.csv


Parsing CIFs:   0%|          | 0/15 [00:00<?, ?it/s]

Converting structures to numpy:   0%|          | 0/15 [00:00<?, ?it/s]

Storing cached dataset in datasets/cache/raw1/test.
Processing test/raw1_processed/train.csv


Parsing CIFs:   0%|          | 0/70 [00:00<?, ?it/s]

Converting structures to numpy:   0%|          | 0/70 [00:00<?, ?it/s]

Storing cached dataset in datasets/cache/raw1/train.


## Training from a custom dataset

### Create datamodule/raw1.yaml file (batch size, epoch limit)

In [14]:
from mattergen.scripts.run import mattergen_main
import sys

sys.argv = [
    "notebook",
    "data_module=raw1",
    "~trainer.logger",
    "trainer.strategy=auto",  # ✅ fix here
    "trainer.num_nodes=1",            # optional
    "trainer.accelerator=gpu"         # optional but recommended
]

mattergen_main()  

provider=hydra, path=pkg://hydra.conf
provider=main, path=file:///home/ubt/Downloads/mattergen/mattergen/conf
provider=schema, path=structured://
provider=hydra, path=pkg://hydra.conf
provider=main, path=file:///home/ubt/Downloads/mattergen/mattergen/conf
provider=schema, path=structured://
data_module:
  _target_: mattergen.common.data.datamodule.CrystDataModule
  _recursive_: true
  properties: []
  dataset_transforms:
  - _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
    _partial_: true
  transforms:
  - _target_: mattergen.common.data.transform.symmetrize_lattice
    _partial_: true
  - _target_: mattergen.common.data.transform.set_chemical_system_string
    _partial_: true
  average_density: 0.07
  root_dir: /home/ubt/Downloads/mattergen/mattergen/../datasets/cache/raw1
  train_dataset:
    _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
    cache_path: /home/ubt/Downloads/mattergen/mattergen/../datasets/cache/raw1/train
    prop

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/ubt/miniconda3/envs/pyg/lib/python3.13/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type            | Params | Mode 
-------------------------------------------------------------
0 | diffusion_module | DiffusionModule | 44.6 M | train
-------------------------------------------------------------
44.6 M    Trainable params
22        Non-trainable params
44.6

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/ubt/miniconda3/envs/pyg/lib/python3.13/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/ubt/miniconda3/envs/pyg/lib/python3.13/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


## Load custom training data and execute generative model

In [15]:
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal, Optional
import numpy as np
from hydra import compose, initialize_config_dir
from omegaconf import DictConfig
import os


@dataclass
class MatterGenCheckpointInfo:
    model_path: Path
    config_path: Path
    load_epoch: Optional[int | Literal["best", "last"]] = "last"
    config_overrides: list[str] = field(default_factory=list)
    strict_checkpoint_loading: bool = True  # <-- reintroduced for compatibility

    def __post_init__(self):
        self.model_path = Path(self.model_path).resolve()
        self.config_path = Path(self.config_path).resolve()

        if not self.model_path.is_dir():
            raise FileNotFoundError(f"Model directory does not exist: {self.model_path}")
        if not (self.config_path / "config.yaml").exists():
            raise FileNotFoundError(f"'config.yaml' not found in {self.config_path}")

    @property
    def config(self) -> DictConfig:
        from hydra.core.global_hydra import GlobalHydra
        if GlobalHydra.instance().is_initialized():
            GlobalHydra.instance().clear()

        with initialize_config_dir(config_dir=str(self.config_path), version_base="1.1"):
            return compose(config_name="config", overrides=self.config_overrides)

    @property
    def checkpoint_path(self) -> Path:
        ckpts = sorted(self.model_path.glob("*.ckpt"))
        if not ckpts:
            raise FileNotFoundError(f"No checkpoints found in {self.model_path}")

        if self.load_epoch == "last":
            last_ckpt = [ckpt for ckpt in ckpts if "last.ckpt" in ckpt.name]
            if not last_ckpt:
                raise ValueError("No 'last.ckpt' file found.")
            return last_ckpt[0]

        if self.load_epoch == "best":
            ckpts_named = [ckpt for ckpt in ckpts if "val_loss" in ckpt.name]
            if not ckpts_named:
                raise ValueError("No checkpoint with validation loss in filename.")
            ckpts_named.sort(key=lambda x: float(x.name.split("val_loss=")[-1].split(".")[0]))
            return ckpts_named[0]

        if isinstance(self.load_epoch, int):
            for ckpt in ckpts:
                if f"epoch={self.load_epoch}" in ckpt.name:
                    return ckpt
            raise ValueError(f"Checkpoint for epoch={self.load_epoch} not found.")

        raise ValueError(f"Invalid load_epoch: {self.load_epoch}")


In [18]:
# === Parameters ===
cfg_path="./outputs/singlerun/2025-05-07/05-10-07/lightning_logs/version_0"
chk_path=f"{cfg_path}/checkpoints"
output_path = "results"
batch_size = 64
num_batches = 1
properties_to_condition_on = {}
sampling_config_path = None
sampling_config_name = "default"
sampling_config_overrides = []
record_trajectories = True
diffusion_guidance_factor = 0.0
target_compositions = []  # e.g., [{'Na': 4, 'Cl': 4, 'H2O': 4}]

checkpoint_info = MatterGenCheckpointInfo(
    model_path= Path(chk_path),
    config_path=Path(cfg_path)
)

# ✅ Now create generator
generator = CrystalGenerator(
    checkpoint_info=checkpoint_info,
    properties_to_condition_on=properties_to_condition_on,
    batch_size=batch_size,
    num_batches=num_batches,
    sampling_config_name=sampling_config_name,
    sampling_config_path=sampling_config_path,
    sampling_config_overrides=sampling_config_overrides,
    record_trajectories=record_trajectories,
    diffusion_guidance_factor=diffusion_guidance_factor,
    target_compositions_dict=target_compositions,
)

# ✅ Run generation
generated_structures = generator.generate(output_dir=Path(output_path))



Model config:
provider=hydra, path=pkg://hydra.conf
provider=main, path=file:///home/ubt/Downloads/mattergen/outputs/singlerun/2025-05-07/05-10-07/lightning_logs/version_0
provider=schema, path=structured://
auto_resume: true
checkpoint_path: null
data_module:
  _recursive_: true
  _target_: mattergen.common.data.datamodule.CrystDataModule
  average_density: 0.07
  batch_size:
    train: 1
    val: 1
  dataset_transforms:
  - _partial_: true
    _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
  max_epochs: 100
  num_workers:
    train: 0
    val: 0
  properties: []
  root_dir: /home/ubt/Downloads/mattergen/mattergen/../datasets/cache/raw1
  train_dataset:
    _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
    cache_path: /home/ubt/Downloads/mattergen/mattergen/../datasets/cache/raw1/train
    dataset_transforms:
    - _partial_: true
      _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
    properties: []
  

Generating samples:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Generating samples: 100%|██████████| 1/1 [04:56<00:00, 296.02s/it]
