In [1]:
from hydra.experimental import compose
from hydra import initialize_config_dir
import hydra
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from torch.nn import functional as F
import cdvae
import os

#### Load cdvae model for mp_20

In [2]:
os.environ["PROJECT_ROOT"] = "/mnt/c/Users/Lenovo/Downloads/cdvae2"
os.environ["HYDRA_JOBS"] = "/mnt/c/Users/Lenovo/Downloads/cdvae2"
os.environ["WABDB_DIR"] = "/mnt/c/Users/Lenovo/Downloads/cdvae2"

In [3]:
model_path = Path("/mnt/c/Users/Lenovo/Downloads/cdvae2/singlerun/2023-05-19/mp_20")

with initialize_config_dir(str(model_path)):
    ckpts = list(model_path.glob('epoch*.ckpt'))
    if len(ckpts) > 0:
        ckpt_epochs = np.array(
            [int(ckpt.parts[-1].split('-')[0].split('=')[1]) for ckpt in ckpts])
        ckpt = str(ckpts[ckpt_epochs.argsort()[-1]])
        
checkpoint = torch.load(ckpt, map_location=torch.device('cpu'))

gemnet_path = Path(cdvae.__file__).parent / "pl_modules/gemnet/gemnet-dT.json"
checkpoint["hyper_parameters"]["decoder"]["scale_file"] = str(gemnet_path)
ckpt = model_path / "checkpoint_edit.ckpt"
torch.save(checkpoint, model_path / "checkpoint_edit.ckpt")

In [4]:
with initialize_config_dir(str(model_path)):
    # load config
    cfg = compose(config_name='hparams')
    
    # load model
    model = hydra.utils.instantiate(
        cfg.model,
        optim=cfg.optim,
        data=cfg.data,
        logging=cfg.logging,
        _recursive_=False,
    )
    
    model = model.load_from_checkpoint(ckpt)
    model.lattice_scaler = torch.load(model_path / 'lattice_scaler.pt')
    model.scaler = torch.load(model_path / 'prop_scaler.pt')



#### Define functions to get a batch from a cif string

First load lattice and property scalers

In [5]:
lattice_scaler = model.lattice_scaler
scaler = model.scaler

Now define function that takes in a string and returns the batch

In [6]:
from cdvae.common.data_utils import build_crystal, add_scaled_lattice_prop
from torch_geometric.data import Data, Batch
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.core import Lattice, Structure
from pymatgen.analysis.graphs import StructureGraph
from pymatgen.analysis import local_env

CrystalNN = local_env.CrystalNN(
    distance_cutoffs=None, x_diff_weight=-1, porous_adjustment=False)

def atoms_to_structure(atoms):
    lattice = Lattice(atoms.cell)
    # lattice = Lattice.from_parameters(*atoms.cell.cellpar())
    symbols = atoms.get_chemical_symbols()
    positions = atoms.get_positions()
    return Structure(lattice, symbols, positions, coords_are_cartesian=True)

def build_crystal_graph(crystal, graph_method='crystalnn'):
    """
    """

    if graph_method == 'crystalnn':
        crystal_graph = StructureGraph.with_local_env_strategy(
            crystal, CrystalNN)
    elif graph_method == 'none':
        pass
    else:
        raise NotImplementedError

    frac_coords = crystal.frac_coords
    atom_types = crystal.atomic_numbers
    lattice_parameters = crystal.lattice.parameters
    lengths = lattice_parameters[:3]
    angles = lattice_parameters[3:]

    edge_indices, to_jimages = [], []
    if graph_method != 'none':
        for i, j, to_jimage in crystal_graph.graph.edges(data='to_jimage'):
            edge_indices.append([j, i])
            to_jimages.append(to_jimage)
            edge_indices.append([i, j])
            to_jimages.append(tuple(-tj for tj in to_jimage))

    atom_types = np.array(atom_types)
    lengths, angles = np.array(lengths), np.array(angles)
    edge_indices = np.array(edge_indices)
    to_jimages = np.array(to_jimages)
    num_atoms = atom_types.shape[0]

    return frac_coords, atom_types, lengths, angles, edge_indices, to_jimages, num_atoms

