In [None]:
import sys
sys.path.append('../')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

In [None]:
import os

In [None]:
import ase
from ase import Atoms
from ase.optimize import FIRE
from ase.constraints import ExpCellFilter
from ase.io import read,write

In [None]:
from pymatgen.core.structure import Structure
from pymatgen.core.lattice import Lattice
from pymatgen.analysis.structure_matcher import StructureMatcher

In [None]:
# Function to convert ASE to Pymatgen Structure
def ase_to_pymatgen(ase_atoms):
    lattice = Lattice(ase_atoms.cell)
    species = ase_atoms.get_chemical_symbols()
    coords = ase_atoms.get_positions()
    structure = Structure(lattice, species, coords, coords_are_cartesian=True)
    return structure

In [None]:
# Initialize StructureMatcher with specified thresholds
matcher = StructureMatcher(stol=0.5, angle_tol=10, ltol=0.3)

In [None]:
from copy import deepcopy

In [None]:
from matplotlib import cm

In [None]:
import cace
from cace.representations.cace_representation import Cace

In [None]:
from cace.tools import to_numpy

In [None]:
cutoff = 4.0
batch_size = 10

In [None]:
device = cace.tools.init_device('cpu')

In [None]:
cace_nnp = torch.load('best_model.pth', map_location=device)
cace_nnp.to(device)
ew = to_numpy(cace_nnp.representation.node_embedding_sender.embedding_weights)

