In [1]:
from m3gnet.models import M3GNet, Potential, M3GNetCalculator
from m3gnet.trainers import PotentialTrainer
from pymatgen.core import Structure, Lattice
from m3gnet.models import Relaxer
import tensorflow as tf

m3gnet = M3GNet.from_dir('callbacks/f_t-2022-05-19_19-43-46/00031-0.161686-0.064913-0.088447-0.083264')
potential = Potential(model=m3gnet)

relaxer = Relaxer(potential='callbacks/f_t-2022-05-19_19-43-46/00031-0.161686-0.064913-0.088447-0.083264')
# this loads the default model



In [4]:
import json, bz2
data = json.load(bz2.open('data/spg20_geo_opt.json.bz2',"rb"))
data.keys()

dict_keys(['./Cu/CuI/xxx_02p-00_20-2-Cu2a.I2b', './Cu/CuI/xxx_02p-00_20-2-Cu2b.I2a', './Cu/CuI/xxx_02p-00_20-2-Cu2b.I2c', './Cu/CuI/xxx_02p-00_20-2-Cu2c.I2c'])

In [5]:
import numpy as np
from pymatgen.core.structure import Structure
def get_data(data,cutoff=0.):
    structures = []
    energy = []
    stress = []
    forces = []
    counter = 0
    for el in data.values():
        for el2 in el:
            try:
                if el2['energy']<cutoff:
                    st = Structure.from_dict(el2['structure'])
                    potential.graph_converter.convert(st)
                    structures.append(st)
                    energy.append(el2['energy'])
                    stress.append(-np.array(el2['stress'])*0.1)  #kbar to GPa
                    forces.append(el2['forces'])
            except:
                counter+=1
                print(counter,"isolated sytem")
    return energy,forces,stress,structures

In [6]:
e, f,stress,struct = get_data(data)
struct

[Structure Summary
 Lattice
     abc : 4.39381809 4.39025482 18.73851009
  angles : 90.0 90.0 90.0
  volume : 361.4655045005905
       A : 4.39381809 0.0 0.0
       B : 0.0 4.39025482 0.0
       C : 0.0 0.0 18.73851009
 PeriodicSite: I (0.0000, 2.1951, 11.2385) [0.0000, 0.5000, 0.5998]
 PeriodicSite: I (2.1969, 2.1951, 7.5000) [0.5000, 0.5000, 0.4002]
 PeriodicSite: Cu (0.0000, 0.0000, 10.0912) [0.0000, 0.0000, 0.5385]
 PeriodicSite: Cu (2.1969, 0.0000, 8.6473) [0.5000, 0.0000, 0.4615],
 Structure Summary
 Lattice
     abc : 4.39185111 4.39039371 18.73851009
  angles : 90.0 90.0 90.0
  volume : 361.3151174780716
       A : 4.39185111 0.0 0.0
       B : 0.0 4.39039371 0.0
       C : 0.0 0.0 18.73851009
 PeriodicSite: I (0.0000, 2.1952, 11.2382) [0.0000, 0.5000, 0.5997]
 PeriodicSite: I (2.1959, 2.1952, 7.5003) [0.5000, 0.5000, 0.4003]
 PeriodicSite: Cu (0.0000, 0.0000, 10.0910) [0.0000, 0.0000, 0.5385]
 PeriodicSite: Cu (2.1959, 0.0000, 8.6475) [0.5000, 0.0000, 0.4615],
 Structure Summa

In [14]:
relax_results = relaxer.relax(struct[3])

final_structure = relax_results['final_structure']
final_energy = relax_results['trajectory'].energies[-1] / 2

print(f"Relaxed lattice parameter is {final_structure.lattice.abc[0]: .3f} Å")
print(f"Final energy is {final_energy.item(): .3f} eV/atom")



Tensor("mul_2:0", shape=(1, 3, 3), dtype=float32)
      Step     Time          Energy         fmax
*Force-consistent energies used in optimization.
FIRE:    0 01:10:02      -10.089573*       7.8406
FIRE:    1 01:10:02      -10.354762*       3.2166
Tensor("mul_2:0", shape=(1, 3, 3), dtype=float32)
FIRE:    2 01:10:05      -10.416962*       2.0620
FIRE:    3 01:10:05      -10.429105*       1.8233
FIRE:    4 01:10:05      -10.449032*       1.3664
FIRE:    5 01:10:06      -10.469250*       0.7178
FIRE:    6 01:10:06      -10.481853*       0.3820
FIRE:    7 01:10:06      -10.484583*       0.8165
FIRE:    8 01:10:06      -10.485381*       0.7851
FIRE:    9 01:10:06      -10.486904*       0.7240
FIRE:   10 01:10:06      -10.489009*       0.6361
FIRE:   11 01:10:06      -10.491523*       0.5259
FIRE:   12 01:10:06      -10.494272*       0.4041
FIRE:   13 01:10:06      -10.497110*       0.3499
FIRE:   14 01:10:06      -10.499952*       0.3644
FIRE:   15 01:10:06      -10.503082*       0.3793
FI



Tensor("mul_2:0", shape=(1, 3, 3), dtype=float32)
FIRE:   33 01:10:10      -10.827917*       0.3923
FIRE:   34 01:10:10      -10.901319*       0.9438
FIRE:   35 01:10:10      -10.965939*       1.9333
FIRE:   36 01:10:10      -10.979331*       0.2620
FIRE:   37 01:10:10      -10.973838*       1.7320
FIRE:   38 01:10:10      -10.977514*       1.3411
FIRE:   39 01:10:10      -10.981916*       0.6627
FIRE:   40 01:10:10      -10.983829*       0.2464
FIRE:   41 01:10:10      -10.983902*       0.2458
FIRE:   42 01:10:10      -10.984048*       0.2446
FIRE:   43 01:10:10      -10.984258*       0.2429
FIRE:   44 01:10:10      -10.984529*       0.2408
FIRE:   45 01:10:10      -10.984859*       0.2384
FIRE:   46 01:10:10      -10.985246*       0.2358
FIRE:   47 01:10:10      -10.985682*       0.2331
FIRE:   48 01:10:10      -10.986227*       0.2303
FIRE:   49 01:10:10      -10.986909*       0.2274
FIRE:   50 01:10:10      -10.987755*       0.2247
FIRE:   51 01:10:10      -10.988805*       0.2223