def process_one(atoms, graph_method="crystalnn", formation_energy_per_atom=0, material_id=0):
    crystal = atoms_to_structure(atoms)
    graph_arrays = build_crystal_graph(crystal, graph_method)
    return {
        'mp_id': material_id,
        'cif': crystal.to(fmt="cif"),
        'graph_arrays': graph_arrays,
        "formation_energy_per_atom": formation_energy_per_atom,
    }

def get_batch(atoms, **process_kwargs):
    d = [process_one(atoms, **process_kwargs)]
    add_scaled_lattice_prop(d, "scale_length")
    (frac_coords, atom_types, lengths, angles, edge_indices,
     to_jimages, num_atoms) = d[0]['graph_arrays']

    prop = scaler.transform(d[0]["formation_energy_per_atom"])
    data = Data(
        frac_coords=torch.Tensor(frac_coords),
        atom_types=torch.LongTensor(atom_types),
        lengths=torch.Tensor(lengths).view(1, -1),
        angles=torch.Tensor(angles).view(1, -1),
        edge_index=torch.LongTensor(
            edge_indices.T).contiguous(),  # shape (2, num_edges)
        to_jimages=torch.LongTensor(to_jimages),
        num_atoms=num_atoms,
        num_bonds=edge_indices.shape[0],
        num_nodes=num_atoms,  # special attribute used for batching in pytorch geometric
        y=prop.view(1, -1),
    )

    return Batch.from_data_list([data])

#### Ground truth latent space of Ag generation

In [7]:
from pymatgen.core import Structure, Lattice
from pymatgen.io.cif import CifFile

# Reference structure of Ag from Materials Project, id=mp-124
ag_cif = "# generated using pymatgen\ndata_Ag\n_symmetry_space_group_name_H-M   'P 1'\n_cell_length_a   4.10435636\n_cell_length_b   4.10435636\n_cell_length_c   4.10435636\n_cell_angle_alpha   90.00000000\n_cell_angle_beta   90.00000000\n_cell_angle_gamma   90.00000000\n_symmetry_Int_Tables_number   1\n_chemical_formula_structural   Ag\n_chemical_formula_sum   Ag4\n_cell_volume   69.14092475\n_cell_formula_units_Z   4\nloop_\n _symmetry_equiv_pos_site_id\n _symmetry_equiv_pos_as_xyz\n  1  'x, y, z'\nloop_\n _atom_site_type_symbol\n _atom_site_label\n _atom_site_symmetry_multiplicity\n _atom_site_fract_x\n _atom_site_fract_y\n _atom_site_fract_z\n _atom_site_occupancy\n  Ag  Ag0  1  0.00000000  0.00000000  0.00000000  1.0\n  Ag  Ag1  1  0.50000000  0.50000000  0.00000000  1.0\n  Ag  Ag2  1  0.50000000  0.00000000  0.50000000  1.0\n  Ag  Ag3  1  0.00000000  0.50000000  0.50000000  1.0\n"
ag_structure = Structure.from_str(ag_cif, fmt="cif")
ag_atoms = AseAtomsAdaptor.get_atoms(ag_structure)

batch = get_batch(ag_atoms)
mu_t, log_t, z_t = model.encode(batch)

#### Relax function

In [8]:
import ase
import io
from ase.build import niggli_reduce
from ase.constraints import ExpCellFilter
from ase.optimize import FIRE
from chgnet.model import StructOptimizer
from ase.ga import set_raw_score
import contextlib
from ase.calculators.singlepoint import SinglePointCalculator

relaxer = StructOptimizer()

def finalize(atoms, energy=None, forces=None, stress=None):
    # Finalizes the atoms by attaching a SinglePointCalculator
    # and setting the raw score as the negative of the total energy
    atoms.wrap()
    calc = SinglePointCalculator(atoms, energy=energy, forces=forces,
                                 stress=stress)
    atoms.calc = calc
    raw_score = -atoms.get_potential_energy()
    set_raw_score(atoms, raw_score)

