In [1]:
from hydra.experimental import compose
from hydra import initialize_config_dir
import hydra
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from torch.nn import functional as F
import cdvae
import os

#### Load cdvae model for mp_20

In [2]:
os.environ["PROJECT_ROOT"] = "/mnt/c/Users/Lenovo/Downloads/cdvae2"
os.environ["HYDRA_JOBS"] = "/mnt/c/Users/Lenovo/Downloads/cdvae2"
os.environ["WABDB_DIR"] = "/mnt/c/Users/Lenovo/Downloads/cdvae2"

In [3]:
model_path = Path("/mnt/c/Users/Lenovo/Downloads/cdvae2/singlerun/2023-05-19/mp_20")

with initialize_config_dir(str(model_path)):
    ckpts = list(model_path.glob('epoch*.ckpt'))
    if len(ckpts) > 0:
        ckpt_epochs = np.array(
            [int(ckpt.parts[-1].split('-')[0].split('=')[1]) for ckpt in ckpts])
        ckpt = str(ckpts[ckpt_epochs.argsort()[-1]])
        
checkpoint = torch.load(ckpt, map_location=torch.device('cpu'))

gemnet_path = Path(cdvae.__file__).parent / "pl_modules/gemnet/gemnet-dT.json"
checkpoint["hyper_parameters"]["decoder"]["scale_file"] = str(gemnet_path)
ckpt = model_path / "checkpoint_edit.ckpt"
torch.save(checkpoint, model_path / "checkpoint_edit.ckpt")

In [4]:
with initialize_config_dir(str(model_path)):
    # load config
    cfg = compose(config_name='hparams')
    
    # load model
    model = hydra.utils.instantiate(
        cfg.model,
        optim=cfg.optim,
        data=cfg.data,
        logging=cfg.logging,
        _recursive_=False,
    )
    
    model = model.load_from_checkpoint(ckpt)
    model.lattice_scaler = torch.load(model_path / 'lattice_scaler.pt')
    model.scaler = torch.load(model_path / 'prop_scaler.pt')



#### Define functions to get a batch from a cif string

First load lattice and property scalers

In [5]:
lattice_scaler = model.lattice_scaler
scaler = model.scaler

Now define function that takes in a string and returns the batch

In [6]:
from cdvae.common.data_utils import build_crystal, build_crystal_graph, add_scaled_lattice_prop
from torch_geometric.data import Data, Batch
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.core import Lattice, Structure

def atoms_to_structure(atoms):
    lattice = Lattice.from_parameters(*atoms.cell.cellpar())
    symbols = atoms.get_chemical_symbols()
    positions = atoms.get_positions()
    return Structure(lattice, symbols, positions, coords_are_cartesian=True)


def process_one(atoms, graph_method="crystalnn", formation_energy_per_atom=0, material_id=0):
    crystal = atoms_to_structure(atoms)
    graph_arrays = build_crystal_graph(crystal, graph_method)
    return {
        'mp_id': material_id,
        'cif': crystal.to(fmt="cif"),
        'graph_arrays': graph_arrays,
        "formation_energy_per_atom": formation_energy_per_atom,
    }

def get_batch(atoms, **process_kwargs):
    d = [process_one(atoms, **process_kwargs)]
    add_scaled_lattice_prop(d, "scale_length")
    (frac_coords, atom_types, lengths, angles, edge_indices,
     to_jimages, num_atoms) = d[0]['graph_arrays']

    prop = scaler.transform(d[0]["formation_energy_per_atom"])
    data = Data(
        frac_coords=torch.Tensor(frac_coords),
        atom_types=torch.LongTensor(atom_types),
        lengths=torch.Tensor(lengths).view(1, -1),
        angles=torch.Tensor(angles).view(1, -1),
        edge_index=torch.LongTensor(
            edge_indices.T).contiguous(),  # shape (2, num_edges)
        to_jimages=torch.LongTensor(to_jimages),
        num_atoms=num_atoms,
        num_bonds=edge_indices.shape[0],
        num_nodes=num_atoms,  # special attribute used for batching in pytorch geometric
        y=prop.view(1, -1),
    )

    return Batch.from_data_list([data])

#### Ground truth latent space of Ag generation

In [17]:
from pymatgen.core import Structure, Lattice

