In [None]:
import numpy as np
import matplotlib.pyplot as plt
from ase import Atoms, units
from ase.io import Trajectory
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from ase.md.nptberendsen import NPTBerendsen
from ase.md.nose_hoover_chain import NoseHooverChainNVT
from ase.ga.utilities import get_rdf
from ase.geometry.analysis import Analysis
from ase.data import atomic_masses, atomic_numbers
from scipy.spatial.distance import cdist
from numpy.polynomial import Polynomial
import os



class MoltenSaltSimulator:
    """
    Class for building molten salt systems and running molecular dynamics simulations.
    """

    def __init__(self, model_name="uma-s-1", device="cuda"): 
        """Initialize the simulator with a specific ML potential."""
        self.device = device
        self.results = {} 


    def set_calculator(self, model_name, model_parameters):
        """
        Set up the calculator based on the chosen ML potential.
        """

        if model_name == "fairchem" or model_name == "FAIRCHEM" or model_name == "FAIRChem":

            from fairchem.core import pretrained_mlip, FAIRChemCalculator
            if model_parameters.get("model_size") == "small":
                predictor = pretrained_mlip.get_predict_unit("uma-s-1", device=self.device)
                self.calc = FAIRChemCalculator(predictor, task_name=model_parameters.get("model_task"))
            elif model_parameters.get("model_size") == "medium":
                predictor = pretrained_mlip.get_predict_unit("uma-m-1p1", device=self.device)
                self.calc = FAIRChemCalculator(predictor, task_name=model_parameters.get("model_task"))
            else:
                raise ValueError("This calculator type is not included in this package")
    
        elif model_name == "MACE" or model_name == "mace":

            from mace.calculators import mace_mp
            from mace.calculators import MACECalculator
            if model_parameters.get("model_type") == "mace-mh-1":
                self.calc = mace_mp(model="/home/amd18/mace_models/mace-mh-1.model", default_dtype="float64", device="cuda", head="omat_pbe")
            elif model_parameters.get("model_type") == "SuperSalt-swa":
                self.calc = MACECalculator(model="/home/amd18/mace_models/SuperSalt-swa.model", device="cuda")
            elif model_parameters.get("model_type") == "SuperSalt":
                self.calc = MACECalculator(model_paths=['/home/amd18/mace_models/SuperSalt.model'], device='cuda', enable_cueq=False)
            else:
                raise ValueError("This calculator type is not included in this package")

        elif model_name == "GRACE" or model_name == "grace":

            from tensorpotential.calculator.foundation_models import grace_fm, GRACEModels
            if model_parameters.get("model_size") == "small" and model_parameters.get("layer") == 1:
                self.calc = grace_fm(GRACEModels.GRACE_1L_OMAT)
            elif model_parameters.get("model_size") == "medium" and model_parameters.get("layer") == 1:
                self.calc = grace_fm(GRACEModels.GRACE_1L_OMAT_medium_base)
            elif model_parameters.get("model_size") == "large" and model_parameters.get("layer") == 1:
                self.calc = grace_fm(GRACEModels.GRACE_1L_OMAT_large_base)
            elif model_parameters.get("model_size") == "small" and model_parameters.get("layer") == 2:
                self.calc = grace_fm(GRACEModels.GRACE_2L_OMAT)
            elif model_parameters.get("model_size") == "medium" and model_parameters.get("layer") == 2:
                self.calc = grace_fm(GRACEModels.GRACE_2L_OMAT_medium_base)
            elif model_parameters.get("model_size") == "large" and model_parameters.get("layer") == 2:
                self.calc = grace_fm(GRACEModels.GRACE_2L_OMAT_large_base)
            else:
                raise ValueError("This calculator type is not included in this package")

        else:
            raise ValueError("This calculator is not included in this package")

        return self.calc

    def create_folder(self, folder_name):
        """
        Create a folder structure for simulation outputs.
        """
        
        run_dir = os.path.join(os.getcwd(), folder_name)
        os.makedirs(run_dir, exist_ok=True)

        npt_dir = os.path.join(run_dir, "NPT")
        nvt_dir = os.path.join(run_dir, "NVT")
        os.makedirs(npt_dir, exist_ok=True)
        os.makedirs(nvt_dir, exist_ok=True)

        print(f"Saving all trajectories in: {run_dir}") 
        return npt_dir, nvt_dir

    
    def build_system(self, salt_anion, salt_cation, anion_Natoms, cation_Natoms, density_guess):
        """
        Build a molten salt system with random initial positions.
        
        Parameters:
        -----------
        salt_anion : list of str
            Chemical symbols for anions
        salt_cation : list of str
            Chemical symbols for cations
        anion_Natoms : list of int
            Number of atoms for each anion type
        cation_Natoms : list of int
            Number of atoms for each cation type
        density_guess : float
            Initial density guess (g/cm³)
        
        Returns:
        --------
        atoms : ASE Atoms object
            The constructed system
        """

        if len(salt_anion) != len(anion_Natoms) or len(salt_cation) != len(cation_Natoms):
            raise ValueError("The number of salts and their number of atoms should be the same")

        # Create symbols list
        symbols = []
        for element, amount_of_atoms in zip(salt_anion, anion_Natoms):
            symbols += [element] * amount_of_atoms
        for element, amount_of_atoms in zip(salt_cation, cation_Natoms):
            symbols += [element] * amount_of_atoms

        # Calculate initial box size from density guess
        mass = sum(atomic_masses[atomic_numbers[sym]] for sym in symbols) * 1.66054e-24  # g
        volume_guess = mass / density_guess  # cm³
        initial_box_size = (volume_guess * 1e24) ** (1/3) # Å

        # Place atoms with minimum distance constraint
        min_distance = 1.6  # Å
        positions_atoms = np.zeros((len(symbols), 3))
 
        for i in range(len(symbols)):
            while True:
                new_pos = np.random.rand(3) * initial_box_size
                if i == 0:  # first atom has no neighbors
                    positions_atoms[i] = new_pos
                    break
                distances = cdist([new_pos], positions_atoms[:i])
                if np.all(distances > min_distance):
                    positions_atoms[i] = new_pos
                    break

        # Create ASE Atoms object
        atoms = Atoms(
            symbols=symbols,
            positions=positions_atoms,
            cell=[initial_box_size] * 3,
            pbc=True
        )
        
        if self.calc:
            atoms.calc = self.calc
        else:
            raise RuntimeError("Calculator not set. Use set_calculator() first.")

        return atoms


    def run_npt_simulation(self, atoms, T, steps=1000, printInterval=100,
                            traj_file="npt_simulation.traj", print_status=True):
        """
        Run NPT (constant pressure, temperature) molecular dynamics.
        
        Parameters:
        -----------
        atoms : ASE Atoms object
            System to simulate
        T : float
            Temperature (K)
        steps : int
            Number of MD steps
        print_interval : int
            Interval for printing status
        traj_file : str
            Output trajectory file
        print_status : bool
            Whether to print simulation status
        
        Returns:
        --------
        None
        """

        MaxwellBoltzmannDistribution(atoms, temperature_K=T)

        dyn = NPTBerendsen(
            atoms,
            timestep=1.0 * units.fs,
            temperature_K=T,
            taut=100 * units.fs,
            pressure_au=1.01325 * units.bar,
            taup=1000 * units.fs,
            compressibility_au=4.0e-5 / units.bar,
            logfile='npt_equili.log'
        )

        trajectory_npt = Trajectory(traj_file, "w", atoms)
        dyn.attach(trajectory_npt.write, interval=10)

        if print_status:
            def print_status_func():
                step = dyn.get_number_of_steps()
                stress_tensor = atoms.get_stress(voigt=False) * 1.60218e6
                pressure = -np.trace(stress_tensor) / 3
                p_xy, p_xz, p_yz = stress_tensor[0, 1], stress_tensor[0, 2], stress_tensor[1, 2]
                print(f"Step {step:6d} | P = {pressure:.6e} bar | V = {atoms.get_volume():8.2f} Å³")
            
            dyn.attach(print_status_func, interval=printInterval)

        dyn.run(steps)
        trajectory_npt.close()
        print(f"NPT trajectory saved to {traj_file}")
        

    def run_nvt_simulation(self, atoms, T, steps=1000, printInterval=100,
                             traj_file="nvt_simulation.traj", print_status=True):
        """
        Run NVT (constant volume, temperature) molecular dynamics.
        
        Parameters:
        -----------
        atoms : ASE Atoms object
            System to simulate
        T : float
            Temperature (K)
        steps : int
            Number of MD steps
        print_interval : int
            Interval for printing status
        traj_file : str
            Output trajectory file
        print_status : bool
            Whether to print simulation status
        
        Returns:
        --------
        None
        """

        MaxwellBoltzmannDistribution(atoms, temperature_K=T)

        dyn = NoseHooverChainNVT(
            atoms,
            timestep=1.0*units.fs,
            temperature_K=T,
            tdamp=100 * units.fs,  
            logfile='nvt_run.log')

        trajectory_nvt = Trajectory(traj_file, "w", atoms)
        dyn.attach(trajectory_nvt.write, interval=10)

        if print_status:
            def print_status_func():
                step = dyn.get_number_of_steps()
                stress_tensor = atoms.get_stress(voigt=False) * 1.60218e6
                pressure = -np.trace(stress_tensor) / 3
                print(f"Step {step:6d} | P = {pressure:.6e} bar | V = {atoms.get_volume():8.2f} Å³")
            
            dyn.attach(print_status_func, interval=printInterval)

        dyn.run(steps)
        trajectory_nvt.close()
        print(f"NVT trajectory saved to {traj_file}")


    def compute_density(self, traj_file):
        """
        Compute density from trajectory file.
        
        Parameters:
        -----------
        traj_file : str
            Path to trajectory file
        
        Returns:
        --------
        density : float
            Density in g/cm³
        """

        traj = Trajectory(traj_file)
        
        # Use last 10% for equilibrium
        volumes = [atoms.get_volume() for atoms in traj]
        equilibrium_volume = np.mean(volumes[-int(len(volumes) * 0.1):])
        
        masses = traj[0].get_masses().sum() * 1.66054e-24  # g
        density = masses / (equilibrium_volume * 1e-24)  # g/cm³
        
        return density


    def plot_density_vs_time(self, traj_file, title="Density vs Time"):
        """
        Plot density evolution during simulation.
        
        Parameters:
        -----------
        traj_file : str
            Path to trajectory file
        title : str
            Plot title

        Returns:
        --------
        densities: array
            List of densities in g/cm³
        """

        traj = Trajectory(traj_file)
        
        volumes = [atoms.get_volume() for atoms in traj]
        masses = traj[0].get_masses().sum() * 1.66054e-24
        
        densities = masses / (np.array(volumes) * 1e-24)
        eq_density = np.mean(densities[-int(0.1 * len(densities)):])
        
        plt.figure()
        plt.plot(densities, color='midnightblue', label='Density')
        plt.axhline(y=eq_density, color='darkorange', linestyle='--', linewidth=3, label='Equilibrium density')
        plt.title(title)
        plt.xlabel("Time (ps)")
        plt.ylabel("Density (g/cm³)")
        plt.grid(True, alpha=0.3)
        plt.legend(loc='upper right')
        plt.tight_layout()
        plt.show()
        
        return densities


    def compute_thermal_expansion(self, npt_dir, salt_name, temperatures):
        """
        Compute thermal expansion coefficient.
        
        Parameters:
        -----------
        npt_dir : str
            Directory with NPT trajectories
        salt_name : str
            Name of the salt
        temperatures : list
            List of temperatures
        
        Returns:
        --------
        dict
            Dictionary with thermal expansion results
        """

        box_lengths = []
        densities = []
        
        for T in temperatures:
            traj_file = os.path.join(npt_dir, f"npt_{salt_name}_{T}K.traj")
            if not os.path.exists(traj_file):
                continue
            
            traj = Trajectory(traj_file)
            volumes = [atoms.get_volume() for atoms in traj]
            eq_vol = np.mean(volumes[-int(0.1 * len(volumes)):])
            box_lengths.append(eq_vol ** (1/3))
            densities.append(self.compute_density(traj_file))
        
        box_lengths = np.array(box_lengths)
        densities = np.array(densities)
        temperatures = np.array(temperatures)
        
        # Fit linear thermal expansion
        fit = np.polyfit(temperatures, box_lengths / box_lengths[0], 1)
        T_fit = np.linspace(min(temperatures), max(temperatures), 100)
        fit_line = np.polyval(fit, T_fit)
        
        return {
            "temperatures": temperatures,
            "box_ratios": box_lengths / box_lengths[0],
            "fit": fit,
            "T_fit": T_fit,
            "fit_line": fit_line,
            "thermal_expansion": fit[0]
        }   


    def compute_heat_capacity(self, traj_file, T):
        """
        Compute heat capacity from enthalpy fluctuations.
        
        Parameters:
        -----------
        traj_file : str
            Path to trajectory file
        T : float
            Temperature (K)
        
        Returns:
        --------
        Cp : float
            Heat capacity in J/g/K
        """

        traj = Trajectory(traj_file)
        
        # Compute enthalpy: H = U + PV
        H = np.array([
            atoms.get_kinetic_energy() + atoms.get_potential_energy() +
            units.bar * atoms.get_volume() * 1e-30 / units.eV
            for atoms in traj
        ])
        
        # Use last 10% for equilibrium
        H_equil = H[-int(0.1 * len(H)):]
        var_H = np.var(H_equil, ddof=1) * (1.60218e-19) ** 2  # Convert to J²
        
        mass_total = traj[0].get_masses().sum() * 1.66054e-24  # g
        
        Cp = var_H / (1.380649e-23 * T**2 * mass_total)  # J/g/K
        
        return Cp

    
    def compute_diffusion_coefficient(self, traj_file):
        """
        Compute diffusion coefficient from mean squared displacement.
        
        Parameters:
        -----------
        traj_file : str
            Path to trajectory file
        
        Returns:
        --------
        D : float
            Diffusion coefficient in Å²/fs
        """

        traj = Trajectory(traj_file)
        
        positions = np.array([atoms.get_positions() for atoms in traj])
        nsteps, natoms, _ = positions.shape
        
        r0 = positions[0]
        msd = np.mean(np.sum((positions - r0) ** 2, axis=2), axis=1)
        
        times = np.arange(nsteps) * 1.0 * units.fs
        slope, _ = np.polyfit(times, msd, 1)
        
        D = slope / 6.0  # Å²/fs
        
        return D


    def compute_rdf(self, traj_file, rmax=6, nbins=100, pairs=None):
        """
        Compute radial distribution functions.
        
        Parameters:
        -----------
        traj_file : str
            Path to trajectory file
        rmax : float
            Maximum distance (Å)
        nbins : int
            Number of bins
        pairs : list of tuples
            Pairs to compute RDF for
        
        Returns:
        --------
        dict
            Dictionary with RDF results
        """

        traj = Trajectory(traj_file)
        symbols = traj[0].get_chemical_symbols()
        
        atoms_list = [
            Atoms(symbols=symbols, positions=atoms.get_positions(),
                  cell=atoms.get_cell(), pbc=True)
            for atoms in traj
        ]
        
        ana = Analysis(atoms_list)
        
        unique_elements = sorted(set(symbols))
        if pairs is None:
            pairs = [(a, b) for i, a in enumerate(unique_elements)
                    for b in unique_elements[i:]]
        
        rdf_results = {}
        for pair in pairs:
            rdfs = ana.get_rdf(rmax=rmax, nbins=nbins, elements=pair)
            if rdfs:
                avg_rdf = np.mean(rdfs, axis=0)
                distances = np.linspace(0, rmax, nbins)
                rdf_results[pair] = (distances, avg_rdf)
        
        return rdf_results


    def plot_rdf(self, traj_file, title="Radial Distribution Function"):
        """
        Plot radial distribution functions.
        
        Parameters:
        -----------
        traj_file : str
            Path to trajectory file
        title : str
            Plot title
        """

        rdf_data = self.compute_rdf(traj_file)
        
        colors = ["midnightblue", "darkorange", "crimson", "green", "purple"]
        
        for i, ((pair, (distances, avg_rdf)), color) in enumerate(
            zip(rdf_data.items(), colors[:len(rdf_data)])
        ):
            peak_index = np.argmax(avg_rdf)
            peak_r = distances[peak_index]
            
            plt.plot(
                distances, avg_rdf,
                label=f"{pair[0]}-{pair[1]} (peak: {peak_r:.2f} Å)",
                linewidth=2,
                color=color
            )
            plt.axvline(peak_r, linestyle='--', color=color, alpha=0.5)
        
        plt.xlabel("Distance (Å)")
        plt.ylabel("g(r)")
        plt.title(title)
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()


    def compute_viscosity(self, traj_file, T, tmax_fs=None):
        """
        Compute viscosity using Green-Kubo relation.
        
        Parameters:
        -----------
        traj_file : str
            Path to trajectory file
        T : float
            Temperature (K)
        tmax_fs : float
            Maximum correlation time (fs)
        
        Returns:
        --------
        eta : float
            Viscosity in Pa·s
        """
        traj = Trajectory(traj_file)
        
        stress_ts = np.array(
            [atoms.get_stress(voigt=False) for atoms in traj],
            dtype=float
        )
        nframes = stress_ts.shape[0]
        
        # Extract shear components
        p_xy = stress_ts[:, 0, 1].copy()
        p_xz = stress_ts[:, 0, 2].copy()
        p_yz = stress_ts[:, 1, 2].copy()
        
        # Remove means
        for comp in [p_xy, p_xz, p_yz]:
            comp -= np.mean(comp)
        
        # Convert to Pa
        conv = 1.60218e11
        p_xy *= conv
        p_xz *= conv
        p_yz *= conv
        
        def autocorr(x, nmax):
            n = len(x)
            corr = np.correlate(x, x, mode='full')[n-1:n-1+nmax]
            norm = np.arange(n, n - nmax, -1)
            return corr / norm
        
        if tmax_fs is None:
            nmax = nframes
        else:
            nmax = min(nframes, int(np.ceil(tmax_fs / 1e-15)))
        
        ac_mean = (
            autocorr(p_xy, nmax) +
            autocorr(p_xz, nmax) +
            autocorr(p_yz, nmax)
        ) / 3.0
        
        dt = 1.0 * 1e-15
        times = np.arange(ac_mean.size) * dt
        
        V = np.mean([atoms.get_volume() for atoms in traj]) * 1e-30
        integral = np.trapz(ac_mean, times)
        
        eta = V * integral / (1.380649e-23 * T)
        
        return eta