def relax(atoms, cellbounds=None):
    atoms.calc = relaxer.calculator  # assign model used to predict forces

    converged = False
    niter = 0
    stream = io.StringIO()
    with contextlib.redirect_stdout(stream):
        while not converged and niter < 10:
            if cellbounds is not None:
                cell = atoms.get_cell()
                if not cellbounds.is_within_bounds(cell):
                    niggli_reduce(atoms)
                cell = atoms.get_cell()
                if not cellbounds.is_within_bounds(cell):
                    # Niggli reduction did not bring the unit cell
                    # within the specified bounds; this candidate should
                    # be discarded so we set an absurdly high energy
                    finalize(atoms, 1e9)
                    return
                
            ecf = ExpCellFilter(atoms)
            dyn = FIRE(ecf, maxmove=0.2, logfile=None, trajectory=None)
            dyn.run(fmax=1e-3, steps=100)

            converged = dyn.converged()
            niter += 1
    
    dyn = FIRE(atoms, maxmove=0.2, logfile=None, trajectory=None)
    dyn.run(fmax=1e-2, steps=100)

    e = atoms.get_potential_energy()
    f = atoms.get_forces()
    s = atoms.get_stress()

    batch = get_batch(atoms)
    mu_x, log_x, z_x = model.encode(batch)
    
    input1 = torch.cat([mu_t,log_t],dim=1)
    input2 = torch.cat((mu_x,log_x),dim=1)
    cos_sim = float(F.cosine_similarity(input1[0], input2[0], dim=0, eps=1e-8))
    
    finalize(atoms, energy=-cos_sim, forces=f, stress=s)

CHGNet initialized with 400,438 parameters
CHGNet will run on cpu


### Genetic algorithm

In [9]:
from ase import Atoms
from ase.data import atomic_numbers
from ase.ga.utilities import closest_distances_generator, CellBounds, get_all_atom_types
from ase.ga.startgenerator import StartGenerator
from ase.ga.data import PrepareDB, DataConnection
from ase.io import write,cif
from ase.ga import get_raw_score, set_raw_score
from ase.ga.population import Population
from ase.ga.ofp_comparator import OFPComparator
from ase.ga.offspring_creator import OperationSelector
from ase.ga.standardmutations import StrainMutation, RotationalMutation, RattleMutation, MirrorMutation
from ase.ga.soft_mutation import SoftMutation
from ase.ga.cutandsplicepairing import CutAndSplicePairing

import io

In [10]:
Path('/mnt/c/Users/Lenovo/Downloads/cdvae2/Ag.db').unlink()

In [11]:
# Number of random initial structures to generate
N = 20

# Target cell volume for the initial structures, in angstrom^3
volume = 69.

natoms = 4
# We may also write:
blocks = ['Ag'] * natoms

# Define the composition of the atoms to optimize
Z = atomic_numbers['Ag']
blmin = closest_distances_generator(atom_numbers=[Z],
                                    ratio_of_covalent_radii=0.5)

# Specify reasonable bounds on the minimal and maximal
# cell vector lengths (in angstrom) and angles (in degrees)
cellbounds = CellBounds(bounds={'phi': [35, 145], 'chi': [35, 145],
                                'psi': [35, 145], 'a': [2, 50],
                                'b': [2, 50], 'c': [2, 50]})

# Choose an (optional) 'cell splitting' scheme which basically controls the level of translational symmetry (within the unit cell) of the randomly generated structures. 
# Here a 1:1 ratio of splitting factors 2 and 1 is used:
splits = {(2,): 1, (1,): 1}
# There will hence be a 50% probability that a can didateis constructed by repeating an randomly generated Ag12 structure along a randomly chosen axis. 
# In the other 50% of cases, no cell cell splitting will be applied.

# The 'slab' object: a template in the creation of new structures, which inherit the slab's atomic positions (if any), 
# cell vectors (if specified), and periodic boundary conditions.
# Here only the last property is relevant:
slab = Atoms('', pbc=True)

# Initialize the random structure generator
sg = StartGenerator(slab, blocks, blmin, box_volume=volume,
                    number_of_variable_cell_vectors=3,
                    splits=splits, cellbounds=cellbounds)

