In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

In [None]:
import cace
from cace.representations.cace_representation import Cace

In [None]:
import numpy as np
import ase
from ase import Atoms
from ase.optimize import FIRE
from ase.constraints import ExpCellFilter

from ase.visualize import view
from ase.md import Langevin
from ase import units
import numpy as np
import time
from ase.io import read,write


In [None]:
from cace.calculators import CACECalculator

In [None]:
cace_nnp = torch.load('best_model.pth')

In [None]:
trainable_params = sum(p.numel() for p in cace_nnp.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {trainable_params}")

In [None]:
calculator = CACECalculator(model_path=cace_nnp, 
                            device='cpu', 
                            energy_key='CACE_energy', 
                            forces_key='CACE_forces',
                            stress_key='CACE_stress',
                           compute_stress=True)

In [None]:
min_distance = 0.7

for v_atom in np.linspace(3.8,5.6,100): # 3.8
    
    print("molar_V: ", v_atom)
    
    for num_atoms in [2,4,6,8,12]:
        # v=A^3/atom
        box_size = (num_atoms * v_atom)**(1./3.) 

        for rr in range(8):
            
            print("No. atoms: ", num_atoms, " R: ", rr)

            positions = []
            symbols = []

            while len(positions) < num_atoms:
                new_pos = np.random.rand(3) * (box_size - min_distance)
                if all(np.linalg.norm(new_pos - p) >= min_distance for p in positions):
                    positions.append(new_pos)
                    symbols.append("C")

            old_cell = np.array([[box_size, 0, 0], [0, box_size, 0], [0, 0, box_size]])
            old_v = box_size**3.
            new_cell = old_cell + (np.random.rand(3,3) - 0.5) * box_size * 0.5
            new_v = np.linalg.det(new_cell)
            new_cell *= (old_v/new_v)**(1./3.)

            atoms = Atoms(symbols, 
                          positions=positions, 
                          cell=old_cell, 
                          pbc=True)

            # Scale the positions to maintain the relative coordinates
            scaled_positions = atoms.get_scaled_positions()
            atoms.set_cell(new_cell, scale_atoms=True)
            atoms.set_scaled_positions(scaled_positions)

            atoms.set_calculator(calculator)

            atoms_c = ExpCellFilter(atoms, constant_volume=True)

            # Perform geometry optimization
            opt = FIRE(atoms_c, logfile=None, maxstep=0.01)
                    
            #opt.attach(write_frame_opt, interval=10)

            run = opt.run(fmax=0.0005, steps=500)  # Adjust fmax for convergence criteria

            if run:
                atoms.info['energy_per_atom'] = atoms.get_potential_energy() / num_atoms
                atoms.info['volume_per_atom'] = atoms.get_volume() / num_atoms
                write('relaxation-final.xyz', atoms, append=True)