In [3]:
import io
import os
from dataclasses import dataclass
from pathlib import Path
from zipfile import ZipFile

import ase.io
import hydra
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pymatgen.core.structure import Structure
from pymatgen.io.ase import AseAtomsAdaptor
from tqdm import tqdm

from mattergen.common.data.chemgraph import ChemGraph
from mattergen.common.data.collate import collate
from mattergen.common.data.condition_factory import ConditionLoader
from mattergen.common.data.num_atoms_distribution import NUM_ATOMS_DISTRIBUTIONS
from mattergen.common.data.types import TargetProperty
from mattergen.common.utils.data_utils import lattice_matrix_to_params_torch
from mattergen.common.utils.eval_utils import (
    MatterGenCheckpointInfo,
    get_crystals_list,
    load_model_diffusion,
    make_structure,
    save_structures,
)
from mattergen.common.utils.globals import DEFAULT_SAMPLING_CONFIG_PATH, get_device
from mattergen.diffusion.lightning_module import DiffusionLightningModule
from mattergen.diffusion.sampling.pc_sampler import PredictorCorrector
from mattergen.generator import *


In [5]:
def draw_samples_from_sampler(
    sampler: PredictorCorrector,
    condition_loader: ConditionLoader,
    properties_to_condition_on: TargetProperty | None = None,
    output_path: Path | None = None,
    cfg: DictConfig | None = None,
    record_trajectories: bool = True,
) -> list:
    properties_to_condition_on = properties_to_condition_on or {}
    assert all([key in sampler.diffusion_module.model.cond_fields_model_was_trained_on for key in properties_to_condition_on.keys()])  # type: ignore

    all_samples_list = []
    all_trajs_list = []
    for conditioning_data, mask in tqdm(condition_loader, desc="Generating samples"):
        if record_trajectories:
            sample, mean, intermediate_samples = sampler.sample_with_record(conditioning_data, mask)
            all_trajs_list.extend(list_of_time_steps_to_list_of_trajectories(intermediate_samples))
        else:
            sample, mean = sampler.sample(conditioning_data, mask)
        all_samples_list.extend(mean.to_data_list())
    all_samples = collate(all_samples_list)
    assert isinstance(all_samples, ChemGraph)
    lengths, angles = lattice_matrix_to_params_torch(all_samples.cell)
    all_samples = all_samples.replace(lengths=lengths, angles=angles)

    generated_strucs = structure_from_model_output(
        all_samples["pos"].reshape(-1, 3),
        all_samples["atomic_numbers"].reshape(-1),
        all_samples["lengths"].reshape(-1, 3),
        all_samples["angles"].reshape(-1, 3),
        all_samples["num_atoms"].reshape(-1),
    )

    if output_path is not None:
        assert cfg is not None
        save_structures(output_path, generated_strucs)

        if record_trajectories:
            dump_trajectories(
                output_path=output_path,
                all_trajs_list=all_trajs_list,
            )

    return generated_strucs

In [18]:
output_path = "results"
pretrained_name = "mattergen_base"
batch_size = 16
num_batches = 1
properties_to_condition_on = {}
sampling_config_path = None
sampling_config_name = "default"
sampling_config_overrides = []
record_trajectories = True
diffusion_guidance_factor = 0.0
target_compositions = []

checkpoint_info = MatterGenCheckpointInfo.from_hf_hub(
    pretrained_name, config_overrides=[]
)
generator = CrystalGenerator(
    checkpoint_info=checkpoint_info,
    properties_to_condition_on=properties_to_condition_on,
    batch_size=batch_size,
    num_batches=num_batches,
    sampling_config_name=sampling_config_name,
    sampling_config_path=sampling_config_path,
    sampling_config_overrides=sampling_config_overrides,
    record_trajectories=record_trajectories,
    diffusion_guidance_factor=diffusion_guidance_factor,
    target_compositions_dict=target_compositions,
)
generated_structures = generator.generate(output_dir=Path(output_path))

INFO:mattergen.common.utils.eval_utils:Loading model from checkpoint: /home/ubt/.cache/huggingface/hub/models--microsoft--mattergen/snapshots/2092423e1f1ab5f7d792142257f0477a57628105/checkpoints/mattergen_base/checkpoints/last.ckpt



Model config:
auto_resume: true
checkpoint_path: null
data_module:
  _recursive_: true
  _target_: mattergen.common.data.datamodule.CrystDataModule
  average_density: 0.05771451654022283
  batch_size:
    train: 32
    val: 32
  max_epochs: 2200
  num_workers:
    train: 0
    val: 0
  properties:
  - dft_bulk_modulus
  - dft_band_gap
  - dft_mag_density
  - ml_bulk_modulus
  - hhi_score
  - space_group
  - energy_above_hull
  root_dir: datasets/cache/alex_mp_20/
  train_dataset:
    _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
    cache_path: datasets/cache/alex_mp_20/train
    properties:
    - dft_bulk_modulus
    - dft_band_gap
    - dft_mag_density
    - ml_bulk_modulus
    - hhi_score
    - space_group
    - energy_above_hull
    transforms:
    - _partial_: true
      _target_: mattergen.common.data.transform.symmetrize_lattice
    - _partial_: true
      _target_: mattergen.common.data.transform.set_chemical_system_string
  transforms:
  - _partial_: 

Generating samples:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Generating samples: 100%|██████████| 1/1 [05:21<00:00, 321.92s/it]
