In [32]:
import torch
from pathlib import Path
import yaml
import sys
import numpy as np
sys.path.append("..")
from ai4mat.models.megnet_pytorch.megnet_pytorch import MEGNet
from ai4mat.models.megnet_pytorch.struct2graph import (
    SimpleCrystalConverter, GaussianDistanceConverter,
    FlattenGaussianDistanceConverter, AtomFeaturesExtractor)

In [30]:
model_name = "megnet_pytorch/sparse/05-12-2022_19-50-53/d6b7ce45"
with open(Path("../trials/") / f"{model_name}.yaml", "r") as f:
    config = yaml.safe_load(f)
checkpoint = torch.load("../datasets/checkpoints/combined_mixed_all_train/formation_energy_per_site/megnet_pytorch/sparse/05-12-2022_19-50-53/d6b7ce45/0.pth")

In [None]:
class MEGNetOnStructures(torch.nn.Module):
    def __init__(self, config):
        self.config = config        
        if self.config["model"]["add_z_bond_coord"]:
            bond_converter = FlattenGaussianDistanceConverter(
                centers=np.linspace(0, self.config['model']['cutoff'], self.config['model']['edge_embed_size'])
            )
        else:
            bond_converter = GaussianDistanceConverter(
                centers=np.linspace(0, self.config['model']['cutoff'], self.config['model']['edge_embed_size'])
            )
        atom_converter = AtomFeaturesExtractor(self.config["model"]["atom_features"])
        self.converter = SimpleCrystalConverter(
            bond_converter=bond_converter,
            atom_converter=atom_converter,
            cutoff=self.config["model"]["cutoff"],
            add_z_bond_coord=self.config["model"]["add_z_bond_coord"],
            add_eos_features=(use_eos := self.config["model"].get("add_eos_features", False)),
        )
        self.model = MEGNet(
            edge_input_shape=bond_converter.get_shape(eos=use_eos),
            node_input_shape=atom_converter.get_shape(),
            embedding_size=self.config['model']['embedding_size'],
            n_blocks=self.config['model']['nblocks'],
            state_input_shape=self.config["model"]["state_input_shape"],
            vertex_aggregation=self.config["model"]["vertex_aggregation"],
            global_aggregation=self.config["model"]["global_aggregation"],
        )
        #self.Scaler = Scaler()
    
    def load(self, checkpoint_file_name):
        checkpoint = torch.load(checkpoint_file_name)
        self.model.load_state_dict(checkpoint['model_params']['model'])
    
    def predict(self, sparse_structure):
        graph = self.converter.convert(sparse_structure)
        return self.model(graph)

In [29]:
model_keys = {"edge_input_shape", "node_input_shape", "embedding_size", "n_blocks", "state_input_shape", "vertex_aggregation", "global_aggregation"}
model = MEGNet(**{key: config['model_params']['model'][key] for key in model_keys})

KeyError: 'node_input_shape'