In [21]:
import argparse
from pathlib import Path
import pandas as pd
import torch
import json
import random
import numpy as np
from typing import List, Tuple
import tiktoken
from dataclasses import dataclass
import os
import time
from pymatgen.core.composition import Composition
from pymatgen.core.structure import Structure


In [22]:
def contains_elements(comp, target_elements):
    """Check if composition contains all target elements (Ag, O) regardless of ratio."""
    return all(el in comp.elements for el in target_elements)

def matches_composition(comp, target_elements, target_ratio):
    """Check if composition contains all target elements (Ag, O) regardless of ratio."""
    if not all(el in comp.elements for el in target_elements):
        return False
    reduced_comp, _ = comp.get_reduced_composition_and_factor()
    return all(abs(reduced_comp[el] - amt) <= 1e-6 for el, amt in target_ratio.items())


In [None]:
seed_structures_df = pd.read_csv('resources/band_gap_processed.csv')
seed_structures_df['structure'] = seed_structures_df['structure'].apply(lambda x: Structure.from_str(x, fmt='json') if pd.notna(x) else None)
seed_structures_df['composition'] = [s.composition for s in seed_structures_df['structure']]
seed_structures_df['composition_str'] = [s.composition.formula for s in seed_structures_df['structure']]

In [None]:
comps = ['Ag6O2', 'Bi2F8', 'Co2Sb2', 'Co4B2', 'Cr4Si4', 'KZnF3', "Sr2O4", "YMg3"]
for comp in comps:
    target_comp = Composition(comp)
    target_elements = set(target_comp.elements)
    comp_df = seed_structures_df[seed_structures_df['composition'].apply(lambda comp: contains_elements(comp, target_elements))]
    print(comp, len(comp_df))

In [None]:
import importlib
import utils
importlib.reload(utils.stability)
from utils.stability import StabilityCalculator
# from utils.e_hull_calculator import EHullCalculator
# ehull_model = EHullCalculator('resources/2023-02-07-ppd-mp.pkl.gz')
stability_calculator_chgnet = StabilityCalculator(mlip='chgnet', ppd_path='resources/2023-02-07-ppd-mp.pkl.gz')
stability_calculator_orbv3 = StabilityCalculator(mlip='orb-v3', ppd_path='resources/2023-02-07-ppd-mp.pkl.gz')
stability_calculator_sevenet = StabilityCalculator(mlip='sevenet', ppd_path='resources/2023-02-07-ppd-mp.pkl.gz')

Initialize EHullCalcul with patched phase diagram.


In [12]:
save_label = "poscar_70b_csp_KZnF3"
generations_df = pd.read_csv(f'../results/{save_label}/generations.csv')
generations_df.head()