# Create the database
da = PrepareDB(db_file_name='Ag.db',
               stoichiometry=[Z] * natoms)

# Generate N random structures and add them to the database
for i in range(N):
    a = sg.get_new_candidate()
    da.add_unrelaxed_candidate(a)

In [12]:
# Connect to the database and retrieve some information
da = DataConnection('Ag.db')
slab = da.get_slab()
atom_numbers_to_optimize = da.get_atom_numbers_to_optimize()
n_top = len(atom_numbers_to_optimize)

# Use Oganov's fingerprint functions to decide whether
# two structures are identical or not
comp = OFPComparator(n_top=n_top, dE=1.0, cos_dist_max=1e-3,
                     rcut=10., binwidth=0.05, pbc=[True, True, True], 
                     sigma=0.05, nsigma=4, recalculate=False)

# Define the cell and interatomic distance bounds
# that the candidates must obey
blmin = closest_distances_generator(atom_numbers_to_optimize, 0.5)

cellbounds = CellBounds(bounds={'phi': [35, 145], 'chi': [35, 145],
                                'psi': [35, 145], 'a': [2, 50],
                                'b': [2, 50], 'c': [2, 50]})

# Define a pairing operator with 100% (0%) chance that the first
# (second) parent will be randomly translated, and with each parent
# contributing to at least 15% of the child's scaled coordinates
pairing = CutAndSplicePairing(slab, n_top, blmin, p1=1., p2=0., minfrac=0.15,
                              number_of_variable_cell_vectors=3,
                              cellbounds=cellbounds, use_tags=False)

# Define a strain mutation with a typical standard deviation of 0.7
# for the strain matrix elements (drawn from a normal distribution)
strainmut = StrainMutation(blmin, stddev=0.7, cellbounds=cellbounds,
                           number_of_variable_cell_vectors=3,
                           use_tags=False)

# Set up the relative probabilities for the different operators
operators = OperationSelector([4., 3., 3.], 
                              [pairing, strainmut, MirrorMutation(blmin, n_top)])

# Relax the initial candidates
while da.get_number_of_unrelaxed_candidates() > 0:
    a = da.get_an_unrelaxed_candidate()

    relax(a, cellbounds=cellbounds)
    da.add_relaxed_step(a)
    print(get_raw_score(a))

    cell = a.get_cell()
    if not cellbounds.is_within_bounds(cell):
        da.kill_candidate(a.info['confid'])

# Initialize the population
population_size = 20
population = Population(data_connection = da,
                        population_size = population_size,
                        comparator = comp,
                        use_extinct = True)

# Update the scaling volume used in some operators
# based on a number of the best candidates
current_pop = population.get_current_population()
print(len(current_pop), get_raw_score(current_pop[0]))
strainmut.update_scaling_volume(current_pop, w_adapt=0.5, n_adapt=4)
pairing.update_scaling_volume(current_pop, w_adapt=0.5, n_adapt=4)

# Test n_to_test new candidates; in this example we need
# only few GA iterations as the global minimum (FCC Ag)
# is very easily found (typically already after relaxation
# of the initial random structures).
n_to_test = 20

for step in range(n_to_test):
    print('Now starting configuration number {0}'.format(step))
    # Create a new candidate
    a3 = None
    while a3 is None:
        a1, a2 = population.get_two_candidates()
        a3, desc = operators.get_new_individual([a1, a2])

    a3.set_pbc(np.array([True, True, True]))
    # Save the unrelaxed candidate
    da.add_unrelaxed_candidate(a3, description=desc)

    # Relax the new candidate and save it
    relax(a3, cellbounds=cellbounds)
    da.add_relaxed_step(a3)

    # If the relaxation has changed the cell parameters
    # beyond the bounds we disregard it in the population
    cell = a3.get_cell()
    if not cellbounds.is_within_bounds(cell):
        da.kill_candidate(a3.info['confid'])

    # Update the population
    population.update()
    current_pop = population.get_current_population()
    print('Step %d %s %.3f %.3f %.3f' % (step, desc, get_raw_score(a1), get_raw_score(a2), get_raw_score(a3)))
    print('Step %d highest raw score in pop: %.3f' % (step, get_raw_score(current_pop[0])))

    if step % 10 == 0:
        # Update the scaling volumes of the strain mutation
        # and the pairing operator based on the current
        # best structures contained in the population
        current_pop = population.get_current_population()
        strainmut.update_scaling_volume(current_pop, w_adapt=0.5, n_adapt=4)
        pairing.update_scaling_volume(current_pop, w_adapt=0.5, n_adapt=4)
        write('current_population.traj', current_pop)
            