ag_cif = "# generated using pymatgen\ndata_Ag\n_symmetry_space_group_name_H-M   'P 1'\n_cell_length_a   4.10435636\n_cell_length_b   4.10435636\n_cell_length_c   4.10435636\n_cell_angle_alpha   90.00000000\n_cell_angle_beta   90.00000000\n_cell_angle_gamma   90.00000000\n_symmetry_Int_Tables_number   1\n_chemical_formula_structural   Ag\n_chemical_formula_sum   Ag4\n_cell_volume   69.14092475\n_cell_formula_units_Z   4\nloop_\n _symmetry_equiv_pos_site_id\n _symmetry_equiv_pos_as_xyz\n  1  'x, y, z'\nloop_\n _atom_site_type_symbol\n _atom_site_label\n _atom_site_symmetry_multiplicity\n _atom_site_fract_x\n _atom_site_fract_y\n _atom_site_fract_z\n _atom_site_occupancy\n  Ag  Ag0  1  0.00000000  0.00000000  0.00000000  1.0\n  Ag  Ag1  1  0.50000000  0.50000000  0.00000000  1.0\n  Ag  Ag2  1  0.50000000  0.00000000  0.50000000  1.0\n  Ag  Ag3  1  0.00000000  0.50000000  0.50000000  1.0\n"
ag_structure = Structure.from_str(ag_cif, fmt="cif")
ag_atoms = AseAtomsAdaptor.get_atoms(ag_structure)

In [7]:
ag_structure = Structure.from_file("/mnt/c/Users/Lenovo/Downloads/cdvae2/structures_GA/ag_cif/Ag_989737.cif")
ag_atoms = AseAtomsAdaptor.get_atoms(ag_structure)



In [8]:
batch = get_batch(ag_atoms)
mu_t, log_t, z_t = model.encode(batch)

#### Relax function

In [9]:
import ase
from ase.ga import set_raw_score
from chgnet.model import StructOptimizer
from ase.calculators.singlepoint import SinglePointCalculator

def finalize(atoms, energy=None, forces=None, stress=None):
    # Finalizes the atoms by attaching a SinglePointCalculator
    # and setting the raw score as the negative of the total energy
    atoms.wrap()
    calc = SinglePointCalculator(atoms, energy=energy, forces=forces,
                                 stress=stress)
    atoms.calc = calc
    raw_score = atoms.get_potential_energy()
    set_raw_score(atoms, raw_score)

relaxer = StructOptimizer()
def relax(atoms):
    # Relax our candidates using chgnet

    result = relaxer.relax(atoms_to_structure(atoms), verbose=False)
    result = relaxer.relax(result["final_structure"], verbose=False)
    relaxed_atoms = AseAtomsAdaptor.get_atoms(result["final_structure"])
    # Compare cosine similarity of the relaxed structures with ground truth
    relaxed_atoms.info = atoms.info
    
    batch = get_batch(relaxed_atoms)
    mu_x, log_x, z_x = model.encode(batch)
    
    input1 = torch.cat([mu_t,log_t],dim=1)
    input2 = torch.cat((mu_x,log_x),dim=1)
    cos_sim = F.cosine_similarity(input1[0], input2[0], dim=0, eps=1e-8)
    
    finalize(
        relaxed_atoms,
        energy = (float(cos_sim) + 1) ** 2,
        forces = result["trajectory"].forces[-1],
        stress = result["trajectory"].stresses[-1],
    )  
    return relaxed_atoms

CHGNet initialized with 400,438 parameters
CHGNet will run on cpu


### Genetic algorithm

In [10]:
from ase import Atoms
from ase.data import atomic_numbers
from ase.ga.utilities import closest_distances_generator, CellBounds, get_all_atom_types
from ase.ga.startgenerator import StartGenerator
from ase.ga.data import PrepareDB, DataConnection
from ase.io import write,cif
from ase.ga import get_raw_score, set_raw_score
from ase.ga.population import Population
from ase.ga.ofp_comparator import OFPComparator
from ase.ga.offspring_creator import OperationSelector
from ase.ga.standardmutations import StrainMutation, RotationalMutation, RattleMutation, MirrorMutation
from ase.ga.soft_mutation import SoftMutation
from ase.ga.cutandsplicepairing import CutAndSplicePairing
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.analysis.structure_matcher import StructureMatcher
from pymatgen.core import Structure
import io