# Running the simulation

In [None]:
sim = MoltenSaltSimulator()
sim.set_calculator(model_name = "grace", model_parameters = {"model_size": "medium", "layer": 2})
npt_dir, nvt_dir = sim.create_folder(folder_name="GRACE_medium_2L_0.3NaCl-0.2KCl-0.5MgCl2")

# Define salts to simulate like:   "salt_name": ([anions], [cations], amount_of_anions, amount_of_cations)
salts = {
    #"NaCl": (["Cl"], ["Na"], [100], [100]),  
    "0.3NaCl-0.2KCl-0.5MgCl2": (["Cl"], ["K", "Mg", "Na"], [150], [20, 50, 30]),
}

# Define at which temperatures you want to calculate the properties per salt
temperatures = {
    "NaCl": [1100, 1125, 1150, 1175, 1200],
    "0.3NaCl-0.2KCl-0.5MgCl2": [700, 800, 900, 1000, 1100], 
}

# Define what density you guess the salt to have at the corresponding temperatures
density_guesses = {
    "NaCl": [1.542, 1.528, 1.515, 1.501, 1.488],
    "0.3NaCl-0.2KCl-0.5MgCl2": [1.761, 1.719, 1.677, 1.635, 1.593],
}

# Run the simulation
for salt_name, (anions, cations, n_anions, n_cations) in salts.items():
    print(f"Running NPT simulations for {salt_name}...")
    
    # Pair each temperature with its corresponding density guess
    for T, density_guess in zip(temperatures[salt_name], density_guesses[salt_name]):
        atoms = sim.build_system(anions, cations, n_anions, n_cations, density_guess)
        traj_file_npt = os.path.join(npt_dir, f"npt_{salt_name}_{T}K.traj")
        traj_file_nvt = os.path.join(nvt_dir, f"nvt_{salt_name}_{T}K.traj")
        sim.run_npt_simulation(atoms, T, steps=1000, printInterval=100, traj_file=traj_file_npt, print_status=True)      
        sim.run_nvt_simulation(atoms, T, steps=1000, printInterval=100, traj_file=traj_file_nvt, print_status=True)