print('GA finished after step %d' % step)
hiscore = get_raw_score(current_pop[0])
print('Highest raw score = %8.4f eV' % hiscore)

all_candidates = da.get_all_relaxed_candidates()
write('all_candidates.traj', all_candidates)

current_pop = population.get_current_population()
write('current_population.traj', current_pop)



0.9531611800193787
0.962939441204071
0.9514378905296326
0.9114278554916382
0.9113357663154602
0.9629392027854919
0.911343514919281
0.9531704187393188
0.9112879037857056
0.9636750817298889
0.9531543254852295
0.9629392027854919
0.9629393219947815
0.9629393219947815
0.9111520648002625
0.9227465987205505
0.9789755344390869
0.9111471772193909
0.9789769053459167
0.9113268852233887
4 0.9789769053459167
Now starting configuration number 0
Step 0 mutation: strain 0.953 0.979 0.953
Step 0 highest raw score in pop: 0.979
Now starting configuration number 1
Step 1 pairing: 20 11 0.979 0.964 0.953
Step 1 highest raw score in pop: 0.979
Now starting configuration number 2
Step 2 mutation: strain 0.979 0.964 0.979
Step 2 highest raw score in pop: 0.979
Now starting configuration number 3
Step 3 pairing: 42 11 0.953 0.964 0.948
Step 3 highest raw score in pop: 0.979
Now starting configuration number 4
Step 4 pairing: 44 11 0.953 0.964 0.953
Step 4 highest raw score in pop: 0.979
Now starting configura

In [13]:
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.analysis.structure_matcher import StructureMatcher
from pymatgen.core import Structure

def is_match(atoms, symprec=0.001):
    s = AseAtomsAdaptor.get_structure(atoms)
    sga = SpacegroupAnalyzer(s, symprec=symprec)
    ps = sga.get_conventional_standard_structure()

    sm = StructureMatcher(scale=False, primitive_cell=False)
    print(sm.fit(ag_structure, ps), sm.get_rms_dist(ag_structure, ps))
    
is_match(current_pop[0])

True (0.0, 0.0)


In [14]:
# Check if there is a hit in the current population
for x in current_pop:
    print('raw score : %.5f' % get_raw_score(x))
    is_match(x)

raw score : 0.97898
True (0.0, 0.0)
raw score : 0.97707
False None
raw score : 0.96685
False None
raw score : 0.96627
False None
raw score : 0.96627
False None
raw score : 0.96368
False None
raw score : 0.96153
False None
raw score : 0.96153
False None
raw score : 0.95318
False None
raw score : 0.95316
False None
raw score : 0.95283
False None
raw score : 0.94459
False None
raw score : 0.91394
False None


In [15]:
# Check if there is a hit in all candidates
for x in all_candidates:
    print('raw score : %.5f' % get_raw_score(x))
    is_match(x)

raw score : 0.97898
True (0.0, 0.0)
raw score : 0.97898
False None
raw score : 0.97898
False None
raw score : 0.97707
False None
raw score : 0.96685
False None
raw score : 0.96627
False None
raw score : 0.96627
False None
raw score : 0.96368
False None
raw score : 0.96294
True (0.0, 0.0)
raw score : 0.96294
True (0.0, 0.0)
raw score : 0.96294
False None
raw score : 0.96294
False None
raw score : 0.96294
False None
raw score : 0.96153
False None
raw score : 0.96153
False None
raw score : 0.96153
False None
raw score : 0.95318
False None
raw score : 0.95317
False None
raw score : 0.95316
False None
raw score : 0.95316
False None
raw score : 0.95316
False None
raw score : 0.95316
False None
raw score : 0.95315
False None
raw score : 0.95283
False None
raw score : 0.95144
True (0.0, 0.0)
raw score : 0.94821
True (0.0, 0.0)
raw score : 0.94459
False None
raw score : 0.92275
False None
raw score : 0.91394
False None
raw score : 0.91393
False None
raw score : 0.91143
False None
raw score : 0.