In [None]:
trainable_params = sum(p.numel() for p in cace_nnp.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {trainable_params}")

In [None]:
from cace.calculators import CACECalculator

In [None]:
calculator = CACECalculator(model_path=cace_nnp, #'water-model-n-3-r-6-12-c4.5-mp-1.pth', 
                            device='cpu', 
                            energy_key='CACE_energy', 
                            forces_key='CACE_forces',
                            stress_key='CACE_stress',
                           compute_stress=True)

In [None]:
def compute_average_Vs(
    atom_list: Atoms, zs
):
    """
    Function to compute the average interaction volume of each chemical element
    returns dictionary of E0s
    """
    len_xyz = len(atom_list)
    len_zs = len(zs)

    A = np.zeros((len_xyz, len_zs))
    B = np.zeros(len_xyz)
    for i in range(len_xyz):
        B[i] = atom_list[i].get_volume()
        for j, z in enumerate(zs):
            A[i, j] = np.count_nonzero(atom_list[i].get_atomic_numbers() == z)
    try:
        V0s = np.linalg.lstsq(A, B, rcond=None)[0]
        atomic_V_dict = {}
        for i, z in enumerate(zs):
            atomic_V_dict[z] = V0s[i]
    except np.linalg.LinAlgError:
        logging.warning(
            "Failed to compute E0s using least squares regression, using the same for all atoms"
        )
        atomic_V_dict = {}
        for i, z in enumerate(zs):
            atomic_V_dict[z] = 0.0
    return atomic_V_dict

In [None]:
import pickle
import os

if os.path.exists('avgV0.pkl'):
    with open('avgV0.pkl', 'rb') as f:
        atomic_number_to_volume = pickle.load(f)
else:
    test_xyz = ase.io.read('../more-datasets/mp_20/mp20-train.xyz', ':')
    zs=[ i for i in range(1,95)]
    atomic_number_to_volume = compute_average_Vs(test_xyz, zs)
    # save the avge0 dict to a file
    with open('avgV0.pkl', 'wb') as f:
        pickle.dump(atomic_number_to_volume, f)

In [None]:
element_to_atomic_number = {'H': 1, 'He': 2, 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, 'N': 7, 'O': 8, 'F': 9, 'Ne': 10, 'Na': 11, 'Mg': 12, 'Al': 13, 'Si': 14, 'P': 15, 'S': 16, 'Cl': 17, 'Ar': 18, 'K': 19, 'Ca': 20, 'Sc': 21, 'Ti': 22, 'V': 23, 'Cr': 24, 'Mn': 25, 'Fe': 26, 'Co': 27, 'Ni': 28, 'Cu': 29, 'Zn': 30, 'Ga': 31, 'Ge': 32, 'As': 33, 'Se': 34, 'Br': 35, 'Kr': 36, 'Rb': 37, 'Sr': 38, 'Y': 39, 'Zr': 40, 'Nb': 41, 'Mo': 42, 'Tc': 43, 'Ru': 44, 'Rh': 45, 'Pd': 46, 'Ag': 47, 'Cd': 48, 'In': 49, 'Sn': 50, 'Sb': 51, 'Te': 52, 'I': 53, 'Xe': 54, 'Cs': 55, 'Ba': 56, 'La': 57, 'Ce': 58, 'Pr': 59, 'Nd': 60, 'Pm': 61, 'Sm': 62, 'Eu': 63, 'Gd': 64, 'Tb': 65, 'Dy': 66, 'Ho': 67, 'Er': 68, 'Tm': 69, 'Yb': 70, 'Lu': 71, 'Hf': 72, 'Ta': 73, 'W': 74, 'Re': 75, 'Os': 76, 'Ir': 77, 'Pt': 78, 'Au': 79, 'Hg': 80, 'Tl': 81, 'Pb': 82, 'Bi': 83, 'Po': 84, 'At': 85, 'Rn': 86, 'Fr': 87, 'Ra': 88, 'Ac': 89, 'Th': 90, 'Pa': 91, 'U': 92, 'Np': 93, 'Pu': 94, 'Am': 95, 'Cm': 96, 'Bk': 97, 'Cf': 98, 'Es': 99, 'Fm': 100, 'Md': 101, 'No': 102, 'Lr': 103, 'Rf': 104, 'Db': 105, 'Sg': 106, 'Bh': 107, 'Hs': 108, 'Mt': 109, 'Ds': 110, 'Rg': 111, 'Cn': 112, 'Nh': 113, 'Fl': 114, 'Mc': 115, 'Lv': 116, 'Ts': 117, 'Og': 118}

In [None]:
# Co2Sb2 Sr2O4 AlAg4 YMg3 Cr4Si4
#Sn4Pd4 Ag6O2 Co4B2 Ba2Cd6 Bi2F8
#KZnF3 Cr3CuO8 Bi4S4Cl4 Si2(CN2)4 Hg2S2O8

# https://proceedings.neurips.cc/paper_files/paper/2023/file/38b787fc530d0b31825827e2cc306656-Paper-Conference.pdf

In [None]:
all_test_compositions = [
    {'Co':2, 'Sb':2},
    {'Sr':2, 'O':4},
    {'Al': 1, 'Ag':4}, 
    {'Y': 1, 'Mg':3}, 
    {'Cr':4, 'Si':4},
    #
    {'Sn':4, 'Pd':4},  
    {'Ag':6, 'O':2},  
    {'Co':4, 'B':2},  
    {'Ba':2, 'Cd':6},  
    {'Bi':2, 'F':8}, 
    #
    {'K': 1, 'Zn': 1, 'F':3 }, 
    {'Cr':3, 'Cu': 1, 'O':8 }, 
    {'Bi': 4, 'S':4, 'Cl':4 }, 
    {'Si':2, 'C':4, 'N':8 }, 
    {'Hg': 2, 'S':2, 'O': 8}, 
]

In [None]:
name_2_formula = {
 'Sr2O4':'SrO2',
 'Co2Sb2':'CoSb',
 'Cr4Si4':'CrSi',
 'Sn4Pd4':'SnPd',
 'Ag6O2':'Ag3O',
 'Co4B2': 'Co2B',
 'Ba2Cd6': 'BaCd3',
 'Bi2F8': 'BiF4'
    }

In [None]:
min_distance = 1.0
names_list = []

for test_composition in all_test_compositions[:]:
    name = ''
    for ele in test_composition:
        name+=ele
        if test_composition[ele] > 1: name+=str(test_composition[ele])
    names_list.append(name)

    for nforms in [1]:
        
        cell_volume = 0.0
        n_atoms = 0
        symbols = []
        for ele in test_composition:
            ele_num = element_to_atomic_number[ele]
            n_atoms += test_composition[ele]
            cell_volume += atomic_number_to_volume[ele_num] * test_composition[ele]
            symbols += [ele] * test_composition[ele]
        
        symbols *= nforms
        n_atoms *= nforms
        cell_volume *= nforms
        print(symbols)
        box_size = cell_volume**(1./3.) - min_distance

        for rr in range(30):
            positions = []

            while len(positions) < n_atoms:
                new_pos = np.random.rand(3) * np.array([ box_size, box_size, box_size ])
                if all(np.linalg.norm(new_pos - p) >= min_distance for p in positions):
                    positions.append(new_pos)

            # Create ASE Atoms object
            atoms = Atoms(symbols=symbols, 
                        positions=positions,
                        cell=[box_size + min_distance, box_size + min_distance, box_size + min_distance], 
                        pbc=True)
            old_cell = atoms.get_cell()
            old_v = atoms.get_volume()
            new_cell = old_cell + (np.random.rand(3,3) - 0.5) * old_cell * 0.5
            new_v = np.linalg.det(new_cell)
            new_cell *= (old_v/new_v)**(1./3.)

            # Scale the positions to maintain the relative coordinates
            scaled_positions = atoms.get_scaled_positions()
            atoms.set_cell(new_cell, scale_atoms=True)
            atoms.set_scaled_positions(scaled_positions)
        

            atoms.set_calculator(calculator)
            #print(atoms)

            atoms_c = ExpCellFilter(atoms, constant_volume=True)

            # Perform geometry optimization
            opt = FIRE(atoms_c, logfile=None)

            run = opt.run(fmax=0.05, steps=100)  # Adjust fmax for convergence criteria
            
            if run:
                atoms_c = ExpCellFilter(atoms, constant_volume=False)

                # Perform geometry optimization
                opt = FIRE(atoms_c, logfile=None)

                run = opt.run(fmax=0.01, steps=2000)  # Adjust fmax for convergence criteria

                if run:
                    print(rr)
                    write(name+'.xyz', atoms, append=True)

In [None]:
import pandas as pd
from pymatgen.io.cif import CifParser
import io

# Load CSV file
csv_file = '../mp_20/test.csv'
df = pd.read_csv(csv_file)

In [None]:
collect_struct = {}
for name in names_list:
    collect_struct[name] = []
    if name in name_2_formula:
        formula_now = name_2_formula[name]
    else:
        formula_now = name

    condition = df["pretty_formula"] == formula_now
    df[condition]
    
    for cif_data in df[condition].cif:
        cif_file = io.StringIO(cif_data)
        parser = CifParser(cif_file)
        mp_structure = parser.get_structures()[0]
        collect_struct[name].append(mp_structure)

In [None]:
match_list = []
for name in names_list:
    for mp_structure in collect_struct[name]:
        for i, at in enumerate(read(name+'.xyz',':')):
            # Load or create your local structure
            # Example: load from a CIF file
            local_structure = ase_to_pymatgen(at)
            is_same = matcher.fit(local_structure, mp_structure)
            #print(f"Are the structures identical? {is_same}")
            if is_same:
                print(name, i)
                match_list.append(name)