# Density Vs Time

In [None]:
sim = MoltenSaltSimulator()

# Folder where trajectories were saved
run_folder = "GRACE_medium_2L_NaCl"  # Replace with the folder name you used
npt_dir = os.path.join(os.getcwd(), run_folder, "NPT")

# Define salts and temperatures for which you want to calculate the heat capacity
salts = {
    "NaCl": [1100, 1125, 1150, 1175, 1200], 
    #"0.3NaCl-0.2KCl-0.5MgCl2": [700, 800, 900, 1000, 1100], 
}

# Loop over salts and compute heat capacity for each temperature
for salt_name, temps in salts.items():
    for T in temps:
        traj_file_npt = os.path.join(npt_dir, f"npt_{salt_name}_{T}K.traj")
        sim.plot_density_vs_time(traj_file_npt, title = f"{salt_name} at {T}K using GRACE medium 1L")

# Density

In [None]:
sim = MoltenSaltSimulator()

salts = {
    "NaCl": {"temps": [1100, 1125, 1150, 1175, 1200], "rho_lit": [1.542, 1.528, 1.515, 1.501, 1.488]},   
}

# Folder where trajectories were saved 
salt_folders = {
    "NaCl": ["GRACE_medium_2L_NaCl", "GRACE_medium_1L_NaCl"],
    }