In [18]:
# Get a True structure from all candidates, and test the 
# cosine similarity of the original, conventional, and primitive structures

a = all_candidates[-15]
s = AseAtomsAdaptor.get_structure(a)
sga = SpacegroupAnalyzer(s, symprec=0.001)
cs = sga.get_conventional_standard_structure()
ps = sga.get_primitive_standard_structure()
s.to(filename='original.vasp', fmt="poscar")
cs.to(filename='conventional.vasp', fmt="poscar")
ps.to(filename='primitive.vasp', fmt="poscar")

# If conventional structure of the found structure match the ground true structure
sm = StructureMatcher(scale=False, primitive_cell=False)
print(sm.fit(ag_structure, cs), sm.get_rms_dist(ag_structure, cs))

# Check the cosine similarity of the orinial structure with the ground true structure
batch = get_batch(AseAtomsAdaptor.get_atoms(s))
mu_x, log_x, z_x = model.encode(batch)
input1 = torch.cat([mu_t,log_t],dim=1)
input2 = torch.cat((mu_x,log_x),dim=1)
print(float(F.cosine_similarity(input1[0], input2[0], dim=0, eps=1e-8)))

# Check the cosine similarity of the conventional structure with the ground true structure
batch = get_batch(AseAtomsAdaptor.get_atoms(cs))
mu_x, log_x, z_x = model.encode(batch)
input2 = torch.cat((mu_x,log_x),dim=1)
print(float(F.cosine_similarity(input1[0], input2[0], dim=0, eps=1e-8)))

# Check the cosine similarity of the primitive structure with the ground true structure
batch = get_batch(AseAtomsAdaptor.get_atoms(ps))
mu_x, log_x, z_x = model.encode(batch)
input2 = torch.cat((mu_x,log_x),dim=1)
print(float(F.cosine_similarity(input1[0], input2[0], dim=0, eps=1e-8)))

True (0.0, 0.0)
0.9482123255729675
0.9992227554321289
0.9617213010787964


### Test the structure invariance of CDVAE

- Try invariance with respect to periodicity (supercell)

In [19]:
supercell = ag_structure.copy() * [2, 2, 2]
supercell.to(filename='supercell.vasp', fmt="poscar")

batch = get_batch(AseAtomsAdaptor.get_atoms(supercell))
mu_x, log_x, z_x = model.encode(batch)
input1 = torch.cat([mu_t,log_t],dim=1)
input2 = torch.cat((mu_x,log_x),dim=1)
print(float(F.cosine_similarity(input1[0], input2[0], dim=0, eps=1e-8)))

0.8800023198127747


- Try invariance with respect to translation.

In [25]:
supercell = ag_structure.copy()
supercell.translate_sites(list(range(len(supercell))), [0.33333, 0.33333, 0.75])

batch = get_batch(AseAtomsAdaptor.get_atoms(supercell))
mu_x, log_x, z_x = model.encode(batch)
input1 = torch.cat([mu_t,log_t],dim=1)
input2 = torch.cat((mu_x,log_x),dim=1)
print(float(F.cosine_similarity(input1[0], input2[0], dim=0, eps=1e-8)))

1.0


- Try invariance with respect to rotation

In [26]:
from pymatgen.core.operations import SymmOp

op = SymmOp.from_origin_axis_angle(
    origin = [0, 0, 0],
    axis = [1, 0, 0],
    angle = 45
)
supercell = ag_structure.copy()
supercell = supercell.apply_operation(op)

batch = get_batch(AseAtomsAdaptor.get_atoms(supercell))
mu_x, log_x, z_x = model.encode(batch)
input1 = torch.cat([mu_t,log_t],dim=1)
input2 = torch.cat((mu_x,log_x),dim=1)
print(float(F.cosine_similarity(input1[0], input2[0], dim=0, eps=1e-8)))

1.0