Unnamed: 0,Iteration,Structure,ParentStructures,Objective,Composition,DeltaE,EHullDistance,BulkModulus,StructureRelaxed,BulkModulusRelaxed
0,1,"{""@class"": ""Structure"", ""@module"": ""pymatgen.c...","[""{\""@class\"": \""Structure\"", \""@module\"": \""p...",inf,F3 K1 Zn1,-1.874561,0.795208,0.0,"{""@class"": ""Structure"", ""@module"": ""pymatgen.c...",0.0
1,1,"{""@class"": ""Structure"", ""@module"": ""pymatgen.c...","[""{\""@class\"": \""Structure\"", \""@module\"": \""p...",inf,F3 K1 Zn1,-0.917131,0.373527,0.0,"{""@class"": ""Structure"", ""@module"": ""pymatgen.c...",0.0
2,1,"{""@class"": ""Structure"", ""@module"": ""pymatgen.c...","[""{\""@class\"": \""Structure\"", \""@module\"": \""p...",inf,F3 K1 Zn1,0.0,1.717362,0.0,"{""@class"": ""Structure"", ""@module"": ""pymatgen.c...",0.0
3,1,"{""@class"": ""Structure"", ""@module"": ""pymatgen.c...","[""{\""@class\"": \""Structure\"", \""@module\"": \""p...",inf,F3 Zn1 K1,-1.814619,0.661457,0.0,"{""@class"": ""Structure"", ""@module"": ""pymatgen.c...",0.0
4,1,"{""@class"": ""Structure"", ""@module"": ""pymatgen.c...","[""{\""@class\"": \""Structure\"", \""@module\"": \""p...",inf,K1 Zn1 F3,-2.028812,0.581382,0.0,"{""@class"": ""Structure"", ""@module"": ""pymatgen.c...",0.0


In [None]:
structures = generations_df['Structure'].apply(lambda x: Structure.from_str(x, fmt='json') if pd.notna(x) else None)
stability_results = structures.apply(lambda x: stability_calculator_orbv3.compute_stability([x])[0])
generations_df['EHullDistance_orbv3'] = stability_results.apply(lambda x: x.e_hull_distance if x is not None else None)


generations_df['StructureRelaxed'] = stability_results.apply(lambda x: json.dumps(x.structure_relaxed.as_dict(), sort_keys=True) if x is not None else None)
generations_df['EHullDistance'] = stability_results.apply(lambda x: x.e_hull_distance if x is not None else None)
generations_df['DeltaE'] = stability_results.apply(lambda x: x.delta_e if x is not None else None)

In [None]:
structures = generations_df['Structure'].apply(lambda x: Structure.from_str(x, fmt='json') if pd.notna(x) else None)
stability_results = structures.apply(lambda x: stability_calculator_chgnet.compute_stability([x])[0])
generations_df['EHullDistance_chgnet'] = stability_results.apply(lambda x: x.e_hull_distance if x is not None else None)


In [None]:
structure = structures[0]
relaxation = stability_calculator_chgnet.relax_structure(structure)
print(relaxation['trajectory'])

In [15]:
structures = generations_df['Structure'].apply(lambda x: Structure.from_str(x, fmt='json') if pd.notna(x) else None)
stability_results = structures.apply(lambda x: stability_calculator_sevenet.compute_stability([x])[0])
generations_df['EHullDistance_sevenet'] = stability_results.apply(lambda x: x.e_hull_distance if x is not None else None)


      Step     Time          Energy          fmax
BFGS:    0 10:37:11       -9.790998       13.944713
BFGS:    1 10:37:12      -15.393703        5.386726
BFGS:    2 10:37:16      -16.615000        3.509117
BFGS:    3 10:37:16      -17.996862        3.571410
BFGS:    4 10:37:17      -18.333149        3.029196
BFGS:    5 10:37:17      -18.531142        1.719498
BFGS:    6 10:37:18      -18.628388        1.056752
BFGS:    7 10:37:18      -18.742896        1.148373
BFGS:    8 10:37:20      -18.786866        1.263015
BFGS:    9 10:37:20      -18.860932        1.118379
BFGS:   10 10:37:20      -18.930048        0.658703
BFGS:   11 10:37:20      -18.957711        0.134170
BFGS:   12 10:37:21      -18.958362        0.029463
Relaxation error: 'StabilityCalculator' object has no attribute 'adaptor'
      Step     Time          Energy          fmax
BFGS:    0 10:37:21      -16.484656        1.935385
BFGS:    1 10:37:21      -16.589857        1.843249
BFGS:    2 10:37:21      -16.891786        1.8

In [16]:
generations_df.to_csv(f'../results/{save_label}/generations_compare.csv', index=False)

In [17]:
stats_df = pd.DataFrame({
    'mean': [generations_df['EHullDistance_sevenet'].mean(),
             generations_df['EHullDistance_chgnet'].mean(),
             generations_df['EHullDistance_orbv3'].mean()],
    'min': [generations_df['EHullDistance_sevenet'].min(),
            generations_df['EHullDistance_chgnet'].min(),
            generations_df['EHullDistance_orbv3'].min()],
    'max': [generations_df['EHullDistance_sevenet'].max(),
            generations_df['EHullDistance_chgnet'].max(),
            generations_df['EHullDistance_orbv3'].max()]
})
stats_df.index = ['sevenet', 'chgnet', 'orbv3']

# Print statistics
print("Statistics for EHull Distance values:")
print(stats_df)

# Show a few examples
print("\nExample values (first 5 rows):")
examples = generations_df[['EHullDistance_sevenet', 
                          'EHullDistance_chgnet', 
                          'EHullDistance_orbv3']].head()
print(examples)

Statistics for EHull Distance values:
             mean       min       max
sevenet       NaN       NaN       NaN
chgnet        NaN       NaN       NaN
orbv3    0.757425  0.373461  2.780483

Example values (first 5 rows):
  EHullDistance_sevenet EHullDistance_chgnet  EHullDistance_orbv3
0                  None                 None             0.795278
1                  None                 None             0.373461
2                  None                 None             1.717378
3                  None                 None             0.661424
4                  None                 None             0.547062


In [60]:
generations_df = pd.read_csv(f'../results/{save_label}/generations.csv')
generations_df['EHullDistance']-generations_df_['EHullDistance']
# generations_df['DeltaE']-generations_df_['DeltaE']

0     0.633691
1     0.277096
2     1.561936
3     0.530927
4     0.364662
        ...   
70    0.335799
71    0.416198
72    0.661585
73    0.633039
74    0.363765
Name: EHullDistance, Length: 75, dtype: float64

In [59]:
structure = structures[6]
energy = stability_calculator_chgnet.compute_energy_per_atom(structure)
print('energy chgnet', energy)
energy_ = stability_calculator_orbv3.compute_energy_per_atom(structure)
print('energy orbv3', energy_)
relaxation = stability_calculator_chgnet.relax_structure(structure)
print('relaxation', relaxation)
relaxation_ = stability_calculator_orbv3.relax_structure(structure)
print('relaxation', relaxation_)
structure_relaxed = relaxation['final_structure']
structure_relaxed_ = relaxation_['final_structure']
energy_relaxed = stability_calculator_chgnet.compute_energy_per_atom(structure_relaxed)
print('energy_relaxed', energy_relaxed)
energy_relaxed_ = stability_calculator_orbv3.compute_energy_per_atom(structure_relaxed_)
print('energy_relaxed', energy_relaxed_)

e_hull_distance = stability_calculator_chgnet.compute_ehull_dist(structure_relaxed, energy_relaxed) 
e_hull_distance_ = stability_calculator_orbv3.compute_ehull_dist(structure_relaxed_, energy_relaxed_) 
print('e_hull_distance', e_hull_distance, e_hull_distance_)

energy chgnet -2.0863840579986572
energy orbv3 -1.7150611877441406
      Step     Time          Energy          fmax
FIRE:    0 23:51:37      -10.429727       24.679861
FIRE:    1 23:51:37      -14.791532        4.302103
FIRE:    2 23:51:37      -14.962103        5.026131


  atoms = ExpCellFilter(atoms)


FIRE:    3 23:51:37      -15.100472        4.801128
FIRE:    4 23:51:37      -15.380871        4.861904
FIRE:    5 23:51:37      -15.813762        4.482694
FIRE:    6 23:51:37      -16.275381        3.753484
FIRE:    7 23:51:37      -16.670322        3.221899
FIRE:    8 23:51:37      -16.964349        4.696006
FIRE:    9 23:51:37      -17.337703        4.891908
FIRE:   10 23:51:37      -17.876228        5.005354
FIRE:   11 23:51:37      -18.457884        4.042971
FIRE:   12 23:51:37      -18.943471        2.843278
FIRE:   13 23:51:38      -19.383278        2.758730
FIRE:   14 23:51:38      -19.789118        2.557762
FIRE:   15 23:51:38      -20.109429        2.808764
FIRE:   16 23:51:38      -20.433445        2.635607
FIRE:   17 23:51:38      -20.661502        3.393600
FIRE:   18 23:51:38      -20.743437        2.083626
FIRE:   19 23:51:38      -20.802956        1.515449
FIRE:   20 23:51:38      -20.811725        1.386216
FIRE:   21 23:51:38      -20.825481        1.190453
FIRE:   22 2

  warn("Using UFloat objects with std_dev==0 may give unexpected results.")
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1622.55it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4429.04it/s]

e_hull_distance 0.11256942460009434 0.41828365036913695





In [10]:
stability_results = seed_structures_df[:3]['structure'].apply(lambda x: stability_calculator_chgnet.compute_stability([x])[0])


Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at /pytorch/aten/src/ATen/native/Cross.cpp:62.)
  volumes.append(torch.dot(lattice[0], torch.cross(lattice[1], lattice[2])))
  atoms = ExpCellFilter(atoms)


      Step     Time          Energy          fmax
FIRE:    0 22:28:01     -109.424622        0.758916
FIRE:    1 22:28:01     -109.432205        0.717667
FIRE:    2 22:28:01     -109.416626        2.688207
FIRE:    3 22:28:01     -109.429115        1.080351
FIRE:    4 22:28:01     -109.434593        0.791901
FIRE:    5 22:28:01     -109.431854        0.460174
FIRE:    6 22:28:01     -109.433754        0.456066
FIRE:    7 22:28:01     -109.437897        0.389377
FIRE:    8 22:28:01     -109.436249        0.354471
FIRE:    9 22:28:01     -109.434708        0.334662
FIRE:   10 22:28:01     -109.435349        0.294138
FIRE:   11 22:28:02     -109.438332        0.160065
FIRE:   12 22:28:02     -109.431656        0.200378
FIRE:   13 22:28:02     -109.432053        0.174881
FIRE:   14 22:28:02     -109.433456        0.214812
FIRE:   15 22:28:02     -109.434082        0.235061
FIRE:   16 22:28:02     -109.434731        0.236065
FIRE:   17 22:28:02     -109.438965        0.201615
FIRE:   18 22:

  warn("Using UFloat objects with std_dev==0 may give unexpected results.")
  0%|                                                                                                                                                                                                            | 0/1 [00:00<?, ?it/s]

E-hull computation error: Unable to get decomposition for None ComputedStructureEntry - K4 Mn4 O8    (KMnO2)
Energy (Uncorrected)     = -109.4375 eV (-6.8398  eV/atom)
Correction               = 0.0000    eV (0.0000   eV/atom)
Energy (Final)           = -109.4375 eV (-6.8398  eV/atom)
Energy Adjustments:
  None
Parameters:
Data:
      Step     Time          Energy          fmax
FIRE:    0 22:28:03      -33.942574        0.540111
FIRE:    1 22:28:03      -33.939793        0.549961





FIRE:    2 22:28:03      -33.940987        0.298507
FIRE:    3 22:28:03      -33.941803        0.097068


  0%|                                                                                                                                                                                                            | 0/1 [00:00<?, ?it/s]

E-hull computation error: Unable to get decomposition for None ComputedStructureEntry - Cr3 Ni1      (Cr3Ni)
Energy (Uncorrected)     = -33.9413  eV (-8.4853  eV/atom)
Correction               = 0.0000    eV (0.0000   eV/atom)
Energy (Final)           = -33.9413  eV (-8.4853  eV/atom)
Energy Adjustments:
  None
Parameters:
Data:
      Step     Time          Energy          fmax
FIRE:    0 22:28:03       -7.270300        0.518828
FIRE:    1 22:28:03       -7.277307        0.383347
FIRE:    2 22:28:03       -7.280029        0.115481
FIRE:    3 22:28:03       -7.280678        0.238777
FIRE:    4 22:28:03       -7.281534        0.208204





FIRE:    5 22:28:03       -7.281169        0.174603
FIRE:    6 22:28:03       -7.281131        0.119441
FIRE:    7 22:28:03       -7.282486        0.069151


  0%|                                                                                                                                                                                                            | 0/1 [00:00<?, ?it/s]

E-hull computation error: Unable to get decomposition for None ComputedStructureEntry - Cs1 Rb1 As1  (CsRbAs)
Energy (Uncorrected)     = -7.2821   eV (-2.4274  eV/atom)
Correction               = 0.0000    eV (0.0000   eV/atom)
Energy (Final)           = -7.2821   eV (-2.4274  eV/atom)
Energy Adjustments:
  None
Parameters:
Data:





In [62]:
import os
import pandas as pd
from pymatgen.io.cif import CifWriter

output_dir = "visualization"
os.makedirs(output_dir, exist_ok=True)

# Remove None values and sort by EHullDistance
valid_structures = generations_df.dropna(subset=['EHullDistance'])
sorted_structures = valid_structures.sort_values(by='EHullDistance')

# Define how many of the best structures to save
num_structures = 20  # Adjust this number as needed

# Save top structures to CIF files
for i, (index, row) in enumerate(sorted_structures.head(num_structures).iterrows()):
    structure = Structure.from_str(row['StructureRelaxed'], fmt='json')
    formula = structure.composition.reduced_formula
    e_hull = row['EHullDistance']
    
    # Create a descriptive filename
    filename = f"{i+1:03d}_{formula}_ehull_{e_hull:.4f}.cif"
    filepath = os.path.join("visualization", filename)
    
    # Write the structure to a CIF file
    cif_writer = CifWriter(structure)
    cif_writer.write_file(filepath)
    
    print(f"Saved structure {i+1}: {formula} with E-hull = {e_hull:.4f} eV/atom to {filepath}")


Saved structure 1: KZnF3 with E-hull = 0.3735 eV/atom to visualization/001_KZnF3_ehull_0.3735.cif
Saved structure 2: KZnF3 with E-hull = 0.4183 eV/atom to visualization/002_KZnF3_ehull_0.4183.cif
Saved structure 3: KZnF3 with E-hull = 0.4994 eV/atom to visualization/003_KZnF3_ehull_0.4994.cif
Saved structure 4: KZnF3 with E-hull = 0.5066 eV/atom to visualization/004_KZnF3_ehull_0.5066.cif
Saved structure 5: KZnF3 with E-hull = 0.5172 eV/atom to visualization/005_KZnF3_ehull_0.5172.cif
Saved structure 6: KZnF3 with E-hull = 0.5242 eV/atom to visualization/006_KZnF3_ehull_0.5242.cif
Saved structure 7: KZnF3 with E-hull = 0.5267 eV/atom to visualization/007_KZnF3_ehull_0.5267.cif
Saved structure 8: KZnF3 with E-hull = 0.5814 eV/atom to visualization/008_KZnF3_ehull_0.5814.cif
Saved structure 9: KZnF3 with E-hull = 0.5849 eV/atom to visualization/009_KZnF3_ehull_0.5849.cif
Saved structure 10: KZnF3 with E-hull = 0.6147 eV/atom to visualization/010_KZnF3_ehull_0.6147.cif
Saved structure 11:

In [61]:
def generation_dedup(df) :
    unique_structures = []
    keep_indices = []
    reduced_formulas = []

    for idx, struct_str in enumerate(df['StructureRelaxed']):
        try:
            structure = Structure.from_dict(json.loads(struct_str))
            reduced_formula = structure.composition.reduced_formula
            df.at[idx, 'Composition'] = structure.composition
            df.at[idx, 'composition_str'] = structure.composition.formula
            is_duplicate = any(
                structure.matches(unique_struct, scale=True, attempt_supercell=False)
                for unique_struct, formula in zip(unique_structures, reduced_formulas)
                if formula == reduced_formula
            )
            if not is_duplicate:
                unique_structures.append(structure)
                reduced_formulas.append(reduced_formula)
                keep_indices.append(idx)
            else:
                print(idx, df.at[idx, 'EHullDistance'])
                for j, (unique_struct, formula) in enumerate(zip(unique_structures, reduced_formulas)):
                    if structure.matches(unique_struct, scale=True, attempt_supercell=False):
                        print(idx,j,keep_indices[j])
        except Exception as e:
            print(f"Error processing structure at index {idx}: {e}")
            continue
    return df.iloc[keep_indices].reset_index(drop=True)


def valid_value(x):
    return (x is not None and not np.isinf(x) and not np.isnan(x) and x != 0)


def valid_mean(values):
    valid_values = [x for x in values if valid_value(x)]
    return np.mean(valid_values) if valid_values else 0.0
# generations_df = pd.read_csv(f'../results/{save_label}/generations_chgnet.csv')

generations_df = generation_dedup(generations_df)

7 0.5601926774687467
7 6 6
8 0.4524274797148404
8 1 1
9 0.4070173234648404
9 1 1
9 6 6
10 0.4080041856474574
10 1 1
10 6 6
12 0.5316343278593716
12 3 3
14 0.4271181077665984
14 6 6
26 0.374398228616208
26 1 1
27 0.377571675271481
27 1 1
31 0.3738821000761687
31 1 1
33 0.4999204606718717
33 1 1
35 0.5943899125761685
35 23 32
37 0.5535833329863249
37 1 1
37 6 6
37 20 28
38 0.539777943582028
38 1 1
39 1.0131635636992151
39 19 25
40 0.5671838731474574
40 15 21
41 0.5419568986601528
41 3 3
41 15 21
42 0.5998018235869109
42 4 4
43 0.6164617509550747
43 8 13
44 0.5299604386992156
44 1 1
48 0.4041578263945276
48 1 1
49 0.4234739274687467
49 1 1
51 0.6442007035917934
51 27 46
52 0.7821348161406219
52 1 1
52 6 6
53 1.2003122300810514
53 1 1
53 6 6
56 0.6532911271757778
56 15 21
57 0.6953706712431607
57 1 1
57 6 6
57 20 28
58 0.4178335160917932
58 1 1
58 6 6
58 20 28
59 0.4042860002226529
59 1 1
60 0.5272088975615201
60 23 32
61 0.3819986314482388
61 1 1
63 0.4042516679472623
63 1 1
64 0.46080379

In [55]:
s1 = Structure.from_str(generations_df.at[10, 'StructureRelaxed'], fmt='json')
s2 = Structure.from_str(generations_df.at[6, 'StructureRelaxed'], fmt='json')
s_mp.matches(s2)

True

In [58]:
s_mp = Structure.from_file("visualization/KZnF3.cif")
s_mp.matches(s2)

False