calculator_names = ["GRACE medium 2L", "GRACE medium 1L"] 

# Process each salt separately
for salt_name, data in salts.items():
    run_folders = salt_folders[salt_name]
    
    # Dictionary to store densities from all runs for this salt
    all_densities = {folder: [] for folder in run_folders}

    # Loop over all run folders and collect densities
    for folder in run_folders:
        npt_dir = os.path.join(os.getcwd(), folder, "NPT")
        
        print(f"\n=== Densities for {salt_name} in {folder} ===")
        temp_list = data["temps"]
        for idx, T in enumerate(temp_list):
            traj_file_npt = os.path.join(npt_dir, f"npt_{salt_name}_{T}K.traj")
            density = sim.compute_density(traj_file_npt)
            print("=========================")
            print("Density for T=", T, "K")
            print("Literature density:", data["rho_lit"][idx], "g/cm³")
            print(f"Density: {density:.4f} g/cm³")
            print("difference:", (data["rho_lit"][idx] / density - 1) * 100, "%")
            all_densities[folder].append((salt_name, T, density))
            

    # Now plot for this salt
    temps = data["temps"]
    rho_lit = data["rho_lit"]
    
    # Create a single plot
    fig, ax = plt.subplots()
    
    # Collect all data with mean densities for sorting
    all_plot_data = []
    all_plot_data.append((np.mean(rho_lit), 'Literature densities', temps, rho_lit))
    for i in range(len(run_folders)):  
        rho_sim = [d for (salt, T, d) in all_densities[run_folders[i]] if salt == salt_name]
        all_plot_data.append((np.mean(rho_sim), calculator_names[i], temps, rho_sim))
    
    # Sort by mean density (highest to lowest) to sort for the color
    all_plot_data.sort(key=lambda x: x[0], reverse=True)
    
    # Colors to assign in order
    colors = ['lightseagreen', 'lightgreen', 'gold', 'darksalmon', 'hotpink', 'mediumorchid', 'navy', 'cyan', 'grey']
    
    # Plot in sorted order with corresponding colors
    temps_fit = np.linspace(min(temps), max(temps), 100)
    for i, (mean_density, label, temps_data, rho_data) in enumerate(all_plot_data):
        if label == 'Literature densities':
            color = colors[i]
            ax.scatter(temps_data, rho_data, marker='v', color=color, label=label, s=80)
            rho_fit = Polynomial.fit(temps_data, rho_data, deg=1)
            ax.plot(temps_fit, rho_fit(temps_fit), linestyle='--', color=color, alpha=0.5)
        else:
            color = colors[i]
            ax.scatter(temps_data, rho_data, color=color, label=label, s=80)
            rho_fit = Polynomial.fit(temps_data, rho_data, deg=1)
            ax.plot(temps_fit, rho_fit(temps_fit), linestyle='--', color=color, alpha=0.5)
        
    
    ax.set_xlabel("Temperature (K)")
    ax.set_ylabel("Density (g/cm³)")
    ax.set_title(f"Density of {salt_name}")
    ax.legend(fontsize=9.5)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

