In [1]:
from pathlib import Path

run_id = "uirlcewv"
run_dir = Path(
    f"/net/csefiles/coc-fung-cluster/nima/shared/experiment-data/lltrainer/{run_id}/"
)
assert (
    run_dir.exists() and run_dir.is_dir()
), f"run_dir: {run_dir} does not exist or is not a directory"
print(run_dir)

/net/csefiles/coc-fung-cluster/nima/shared/experiment-data/lltrainer/uirlcewv


In [2]:
def make_config():
    from collections.abc import Callable
    from pathlib import Path
    from typing import Literal

    import nshtrainer as nt
    from jmppeft.configs.finetune.jmp_s import jmp_s_ft_config_
    from jmppeft.modules import loss
    from jmppeft.tasks.config import AdamWConfig
    from jmppeft.tasks.finetune import base, output_head
    from jmppeft.tasks.finetune import matbench_discovery as M
    from jmppeft.utils.param_specific_util import (
        make_parameter_specific_optimizer_config,
        parameter_specific_optimizer_config,
    )

    jmp_s_ckpt_path = Path(
        "/net/csefiles/coc-fung-cluster/nima/shared/checkpoints/jmp-s.pt"
    )

    # Set this to None if you want the run logs to be saved in the current directory
    project_root: Path | None = Path(
        "/net/csefiles/coc-fung-cluster/nima/shared/experiment-data/"
    )
    project_root.mkdir(exist_ok=True, parents=True)

    def jmp_s_(config: base.FinetuneConfigBase):
        ckpt_path = jmp_s_ckpt_path
        assert ckpt_path.exists(), f"Checkpoint not found: {ckpt_path}"

        jmp_s_ft_config_(config)
        config.ckpt_load.checkpoint = base.PretrainedCheckpointConfig(
            path=ckpt_path, ema=True
        )

        config.meta["jmp_kind"] = "s"
        config.name_parts.append("jmps")

    def parameter_specific_optimizers_(config: base.FinetuneConfigBase):
        if config.parameter_specific_optimizers is None:
            config.parameter_specific_optimizers = []

        match config.meta["jmp_kind"]:
            case "l":
                config.parameter_specific_optimizers.extend(
                    make_parameter_specific_optimizer_config(
                        config,
                        config.backbone.num_blocks,
                        {
                            "embedding": 0.3,
                            "blocks_0": 0.55,
                            "blocks_1": 0.40,
                            "blocks_2": 0.30,
                            "blocks_3": 0.40,
                            "blocks_4": 0.55,
                            "blocks_5": 0.625,
                        },
                    )
                )
            case "s":
                config.parameter_specific_optimizers.extend(
                    make_parameter_specific_optimizer_config(
                        config,
                        config.backbone.num_blocks,
                        {
                            "embedding": 0.3,
                            "blocks_0": 0.30,
                            "blocks_1": 0.40,
                            "blocks_2": 0.55,
                            "blocks_3": 0.625,
                        },
                    )
                )
            case _:
                raise ValueError(f"Invalid jmp_kind: {config.meta['jmp_kind']}")

    def parameter_specific_optimizers_energy_references_(
        config: base.FinetuneConfigBase,
        lr_multiplier: float = 0.1,
    ):
        if not config.parameter_specific_optimizers:
            config.parameter_specific_optimizers = []

        if energy_ref_heads := [
            t
            for t in config.graph_targets
            if isinstance(t, output_head.ReferencedScalarTargetConfig)
        ]:
            config.parameter_specific_optimizers.extend(
                parameter_specific_optimizer_config(
                    config,
                    [
                        {
                            "name": f"{energy_ref_head.name}.ref",
                            "lr_multiplier": lr_multiplier,
                            "parameter_patterns": [
                                f"graph_outputs._module_dict.ft_mlp_{energy_ref_head.name}.references.*"
                            ],
                        }
                        for energy_ref_head in energy_ref_heads
                    ],
                )
            )

        elif allegro_heads := [
            t
            for t in config.graph_targets
            if isinstance(t, output_head.AllegroScalarTargetConfig)
        ]:
            config.parameter_specific_optimizers.extend(
                parameter_specific_optimizer_config(
                    config,
                    [
                        {
                            "name": f"{h.name}.scales",
                            "lr_multiplier": lr_multiplier,
                            "parameter_patterns": [
                                f"graph_outputs._module_dict.ft_mlp_{h.name}.per_atom_scales.*",
                                f"graph_outputs._module_dict.ft_mlp_{h.name}.per_atom_shifts.*",
                                *(
                                    [
                                        f"graph_outputs._module_dict.ft_mlp_{h.name}.pairwise_scales.*"
                                    ]
                                    if h.edge_level_energies
                                    else []
                                ),
                            ],
                        }
                        for h in allegro_heads
                    ],
                )
            )
        else:
            raise ValueError("No energy reference or allegro heads found")

    def direct_(config: base.FinetuneConfigBase):
        config.backbone.regress_forces = True
        config.backbone.direct_forces = True
        config.backbone.regress_energy = True
        config.name_parts.append("direct")

    def ln_(
        config: base.FinetuneConfigBase,
        *,
        lr_multiplier: float | None,
    ):
        config.backbone.ln_per_layer = True
        config.backbone.scale_factor_to_ln = True

        if lr_multiplier is not None:
            if config.parameter_specific_optimizers is None:
                config.parameter_specific_optimizers = []

            config.parameter_specific_optimizers = [
                *parameter_specific_optimizer_config(
                    config,
                    [
                        {
                            "name": "ln",
                            "lr_multiplier": lr_multiplier,
                            "parameter_patterns": [
                                "backbone.h_lns.*",
                                "backbone.m_lns.*",
                                "backbone.*.scale*.ln.*",
                            ],
                        }
                    ],
                ),
                *config.parameter_specific_optimizers,
            ]

        config.name_parts.append("ln")

    def pos_aug_(config: base.FinetuneConfigBase, *, std: float):
        config.pos_noise_augmentation = base.PositionNoiseAugmentationConfig(
            system_corrupt_prob=0.75,
            atom_corrupt_prob=0.5,
            noise_std=std,
        )
        config.name_parts.append(f"posaug_std{std}")

    def data_config_(
        config: M.MatbenchDiscoveryConfig,
        *,
        batch_size: int,
        reference: bool,
    ):
        config.batch_size = batch_size
        config.name_parts.append(f"bsz{batch_size}")

        def dataset_fn(split: Literal["train", "val", "test"]):
            return base.FinetuneMPTrjHuggingfaceDatasetConfig(
                split=split,
                energy_column_mapping={
                    "y": "corrected_total_energy_referenced",
                    "y_relaxed": "corrected_total_energy_relaxed_referenced",
                }
                if reference
                else {
                    "y": "corrected_total_energy",
                    "y_relaxed": "corrected_total_energy_relaxed",
                },
            )

        config.train_dataset = dataset_fn("train")
        config.val_dataset = dataset_fn("val")
        config.test_dataset = dataset_fn("test")

        if reference:
            config.name_parts.append("linrefenergy")
        else:
            config.name_parts.append("totalenergy")

        # Set data config
        config.num_workers = 7

        # Balanced batch sampler
        config.use_balanced_batch_sampler = True
        config.trainer.use_distributed_sampler = False

    def output_heads_config_(
        config: M.MatbenchDiscoveryConfig,
        *,
        relaxed_energy: bool,
        mace_energy_loss: bool,
        mace_force_loss: bool,
        energy_coefficient: float,
        force_coefficient: float,
        stress_coefficient: float,
    ):
        energy_loss = loss.HuberLossConfig(delta=0.01)
        if mace_energy_loss:
            energy_loss = loss.MACEHuberEnergyLossConfig(delta=0.01)
            config.name_parts.append("maceenergy")

        force_loss = loss.HuberLossConfig(delta=0.01)
        if mace_force_loss:
            force_loss = loss.MACEHuberLossConfig(delta=0.01)
            config.name_parts.append("maceforce")

        # Energy head
        config.graph_targets.append(
            output_head.AllegroScalarTargetConfig(
                name="y",
                loss_coefficient=energy_coefficient,
                loss=energy_loss.model_copy(),
                reduction="sum",
                max_atomic_number=config.backbone.num_elements,
                edge_level_energies=True,
            )
        )
        if relaxed_energy:
            # Relaxed Energy head
            config.graph_targets.append(
                output_head.AllegroScalarTargetConfig(
                    name="y_relaxed",
                    loss_coefficient=energy_coefficient / 2.0,
                    loss=energy_loss.model_copy(),
                    reduction="sum",
                    max_atomic_number=config.backbone.num_elements,
                    edge_level_energies=True,
                )
            )

            config.name_parts.append("rele")
        # Stress head
        config.graph_targets.append(
            output_head.DirectStressTargetConfig(
                name="stress",
                loss_coefficient=stress_coefficient,
                loss=loss.HuberLossConfig(delta=0.01),
                reduction="mean",
            )
        )
        # Force head
        config.node_targets.append(
            output_head.NodeVectorTargetConfig(
                name="force",
                loss_coefficient=force_coefficient,
                loss=force_loss,
                reduction="sum",
            )
        )

        config.name_parts.append(f"ec{energy_coefficient}")
        config.name_parts.append(f"fc{force_coefficient}")
        config.name_parts.append(f"sc{stress_coefficient}")

    def optimization_config_(
        config: M.MatbenchDiscoveryConfig,
        *,
        lr: float,
    ):
        config.optimizer = AdamWConfig(
            lr=lr,
            amsgrad=False,
            betas=(0.9, 0.95),
            weight_decay=0.1,
        )
        config.lr_scheduler = base.WarmupCosRLPConfig(
            warmup_epochs=1,
            warmup_start_lr_factor=1.0e-1,
            should_restart=False,
            max_epochs=128,
            min_lr_factor=0.5,
            rlp=base.RLPConfig(patience=5, factor=0.8),
        )
        config.trainer.optimizer.gradient_clipping = nt.model.GradientClippingConfig(
            value=2.0,
            algorithm="value",
        )

        config.name_parts.append(f"lr{lr}")

    def create_config(config_fn: Callable[[M.MatbenchDiscoveryConfig], None]):
        config = M.MatbenchDiscoveryConfig.draft()

        config.trainer.precision = "16-mixed-auto"
        config.trainer.set_float32_matmul_precision = "medium"

        config.project = "jmp_mptrj"
        config.name = "mptrj"
        config_fn(config)
        config.backbone.qint_tags = [0, 1, 2]

        config.primary_metric = nt.MetricConfig(
            name="matbench_discovery/force_mae", mode="min"
        )

        if project_root:
            config.with_project_root_(project_root)
        return config

    config = create_config(jmp_s_)
    config.parameter_specific_optimizers = []
    config.max_neighbors = M.MaxNeighbors(main=25, aeaint=20, aint=1000, qint=8)
    config.cutoffs = M.Cutoffs.from_constant(12.0)
    data_config_(config, reference=True, batch_size=40)
    optimization_config_(config, lr=8.0e-5)
    ln_(config, lr_multiplier=1.5)
    direct_(config=config)
    output_heads_config_(
        config,
        relaxed_energy=True,
        mace_energy_loss=True,
        mace_force_loss=True,
        energy_coefficient=5.0,
        force_coefficient=10.0,
        stress_coefficient=100.0,
    )
    parameter_specific_optimizers_(config)
    parameter_specific_optimizers_energy_references_(config, lr_multiplier=0.1)
    pos_aug_(config, std=0.01)
    config.per_graph_radius_graph = True
    config.ignore_graph_generation_errors = False

    config = config.finalize()

    return config