In [2]:
def run_ga(probability):
    # Take mutation operators as variable and return highest cosine similarity
    Path('/mnt/c/Users/Lenovo/Downloads/cdvae2/Ag.db').unlink()
    # Number of random initial structures to generate
    N = 20

    # Target cell volume for the initial structures, in angstrom^3
    volume = 162.46

    natoms = 9
    # We may also write:
    blocks = ['Ag'] * natoms

    # Define the composition of the atoms to optimize
    Z = atomic_numbers['Ag']
    blmin = closest_distances_generator(atom_numbers=[Z],
                                        ratio_of_covalent_radii=0.5)

    # Specify reasonable bounds on the minimal and maximal
    # cell vector lengths (in angstrom) and angles (in degrees)
    cellbounds = CellBounds(bounds={'phi': [25, 155], 'chi': [25, 155],
                                    'psi': [25, 155], 'a': [3, 50],
                                    'b': [3, 50], 'c': [3, 50]})

    # Choose an (optional) 'cell splitting' scheme which basically controls the level of translational symmetry (within the unit cell) of the randomly generated structures. 
    # Here a 1:1 ratio of splitting factors 2 and 1 is used:
    splits = {(2,): 1, (1,): 1}
    # There will hence be a 50% probability that a can didateis constructed by repeating an randomly generated Ag12 structure along a randomly chosen axis. 
    # In the other 50% of cases, no cell cell splitting will be applied.

    # The 'slab' object: a template in the creation of new structures, which inherit the slab's atomic positions (if any), 
    # cell vectors (if specified), and periodic boundary conditions.
    # Here only the last property is relevant:
    slab = Atoms('', pbc=True)

    # Initialize the random structure generator
    sg = StartGenerator(slab, blocks, blmin, box_volume=volume,
                        number_of_variable_cell_vectors=3,
                        splits=splits, cellbounds=cellbounds)

    # Create the database
    da = PrepareDB(db_file_name='Ag.db',
                   stoichiometry=[Z] * natoms)

    # Generate N random structures and add them to the database
    for i in range(N):
        a = sg.get_new_candidate()
        da.add_unrelaxed_candidate(a)
    
    # Connect to the database and retrieve some information
    da = DataConnection('Ag.db')
    slab = da.get_slab()
    atom_numbers_to_optimize = da.get_atom_numbers_to_optimize()
    n_top = len(atom_numbers_to_optimize)

    # Use Oganov's fingerprint functions to decide whether
    # two structures are identical or not
    comp = OFPComparator(n_top=n_top, dE=1.0, cos_dist_max=1e-3, rcut=10.,
                         binwidth=0.05, pbc=[True, True, True],
                         sigma=0.05, nsigma=4, recalculate=False)

    # Define the cell and interatomic distance bounds
    # that the candidates must obey
    blmin = closest_distances_generator(atom_numbers_to_optimize, 0.5)

    # Define a pairing operator with 100% (0%) chance that the first
    # (second) parent will be randomly translated, and with each parent
    # contributing to at least 15% of the child's scaled coordinates
    pairing = CutAndSplicePairing(slab, n_top, blmin, p1=1., p2=0., minfrac=0.15,
                                  number_of_variable_cell_vectors=3,
                                  cellbounds=cellbounds, use_tags=False)

    # Define a strain mutation with a typical standard deviation of 0.7
    # for the strain matrix elements (drawn from a normal distribution)
    strainmut = StrainMutation(blmin, stddev=0.7, cellbounds=cellbounds,
                               number_of_variable_cell_vectors=3,
                               use_tags=False)

    # Define a soft mutation
    blmin_soft = closest_distances_generator(atom_numbers_to_optimize, 0.1)
    softmut = SoftMutation(blmin_soft, bounds=[2, 5.], use_tags=False)

    # By default, the operator will update a "used_modes.json" file
    # after every mutation, listing which modes have been used so far
    # for each structure in the database. The mode indices start at 3
    # as the three lowest frequency modes are translational modes.

    # Set up the relative probabilities for the different operators
    
    operators = OperationSelector(probability,[pairing, strainmut, softmut, MirrorMutation(blmin, n_top),])

    # Relax the initial candidates
    while da.get_number_of_unrelaxed_candidates() > 0:
        a = da.get_an_unrelaxed_candidate()
    
        a = relax(a)
        a = relax(a)
        da.add_relaxed_step(a)

        cell = a.get_cell()
        if not cellbounds.is_within_bounds(cell):
            da.kill_candidate(a.info['confid'])

    # Initialize the population
    population_size = 20
    population = Population(data_connection=da,
                            population_size=population_size,
                            comparator=comp,
                            use_extinct=True)

    # Update the scaling volume used in some operators
    # based on a number of the best candidates
    current_pop = population.get_current_population()
    strainmut.update_scaling_volume(current_pop, w_adapt=0.5, n_adapt=4)
    pairing.update_scaling_volume(current_pop, w_adapt=0.5, n_adapt=4)
    
    # Test n_to_test new candidates
    n_to_test = 20

    for step in range(n_to_test):
    
        # Create a new candidate
        a3 = None
        while a3 is None:
            a1, a2 = population.get_two_candidates()
            a3, desc = operators.get_new_individual([a1, a2])

        # Save the unrelaxed candidate
        da.add_unrelaxed_candidate(a3, description=desc)

        # Relax the new candidate and save it
        a3 = relax(a3)
        a3 = relax(a3)
        da.add_relaxed_step(a3)

        # If the relaxation has changed the cell parameters
        # beyond the bounds we disregard it in the population
        cell = a3.get_cell()
        if not cellbounds.is_within_bounds(cell):
            da.kill_candidate(a3.info['confid'])

        # Update the population
        population.update()

        if step % 10 == 0:
            # Update the scaling volumes of the strain mutation
            # and the pairing operator based on the current
            # best structures contained in the population
            current_pop = population.get_current_population()
            strainmut.update_scaling_volume(current_pop, w_adapt=0.5, n_adapt=4)
            pairing.update_scaling_volume(current_pop, w_adapt=0.5, n_adapt=4)
            write('current_population.traj', current_pop)
    
    hiscore = np.sqrt(get_raw_score(current_pop[0])) - 1
    print('Highest raw score = %8.4f' % hiscore)
    
    is_match(current_pop[0])
    
    return hiscore