# Thermal expansion


In [None]:
sim = MoltenSaltSimulator()

salts = {
    "NaCl": {"temps": [1100, 1125, 1150, 1175, 1200]},
}

# Fixed: Include full path to NPT directories
salt_folders = {
    "NaCl": [
        os.path.join(os.getcwd(), "GRACE_medium_2L_NaCl", "NPT")
    ]
}

calculator_names = ["GRACE medium 2L"]

for salt_name, data in salts.items():

    temps = data["temps"]
    run_folders = salt_folders[salt_name]

    results = {}

    print(f"\n=== Thermal Expansion for {salt_name} ===")

    for i in range(len(run_folders)):
        result = sim.compute_thermal_expansion(npt_dir=run_folders[i], salt_name=salt_name, temperatures=temps)
        results[run_folders[i]] = result
        print(f"Thermal expansion calculated using {calculator_names[i]}:  β = {result['thermal_expansion']:.6e} 1/K")

    # Plot thermal expansion
    markers = ['o', 's', 'D', '^', 'v', 'P', '.', '+', '8', '*']
    colors = ['lightseagreen', 'lightgreen', 'gold', 'darksalmon', 'hotpink', 'mediumorchid', 'navy', 'cyan']

    plt.figure()

    for i, (folder, result_data) in enumerate(results.items()):
        label = calculator_names[i]

        plt.scatter(result_data["temperatures"], result_data["box_ratios"], 
                   marker=markers[i % len(markers)], 
                   label=f"{label} (β={result_data['thermal_expansion']:.2e})", 
                   color=colors[i])
        plt.plot(result_data["T_fit"], result_data["fit_line"], 
                linestyle='--', alpha=0.7, color=colors[i])

    # Add literature value
    literature_value = 3.96e-5
    literature_fit = literature_value * result_data["temperatures"] + 1 - literature_value * result_data["temperatures"][0]
    plt.plot(result_data["temperatures"], literature_fit, 
            linestyle='--', alpha=0.7, color=colors[-1], 
            label=f'Literature (β={literature_value})', 
            marker=markers[-1])

    plt.xlabel("Temperature (K)")
    plt.ylabel("Thermal Expansion (a/a₀)")
    plt.title(f"Thermal Expansion of {salt_name}")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