config = make_config()
config

TensorBoard/TensorBoardX not found. Disabling TensorBoardLogger. Please install TensorBoard with `pip install tensorboard` or TensorBoardX with `pip install tensorboardx` to enable TensorBoard logging.
Type checking the following modules: ('jmppeft',)


MatbenchDiscoveryConfig(id='vbio0fwy', name='mptrj', name_parts=['jmps', 'bsz40', 'linrefenergy', 'lr8e-05', 'ln', 'direct', 'maceenergy', 'maceforce', 'rele', 'ec5.0', 'fc10.0', 'sc100.0', 'posaug_std0.01'], project='jmp_mptrj', directory=DirectoryConfig(project_root=PosixPath('/net/csefiles/coc-fung-cluster/nima/shared/experiment-data')), trainer=TrainerConfig(optimizer=OptimizationConfig(log_grad_norm=True, gradient_clipping=GradientClippingConfig(value=2.0, algorithm='value')), early_stopping=EarlyStoppingConfig(patience=50, min_lr=1e-08), precision='fp16-mixed', max_epochs=500, max_time='07:00:00:00', use_distributed_sampler=False, set_float32_matmul_precision='medium'), primary_metric=MetricConfig(name='matbench_discovery/force_mae', mode='min'), meta={'jmp_kind': 's'}, train_dataset=FinetuneMPTrjHuggingfaceDatasetConfig(split='train', energy_column_mapping={'y': 'corrected_total_energy_referenced', 'y_relaxed': 'corrected_total_energy_relaxed_referenced'}), val_dataset=FinetuneM

In [3]:
import yaml

config_updated = config

hparams_file = next(run_dir.glob("./log/csv/csv/*/*/hparams.yaml"))
print(hparams_file)

key_keys = (
    "backbone",
    "embedding",
    "output",
    "graph_targets",
    "node_targets",
    "train_dataset",
    "val_dataset",
    "test_dataset",
    "id",
    "name",
    "name_parts",
    # "predict_dataset",
)

hparams = yaml.unsafe_load(hparams_file.read_text())

# Update the config with the hparams
for key in key_keys:
    assert (value := hparams.get(key)), f"{key} not found in hparams"

    config_dict = config_updated.model_dump(round_trip=True)
    config_dict[key] = value
    config_updated = config_updated.model_validate(config_dict, strict=True)

/net/csefiles/coc-fung-cluster/nima/shared/experiment-data/lltrainer/uirlcewv/log/csv/csv/mptrj-jmps-bsz16-linrefenergy-lr8e-05-ln-direct-maceenergy-maceforce-rele-ec2.0-fc10.0-sc100.0-posaug_std0.01/uirlcewv/hparams.yaml


In [4]:
import os

os.environ["LL_DISABLE_TYPECHECKING"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


import torch

torch.no_grad().__enter__()
torch.inference_mode().__enter__()

In [5]:
# ckpt_path = run_dir / "checkpoint" / "last.ckpt"
ckpt_path = next(run_dir.glob("checkpoint/latest_*.ckpt"))
# If the file is a symlink, get the target
if ckpt_path.is_symlink():
    print(f"Symlink found {ckpt_path} => {ckpt_path.resolve()}")
    ckpt_path = ckpt_path.resolve()

ckpt_path

PosixPath('/net/csefiles/coc-fung-cluster/nima/shared/experiment-data/lltrainer/uirlcewv/checkpoint/latest_epoch31_step724384.ckpt')

In [6]:
from jmppeft.tasks.finetune.base import FinetuneMatBenchDiscoveryIS2REDatasetConfig

dataset_config = FinetuneMatBenchDiscoveryIS2REDatasetConfig(
    # sample_n=DatasetSampleNConfig(sample_n=16, seed=42)
)
print(dataset_config)

dataset_og = dataset_config.create_dataset()
dataset_og, len(dataset_og)


Loading 'wbm_summary' from cached file at '/nethome/nsg6/.cache/matbench-discovery/1.0.0/wbm/2023-12-13-wbm-summary.csv.gz'


(<jmppeft.datasets.mpd_is2re.MatBenchDiscoveryIS2REDataset at 0x7f402814cd90>,
 256963)

In [7]:
from jmppeft.tasks.finetune import matbench_discovery as M

default_dtype = torch.float32
ckpt = torch.load(ckpt_path, map_location="cuda")

model = M.MatbenchDiscoveryModel(config_updated)
model = model.to(default_dtype).cuda().eval()
model.load_state_dict(ckpt["state_dict"])
model

NSHRUNNER_SESSION_DIR is not set. Skipping symlink creation.
NSHRUNNER_SESSION_DIR is not set. Skipping symlink creation.
NSHRUNNER_SESSION_DIR is not set. Skipping symlink creation.
NSHRUNNER_SESSION_DIR is not set. Skipping symlink creation.
Using regular backbone


Unrecognized arguments:  dict_keys(['name', 'learnable_rbf', 'learnable_rbf_stds', 'unique_basis_per_layer', 'dropout', 'edge_dropout', 'ln_per_layer', 'scale_factor_to_ln'])


Constructed backbone with dlora=None
Freezing 0 parameters (0.00%) out of 43,146,824 total parameters (43,146,824 trainable)


MatbenchDiscoveryModel(config=MatbenchDiscoveryConfig(name=mptrj-jmps-bsz16-linrefenergy-lr8e-05-ln-direct-maceenergy-maceforce-rele-ec2.0-fc10.0-sc100.0-posaug_std0.01, project=jmp_mptrj), device=cuda:0)

In [8]:
from collections import Counter

import jmppeft.modules.dataset.dataset_transform as DT
import nshutils
import rich
import torch.utils._pytree as tree
from lightning.fabric.utilities.apply_func import move_data_to_device
from torch_geometric.data import Batch, Data

nshutils.pretty()


def data_transform(data: Data):
    data = model.data_transform(data)
    data = Data.from_dict(
        tree.tree_map(
            lambda x: x.type(default_dtype)
            if torch.is_tensor(x) and torch.is_floating_point(x)
            else x,
            data.to_dict(),
        )
    )
    return data


def composition(data: Batch):
    return dict(Counter(data.atomic_numbers.tolist()))


num_items = 1024

dataset = DT.transform(dataset_og, data_transform)
dataset = DT.sample_n_transform(dataset, n=num_items, seed=42)


idx = 32
data = Batch.from_data_list([dataset[idx]])
rich.print(data.to_dict(), composition(data))

In [9]:
import numpy as np
from jmppeft.modules.relaxer import ModelOutput, Relaxer
from matbench_discovery.energy import get_e_form_per_atom

USE_Y_RELAXED = False
LINREF = np.load(
    "/net/csefiles/coc-fung-cluster/nima/shared/repositories/jmp-peft/notebooks/mptrj_linref.npy"
)


def model_fn(data, initial_data, *, use_y_relaxed: bool = USE_Y_RELAXED) -> ModelOutput:
    model_out = model.forward_denormalized(data)

    energy = model_out["y_relaxed"] if use_y_relaxed else model_out["y"]
    # energy = model_out["y"]
    # relaxed_energy = model_out["y_relaxed"]
    forces = model_out["force"]
    stress = model_out["stress"]

    # Undo the linref
    if LINREF is not None:
        energy = energy + LINREF[data.atomic_numbers.cpu().numpy()].sum()

    # JMP-S v2 energy is corrected_energy, i.e., DFT total energy
    # This energy is now DFT total energy, we need to convert it to formation energy per atom
    energy = get_e_form_per_atom(
        {
            "composition": composition(data),
            "energy": energy,
        }
    )
    assert isinstance(energy, torch.Tensor)
    # assert isinstance(relaxed_energy, torch.Tensor)

    # Add the correction factor
    if False:
        energy += initial_data.y_formation_correction.item()

    # energy, relaxed_energy = tree.tree_map(
    #     lambda energy: energy.view(1), (energy, relaxed_energy)
    # )
    energy = energy.view(1)
    forces = forces.view(-1, 3)
    stress = stress.view(1, 3, 3) if stress.numel() == 9 else stress.view(1, 6)

    return {
        "energy": energy,
        # "relaxed_energy": relaxed_energy,
        "forces": forces,
        "stress": stress,
    }


data = move_data_to_device(data, model.device)
model_fn(data, data)

{'energy': tensor[1] cuda:0 [0.111],
 'forces': tensor[11, 3] n=33 x∈[-0.454, 0.454] μ=0.001 σ=0.216 cuda:0,
 'stress': tensor[1, 3, 3] n=9 x∈[-0.107, -0.003] μ=-0.038 σ=0.051 cuda:0 [[[-0.103, -0.003, -0.003], [-0.003, -0.105, -0.005], [-0.003, -0.005, -0.107]]]}

In [10]:
from functools import partial

import rich
from jmppeft.modules.relaxer import RelaxerConfig

config = RelaxerConfig(
    compute_stress=True,
    stress_weight=0.1,
    optimizer="FIRE",
    fmax=0.05,
    ase_filter="exp",
)
relaxer = Relaxer(
    config=config,
    model=partial(model_fn, use_y_relaxed=False),
    collate_fn=model.collate_fn,
    device=model.device,
)
rich.print(data.y_formation)
relax_out = relaxer.relax(data)
# rich.print(relax_out)

energy = relax_out.atoms.get_total_energy()
rich.print(energy, data.y_formation)

      Step     Time          Energy          fmax
FIRE:    0 20:14:10        0.111085        2.534322
FIRE:    1 20:14:10       -0.033397        1.314817
FIRE:    2 20:14:10       -0.104730        0.361149
FIRE:    3 20:14:10       -0.071463        0.937753
FIRE:    4 20:14:10       -0.076991        0.889091
FIRE:    5 20:14:11       -0.086585        0.788414
FIRE:    6 20:14:11       -0.097697        0.627135
FIRE:    7 20:14:11       -0.107140        0.400099
FIRE:    8 20:14:11       -0.112394        0.335226
FIRE:    9 20:14:11       -0.112895        0.315883
FIRE:   10 20:14:11       -0.110406        0.362677
FIRE:   11 20:14:11       -0.105542        0.567738
FIRE:   12 20:14:12       -0.104622        0.653812
FIRE:   13 20:14:12       -0.113047        0.572379
FIRE:   14 20:14:12       -0.124802        0.323620
FIRE:   15 20:14:12       -0.135171        0.302931
FIRE:   16 20:14:12       -0.139262        0.339625
FIRE:   17 20:14:12       -0.141019        0.430812
FIRE:   18 20:

In [11]:
from collections import defaultdict
from functools import partial
from typing import TypedDict, cast

import numpy as np
from jmppeft.modules.relaxer._relaxer import RelaxationOutput
from torch.utils.data import DataLoader
from torch_geometric.data import Batch
from tqdm.auto import tqdm

use_y_relaxed = False

config = RelaxerConfig(
    compute_stress=True,
    stress_weight=0.1,
    optimizer="FIRE",
    # fmax=0.01,
    # ase_filter="frechet",
    fmax=0.05,
    ase_filter="exp",
)
relaxer = Relaxer(
    config=config,
    model=partial(model_fn, use_y_relaxed=use_y_relaxed),
    collate_fn=model.collate_fn,
    device=model.device,
)

dl = DataLoader(
    dataset,
    batch_size=1,
    collate_fn=model.collate_fn,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
)

preds_targets = defaultdict[str, list[tuple[float, float]]](lambda: [])
mae_error = 0.0
mae_count = 0


class ProblematicSample(TypedDict):
    error: float
    initial_data: Batch
    relaxed_data: Batch
    relax_out: RelaxationOutput


problematic_samples: list[ProblematicSample] = []
ae_threshold = 0.1

for data in tqdm(dl, total=len(dl)):
    data = cast(Batch, data)
    data = move_data_to_device(data, model.device)
    data.y_prediction = data.y_formation
    relaxed_data, relax_out = relaxer.relax_and_return_structure(data, verbose=False)

    e_form_true = data.y_formation.item()
    e_form_pred = relax_out.atoms.get_total_energy()
    preds_targets["e_form"].append((e_form_pred, e_form_true))

    e_above_hull_true = data.y_above_hull.item()
    e_above_hull_pred = e_above_hull_true + (e_form_pred - e_form_true)
    preds_targets["e_above_hull"].append((e_above_hull_pred, e_above_hull_true))

    mae_error += abs(e_form_pred - e_form_true)
    mae_count += 1
    mae_running = mae_error / mae_count

    nsteps = len(relax_out.trajectory.frames)

    error = abs(e_form_pred - e_form_true)
    prefix = "✅"
    if error > ae_threshold:
        problematic_samples.append(
            {
                "error": error,
                "initial_data": move_data_to_device(data, "cpu"),
                "relaxed_data": move_data_to_device(relaxed_data, "cpu"),
                "relax_out": move_data_to_device(relax_out, "cpu"),
            }
        )
        prefix = "❌"

    print(
        f"{prefix} # Steps: {nsteps}; e_form: P={e_form_pred:.4f}, GT={e_form_true:.4f}, Δ={abs(e_form_pred - e_form_true):.4f}, MAE={mae_running:.4f}"
    )

  0%|          | 0/1024 [00:00<?, ?it/s]

✅ # Steps: 20; e_form: P=-0.8377, GT=-0.8008, Δ=0.0369, MAE=0.0369
✅ # Steps: 37; e_form: P=0.0512, GT=-0.0220, Δ=0.0732, MAE=0.0550
✅ # Steps: 32; e_form: P=-0.6289, GT=-0.6088, Δ=0.0201, MAE=0.0434
✅ # Steps: 75; e_form: P=-0.2388, GT=-0.1544, Δ=0.0844, MAE=0.0537
✅ # Steps: 42; e_form: P=-0.1500, GT=-0.1154, Δ=0.0345, MAE=0.0498
✅ # Steps: 23; e_form: P=-0.3919, GT=-0.3905, Δ=0.0014, MAE=0.0418
✅ # Steps: 14; e_form: P=-0.2634, GT=-0.2468, Δ=0.0167, MAE=0.0382
❌ # Steps: 15; e_form: P=-0.5672, GT=-0.4542, Δ=0.1130, MAE=0.0475
✅ # Steps: 13; e_form: P=-1.9540, GT=-1.9089, Δ=0.0450, MAE=0.0473
✅ # Steps: 7; e_form: P=-0.1208, GT=-0.1224, Δ=0.0016, MAE=0.0427
✅ # Steps: 5; e_form: P=-1.9096, GT=-1.8781, Δ=0.0315, MAE=0.0417
✅ # Steps: 5; e_form: P=-1.5727, GT=-1.6101, Δ=0.0373, MAE=0.0413
✅ # Steps: 13; e_form: P=-1.8422, GT=-1.8180, Δ=0.0243, MAE=0.0400


KeyboardInterrupt: 

In [17]:
import copy

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

sns.set_theme()


def compute_relaxed(
    sample: ProblematicSample,
):
    relax_out = sample["relax_out"]

    initial_data = move_data_to_device(
        copy.deepcopy(sample["initial_data"]), model.device
    )
    data = move_data_to_device(copy.deepcopy(sample["initial_data"]), model.device)

    for f in tqdm(relax_out.trajectory.frames):
        data.pos = f.pos.type_as(data.pos).reshape_as(data.pos).to(data.pos.device)
        data.cell = f.cell.type_as(data.cell).reshape_as(data.cell).to(data.cell.device)

        out = model_fn(data, initial_data, use_y_relaxed=True)
        yield out["energy"].item()


def plot_energy_vs_steps(
    sample: ProblematicSample,
    ax: plt.Axes | None = None,
):
    initial_data = sample["initial_data"]
    relax_out = sample["relax_out"]

    e_form_true = initial_data.y_formation.item()
    e_form_pred = [f.energy.item() for f in relax_out.trajectory.frames]
    e_form_pred_relaxed = None
    if True:
        e_form_pred_relaxed = list(compute_relaxed(sample))

    if ax is None:
        _, ax = plt.subplots()
    ax.plot(e_form_pred, label="Predicted")
    if e_form_pred_relaxed:
        ax.plot(e_form_pred_relaxed, label="Predicted (y_relaxed)")
    ax.axhline(y=e_form_true, color="r", linestyle="--", label="True")
    ax.set_xlabel("Step")
    ax.set_ylabel("Formation Energy")
    ax.legend()


def plot_initial_structure(sample: ProblematicSample):
    import ase

    initial_data = sample["initial_data"]
    atoms = ase.Atoms(
        numbers=initial_data.atomic_numbers.cpu().numpy(),
        positions=initial_data.pos.cpu().numpy(),
        cell=initial_data.cell.cpu().squeeze(0).numpy(),
        pbc=[True, True, True],
    )

    import ase.visualize

    return ase.visualize.view(atoms, viewer="ngl")


# plot_energy_vs_steps(problematic_samples[0])
plot_initial_structure(problematic_samples[0])

HBox(children=(NGLWidget(), VBox(children=(Dropdown(description='Show', options=('All', 'Rh', 'Dy', 'H'), valu…