In [12]:
list_of_probability = [[3., 4., 1., 2.], [3., 3., 1., 3.], [4., 3., 3., 0.]]
all_raw_scores = []
for probability in list_of_probability:
    print(probability)
    operator_raw_scores = []
    while len(operator_raw_scores) < 5:
        try:
            raw_score = run_ga(probability)
            operator_raw_scores.append(raw_score)
        except:
            pass
    print(operator_raw_scores)
    all_raw_scores.append(operator_raw_scores)

[3.0, 4.0, 1.0, 2.0]
Highest raw score =   0.9808
Highest raw score =   0.9806
Highest raw score =   0.9828
Highest raw score =   0.9810
Highest raw score =   0.9811
[0.9808115363121033, 0.9805940985679626, 0.982781708240509, 0.9810068607330322, 0.9810959100723267]
[3.0, 3.0, 1.0, 3.0]
Highest raw score =   0.9852
Highest raw score =   0.9851
Highest raw score =   0.9807
Highest raw score =   0.9797
Highest raw score =   0.9825
[0.9851791858673096, 0.9850552082061768, 0.9806743860244751, 0.9796713590621948, 0.982513964176178]
[4.0, 3.0, 3.0, 0.0]
Highest raw score =   0.9833
Highest raw score =   0.9814
Highest raw score =   0.9830
Highest raw score =   0.9789
Highest raw score =   0.9833
[0.9833438992500305, 0.9813639521598816, 0.9830278158187866, 0.978891909122467, 0.9832966327667236]


In [13]:
print(np.mean(all_raw_scores[2]), np.std(all_raw_scores[2]))

0.9819848418235779 0.0017092316396371003


In [1]:
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.analysis.structure_matcher import StructureMatcher
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.core import Structure

def is_match(atoms, symprec=0.001):
    s = AseAtomsAdaptor.get_structure(atoms)
    sga = SpacegroupAnalyzer(s, symprec=symprec)
    ps = sga.get_conventional_standard_structure()

    sm = StructureMatcher(scale=False, primitive_cell=False)
    print(sm.fit(ag_structure, ps), sm.get_rms_dist(ag_structure, ps))

In [16]:
ps

Structure Summary
Lattice
    abc : 2.9378535333410376 2.93902641215515 4.202189888102476
 angles : 90.0 90.0 90.0
 volume : 36.28351077765953
      A : 2.9378535333410376 0.0 1.798916462984922e-16
      B : 4.726314941384558e-16 2.93902641215515 1.799634644127667e-16
      C : 0.0 0.0 4.202189888102476
    pbc : True True True
PeriodicSite: Ag (0.0000, 2.1743, 2.1011) [0.0000, 0.7398, 0.5000]
PeriodicSite: Ag (1.4689, 0.7648, 0.0000) [0.5000, 0.2602, 0.0000]

In [17]:
ag_structure

Structure Summary
Lattice
    abc : 4.10435636 4.10435636 4.10435636
 angles : 90.0 90.0 90.0
 volume : 69.14092474530555
      A : 4.10435636 0.0 2.513193439417041e-16
      B : 6.600308424860354e-16 4.10435636 2.513193439417041e-16
      C : 0.0 0.0 4.10435636
    pbc : True True True
PeriodicSite: Ag (0.0000, 0.0000, 0.0000) [0.0000, 0.0000, 0.0000]
PeriodicSite: Ag (2.0522, 2.0522, 0.0000) [0.5000, 0.5000, 0.0000]
PeriodicSite: Ag (2.0522, 0.0000, 2.0522) [0.5000, 0.0000, 0.5000]
PeriodicSite: Ag (0.0000, 2.0522, 2.0522) [0.0000, 0.5000, 0.5000]