# Heat capacity

In [None]:
sim = MoltenSaltSimulator()

# Salt compositions (anions, cations, anion_counts, cation_counts)
salt_compositions = {"0.3NaCl-0.2KCl-0.5MgCl2": (["Cl"], ["K", "Mg", "Na"], [150], [20, 50, 30])}

# Temperature and literature data for heat capacity
salt_data = {
            "0.3NaCl-0.2KCl-0.5MgCl2": {
                "temps": [700, 800, 900, 1000, 1100], 
                "cp_lit_dulong_petit": [1.0443, 1.0443, 1.0443, 1.0443, 1.0443], 
                "cp_lit_mole_fraction_ave": [1.0292, 1.0262, 1.0232, 1.0202, 1.0172], 
            }
        }

# Folder where trajectories were saved
run_folders = ["GRACE_medium_2L_0.3NaCl-0.2KCl-0.5MgCl2", "GRACE_medium_1L_0.3NaCl-0.2KCl-0.5MgCl2"]

calculator_names = ["GRACE medium 2L", "GRACE medium 1L"]

# Dictionary to store heat capacities from all runs
all_heat_capacities = {folder: [] for folder in run_folders}

# Loop over all run folders and collect heat capacities
for folder in run_folders:
    nvt_dir = os.path.join(os.getcwd(), folder, "NVT")
    
    for salt_name in salt_data.keys():
        print(f"\n=== Heat capacities for {salt_name} in {folder} ===")
        temp_list = salt_data[salt_name]["temps"]
        
        for T in temp_list:
            traj_file_nvt = os.path.join(nvt_dir, f"nvt_{salt_name}_{T}K.traj")
            heat_capacity = sim.compute_heat_capacity(traj_file_nvt, T) * 1000  # J/kg/K
            all_heat_capacities[folder].append((salt_name, T, heat_capacity))
            print(f"Heat capacity at {T} K: {heat_capacity:.4f} J/kg/K")

# Plot results for each salt
for salt_name, data in salt_data.items():
    temps = data["temps"]
    
    # Create a single plot
    fig, ax = plt.subplots()
    
    # Collect all data with mean heat capacities for sorting
    all_plot_data = []
    
    # Add all three literature datasets
    all_plot_data.append((np.mean(data["cp_lit_dulong_petit"]), 'Dulong-Petit', temps, data["cp_lit_dulong_petit"]))
    all_plot_data.append((np.mean(data["cp_lit_mole_fraction_ave"]), 'Mole fraction average', temps, data["cp_lit_mole_fraction_ave"]))
    
    # Add calculated values from all folders
    for i, folder in enumerate(run_folders):
        cp_sim = [cp for (salt, T, cp) in all_heat_capacities[folder] if salt == salt_name]
        if cp_sim:  # Only add if data exists
            all_plot_data.append((np.mean(cp_sim), calculator_names[i], temps, cp_sim))
    
    # Sort by mean heat capacity (highest to lowest)
    all_plot_data.sort(key=lambda x: x[0], reverse=True)
    
    # Colors to assign in order
    colors = ['lightseagreen', 'lightgreen', 'gold', 'darksalmon', 'hotpink', 'mediumorchid', 'navy', 'crimson', 'teal', 'olive']
    
    # Plot in sorted order with corresponding colors
    temps_fit = np.linspace(min(temps), max(temps), 100)
    for i, (mean_cp, label, temps_data, cp_data) in enumerate(all_plot_data):
        color = colors[i % len(colors)]  # Handle case with more datasets
        ax.scatter(temps_data, cp_data, color=color, label=label, s=80)
        
        # Only fit if we have enough points and the data varies
        if len(temps_data) > 1 and np.std(cp_data) > 0:
            cp_fit = Polynomial.fit(temps_data, cp_data, deg=1)
            ax.plot(temps_fit, cp_fit(temps_fit), linestyle='--', color=color, alpha=0.5)
    
    ax.set_xlabel("Temperature (K)", fontsize=12)
    ax.set_ylabel("Heat capacity (J/kg/K)", fontsize=12)
    ax.set_title(f"Heat capacity of {salt_name}", fontsize=13)
    ax.legend(fontsize=9, loc='best')
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

# Diffusion coefficient

In [None]:
sim = MoltenSaltSimulator()

salts = {
    "NaCl": {"temps": [1100, 1125, 1150, 1175, 1200], "D_lit": [7.18e-5, 7.725e-5, 8.295e-5, 8.8775e-5, 9.47e-5]}, 
}

# Folder where trajectories were saved - organized by salt
salt_folders = {
    "NaCl": ["GRACE_medium_2L_NaCl", "GRACE_medium_1L_NaCl"]
   }

calculator_names = {
    "NaCl": ["GRACE medium 2L", "GRACE medium 1L"]
    }

# Process each salt separately
for salt_name, data in salts.items():
    run_folders = salt_folders[salt_name]
    
    # Dictionary to store diffusion coefficients from all runs for this salt
    all_diffusion_coefficients = {folder: [] for folder in run_folders}

    # Loop over all run folders and collect diffusion coefficients
    for folder in run_folders:
        nvt_dir = os.path.join(os.getcwd(), folder, "NVT")
        
        print(f"\n=== Diffusion coefficients for {salt_name} in {folder} ===")
        temp_list = data["temps"]
        for T in temp_list:
            traj_file_nvt = os.path.join(nvt_dir, f"nvt_{salt_name}_{T}K.traj")
            diffusion_coefficient = sim.compute_diffusion_coefficient(traj_file_nvt) * 1e-1  # Convert Angstrom^2/fs to cm²/s
            all_diffusion_coefficients[folder].append((salt_name, T, diffusion_coefficient))
            print(f"Diffusion coefficient at {T} K: {diffusion_coefficient:.4e} cm²/s")

    # Now plot for this salt
    temps = data["temps"]
    D_lit = data["D_lit"]
    
    # Create a single plot
    fig, ax = plt.subplots()
    
    # Collect all data with mean diffusion coefficients for sorting
    all_plot_data = []
    all_plot_data.append((np.mean(D_lit), 'Literature diffusion coefficients', temps, D_lit))
    for i in range(len(run_folders)):  
        D_sim = [d for (salt, T, d) in all_diffusion_coefficients[run_folders[i]] if salt == salt_name]
        all_plot_data.append((np.mean(D_sim), calculator_names[salt_name][i], temps, D_sim))
    
    # Sort by mean diffusion coefficient (highest to lowest) to sort for the color
    all_plot_data.sort(key=lambda x: x[0], reverse=True)
    
    # Colors to assign in order
    colors = ['lightseagreen', 'lightgreen', 'gold', 'darksalmon', 'hotpink', 'mediumorchid', 'navy', 'cyan', 'gray']
    
    # Plot in sorted order with corresponding colors
    temps_fit = np.linspace(min(temps), max(temps), 100)
    for i, (mean_D, label, temps_data, D_data) in enumerate(all_plot_data):
        if label == 'Literature diffusion coefficients':
            color = colors[i]
            ax.scatter(temps_data, D_data, marker='v', color=color, label=label, s=80)
            D_fit = Polynomial.fit(temps_data, D_data, deg=1)
            ax.plot(temps_fit, D_fit(temps_fit), linestyle='--', color=color, alpha=0.5)
        else:
            color = colors[i]
            ax.scatter(temps_data, D_data, color=color, label=label, s=80)
            D_fit = Polynomial.fit(temps_data, D_data, deg=1)
            ax.plot(temps_fit, D_fit(temps_fit), linestyle='--', color=color, alpha=0.5)
    
    ax.set_xlabel("Temperature (K)")
    ax.set_ylabel("Diffusion coefficient (cm²/s)")
    ax.set_title(f"Diffusion coefficient of {salt_name}")
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

# Radial Distribution Function (RDF)

In [None]:
sim = MoltenSaltSimulator()

# Folder where trajectories were saved
run_folder = "GRACE_medium_2L_NaCl"  # Replace with the folder name you used
nvt_dir = os.path.join(os.getcwd(), run_folder, "NVT")

# Define salts and temperatures for which you want to plot the RDF
salts = {
    "NaCl": [1100, 1125, 1150, 1175, 1200],
}

# Loop over salts and compute a plot for the RDF for each salt at each temperature
for salt_name, temps in salts.items():
    for T in temps:
        traj_file_nvt = os.path.join(nvt_dir, f"nvt_{salt_name}_{T}K.traj")
        title = f"{salt_name} at {T}K"
        sim.plot_rdf(traj_file_nvt, title)

# Viscosity

In [None]:
from numpy.polynomial import Polynomial

sim = MoltenSaltSimulator()

salts = {
    "NaCl": {"temps": [1100, 1125, 1150, 1175, 1200], "eta_lit": [9.8526e-4, 9.3406e-4, 8.8759e-4, 8.4525e-4, 8.0658e-4]},  
}

# Folder where trajectories were saved - organized by salt
salt_folders = {
    "NaCl": ["GRACE_medium_2L_NaCl", "GRACE_medium_1L_NaCl"]
}

calculator_names = {
    "NaCl": ["GRACE medium 2L", "GRACE medium 1L"]
}

# Process each salt separately
for salt_name, data in salts.items():
    run_folders = salt_folders[salt_name]
    
    # Dictionary to store viscosities from all runs for this salt
    all_viscosities = {folder: [] for folder in run_folders}

    # Loop over all run folders and collect viscosities
    for folder in run_folders:
        nvt_dir = os.path.join(os.getcwd(), folder, "NVT")
        
        print(f"\n=== Viscosities for {salt_name} in {folder} ===")
        temp_list = data["temps"]
        for T in temp_list:
            traj_file_nvt = os.path.join(nvt_dir, f"nvt_{salt_name}_{T}K.traj")
            viscosity = sim.compute_viscosity(traj_file_nvt, T)  # Already in Pa·s
            all_viscosities[folder].append((salt_name, T, viscosity))
            print(f"Viscosity at {T} K: {viscosity:.4e} Pa·s")

    # Now plot for this salt
    temps = data["temps"]
    eta_lit = data["eta_lit"]
    
    # Create a single plot
    fig, ax = plt.subplots()
    
    # Collect all data with mean viscosities for sorting
    all_plot_data = []
    all_plot_data.append((np.mean(eta_lit), 'Literature viscosities', temps, eta_lit))
    for i in range(len(run_folders)):  
        eta_sim = [eta for (salt, T, eta) in all_viscosities[run_folders[i]] if salt == salt_name]
        all_plot_data.append((np.mean(eta_sim), calculator_names[salt_name][i], temps, eta_sim))
    
    # Sort by mean viscosity (highest to lowest) to sort for the color
    all_plot_data.sort(key=lambda x: x[0], reverse=True)
    
    # Colors to assign in order
    colors = ['lightseagreen', 'lightgreen', 'gold', 'darksalmon', 'hotpink', 'mediumorchid', 'navy', 'cyan', 'gray']
    
    # Plot in sorted order with corresponding colors
    temps_fit = np.linspace(min(temps), max(temps), 100)
    for i, (mean_eta, label, temps_data, eta_data) in enumerate(all_plot_data):
        if label == "Literature viscosities":
            color = colors[i]
            ax.scatter(temps_data, eta_data, marker="v", color=color, label=label, s=80)
            eta_fit = Polynomial.fit(temps_data, eta_data, deg=1)
            ax.plot(temps_fit, eta_fit(temps_fit), linestyle='--', color=color, alpha=0.5)
        else:
            color = colors[i]
            ax.scatter(temps_data, eta_data, color=color, label=label, s=80)
            eta_fit = Polynomial.fit(temps_data, eta_data, deg=1)
            ax.plot(temps_fit, eta_fit(temps_fit), linestyle='--', color=color, alpha=0.5)
    
    ax.set_xlabel("Temperature (K)")
    ax.set_ylabel("Viscosity (Pa·s)")
    ax.set_title(f"Viscosity of {salt_name}")
    ax.legend(fontsize=9.5)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()