In [4]:
import torch
import ase.io
from tqdm import tqdm
import os
from mace.calculators import MACECalculator
from ase import Atoms
from typing import Dict, List

import mbe_automation
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


def get_vacuum_energies(calc_mace_off: MACECalculator, calc_mace_mp: MACECalculator, z_list: List[int]) -> Dict[int, float]:
    """Calculates the energy (mace_off) for single, isolated atoms."""
    logging.info("Calculating vacuum energies for regression baseline...")
    vacuum_energies = {}
    unique_atomic_numbers = sorted(list(set(z_list)))
    
    for z in unique_atomic_numbers:
        atom = Atoms(numbers=[z])
        atom.calc = calc_mace_off
        vacuum_ref = atom.get_potential_energy()
        atom.calc = calc_mace_mp
        vacuum_base = atom.get_potential_energy()
        vacuum_energies[z] = vacuum_ref - vacuum_base
        logging.info(f"  - Referance vacuum energy for Z={z}: {vacuum_ref:.4f} eV")
        logging.info(f"  - Base vacuum energy for Z={z}: {vacuum_base:.4f} eV")

        
    return vacuum_energies

if __name__ == '__main__':
    MACE_OFF_MODEL_PATH = 'MACE-OFF24_medium.model'
    MP0_MODEL_PATH = 'MACE-MP_small.model'
    
    INPUT_HDF5_FILE = 'training_set.hdf5' 
    HDF5_TRAJECTORY_KEY = "training/md_sampling/crystal[dyn:T=298.15,p=0.00010]/trajectory"
    
    OUTPUT_DATASET_FILE = 'delta_learning_dataset_full.xyz'
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    for path in [MACE_OFF_MODEL_PATH, MP0_MODEL_PATH, INPUT_HDF5_FILE]:
        if not os.path.exists(path):
            raise FileNotFoundError(f"Could not find required file: '{path}'")

    # --- Step 1: Initialize Calculators ---
    print("Initializing calculators...")
    calc_mace_off = MACECalculator(model_paths=MACE_OFF_MODEL_PATH, device=device, default_dtype="float64")
    calc_mp0 = MACECalculator(model_paths=MP0_MODEL_PATH, device=device, default_dtype="float64")

    # --- Step 2: Load Trajectory from HDF5 ---
    print(f"Loading trajectory from HDF5 file: '{INPUT_HDF5_FILE}'")
    print(f"Using HDF5 key: '{HDF5_TRAJECTORY_KEY}'")
    
    # Read the custom trajectory object from the HDF5 file
    mbe_trajectory = mbe_automation.storage.read_trajectory(
        dataset=INPUT_HDF5_FILE,
        key=HDF5_TRAJECTORY_KEY
    )
    
    # Convert it to an ASE-compatible trajectory object
    # This object can be iterated over just like the list from ase.io.read
    trajectory = mbe_automation.storage.ASETrajectory(mbe_trajectory)
    print(f"Found {len(trajectory)} frames to process.")

    # --- Step 3: Calculate Vacuum Energy Shifts ---
    all_zs = [z for atoms in trajectory for z in atoms.get_atomic_numbers()]
    vacuum_energy_shifts = get_vacuum_energies(calc_mace_off, calc_mp0, all_zs)
    
    # --- Step 4: Calculate Energies and Modified Delta ---
    processed_atoms_list = []
    for atoms in tqdm(trajectory, desc="Processing frames"):
        # Calculate energy with the base model (mp-0)
        atoms.calc = calc_mp0
        energy_mp0 = atoms.get_potential_energy()

        # Calculate energy with the target model (mace-off)
        atoms.calc = calc_mace_off
        energy_mace_off = atoms.get_potential_energy()

        # Calculate the total energy difference
        total_delta_energy = energy_mace_off - energy_mp0
        
        # Calculate the total shift by summing the vacuum energies of all atoms in the frame
        total_vacuum_shift = sum(vacuum_energy_shifts[z] for z in atoms.get_atomic_numbers())
        
        # The residual delta is the new target for the delta-learning model
        # It's the part of the energy difference not captured by the simple atomic sum
        residual_delta_energy = total_delta_energy - total_vacuum_shift
        
        # Store all relevant information in the Atoms object's info dictionary
        atoms.info['energy_mp0'] = energy_mp0
        atoms.info['energy_mace_off'] = energy_mace_off
        atoms.info['total_delta_energy'] = total_delta_energy
        atoms.info['residual_delta_energy'] = residual_delta_energy # This is the new learning target!

        # Clear the calculator before appending to the list
        atoms.calc = None
        processed_atoms_list.append(atoms)

    # --- Step 5: Save the Processed Dataset ---
    print(f"Saving processed dataset to file: '{OUTPUT_DATASET_FILE}'")
    ase.io.write(OUTPUT_DATASET_FILE, processed_atoms_list)
    print("\nDone! The dataset with shifts has been successfully created.")

Using device: cpu
Initializing calculators...
Using head Default out of ['Default']


  torch.load(f=model_path, map_location=device)
  torch.load(f=model_path, map_location=device)


Using head Default out of ['Default']
Loading trajectory from HDF5 file: 'training_set.hdf5'
Using HDF5 key: 'training/md_sampling/crystal[dyn:T=298.15,p=0.00010]/trajectory'
Found 1001 frames to process.


Processing frames:   7%|▋         | 74/1001 [16:11<3:22:45, 13.12s/it]


KeyboardInterrupt: 

In [3]:
import torch
import ase.io
from tqdm import tqdm
import os

from mace.calculators import MACECalculator

from ase import Atoms
from typing import Dict, List

def get_vacuum_energies(calc_mp0, calc_mace_off, all_atomic_numbers: List[int]) -> Dict[int, float]:
    """
    Calculates the energy (mace_off) for single, isolated atoms.
    This serves as our per-atom energy shift.
    Code only for now it will not be neaded for any other method
    """
    print("Obliczanie energii dla pojedynczych atomów w próżni (shift)...")
    vacuum_energies = {}
    unique_atomic_numbers = sorted(list(set(all_atomic_numbers)))
    
    for z in unique_atomic_numbers:
        atom = Atoms(numbers=[z])
        
        # Oblicz energię z modelu docelowego (mace-off)
        atom.calc = calc_mace_off
        vacuum_e = atom.get_potential_energy()
        
        vacuum_energies[z] = vacuum_e
        print(f"  - Atom Z={z}: Shift = {vacuum_e:.4f} eV")
        
    return vacuum_energies



if __name__ == '__main__':
    # --- KONFIGURACJA (bez zmian) ---
    MACE_OFF_MODEL_PATH = 'MACE-OFF24_medium.model'
    MP0_MODEL_PATH = 'MACE-MP_small.model'
    INPUT_TRAJECTORY_FILE = 'mace_off_trajectory.xyz'
    OUTPUT_DATASET_FILE = 'delta_learning_dataset_with_shifts.xyz' # Nowy plik wyjściowy
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Używane urządzenie: {device}")

    # Sprawdzenie plików (bez zmian)
    for path in [MACE_OFF_MODEL_PATH, MP0_MODEL_PATH, INPUT_TRAJECTORY_FILE]:
        if not os.path.exists(path):
            raise FileNotFoundError(f"Nie znaleziono wymaganego pliku: '{path}'")

    # Krok 1: Inicjalizacja kalkulatorów (bez zmian)
    print("Inicjalizacja kalkulatorów...")
    calc_mace_off = MACECalculator(model_path=MACE_OFF_MODEL_PATH, device=device)
    calc_mp0 = MACECalculator(model_path=MP0_MODEL_PATH, device=device)

    # Krok 2: Wczytanie trajektorii (bez zmian)
    print(f"Wczytywanie trajektorii z pliku: '{INPUT_TRAJECTORY_FILE}'")
    trajectory = ase.io.read(INPUT_TRAJECTORY_FILE, index=':')
    print(f"Znaleziono {len(trajectory)} klatek do przetworzenia.")

    # --- NOWY KROK: Obliczenie energii w próżni (shift) ---
    all_zs = [z for atoms in trajectory for z in atoms.get_atomic_numbers()]
    vacuum_energy_shifts = get_vacuum_energies(calc_mp0, calc_mace_off, all_zs)
    # --------------------------------------------------------

    # Krok 3: Obliczenie energii i ZMODYFIKOWANEJ delty
    processed_atoms_list = []
    for atoms in tqdm(trajectory, desc="Przetwarzanie klatek"):
        atoms.calc = calc_mp0
        energy_mp0 = atoms.get_potential_energy()

        atoms.calc = calc_mace_off
        energy_mace_off = atoms.get_potential_energy()

        # Oblicz pełną deltę
        total_delta_energy = energy_mace_off - energy_mp0
        
        # --- ZMIANA: Oblicz sumę shiftów i resztkową deltę ---
        total_vacuum_shift = sum(vacuum_energy_shifts[z] for z in atoms.get_atomic_numbers())
        residual_delta_energy = total_delta_energy - total_vacuum_shift
        
        # Zapisz wszystkie informacje w obiekcie Atoms
        # 'residual_delta_energy' to będzie nasz nowy cel do uczenia
        atoms.info['energy_mp0'] = energy_mp0
        atoms.info['energy_mace_off'] = energy_mace_off
        atoms.info['total_delta_energy'] = total_delta_energy # Zachowaj dla ewaluacji
        atoms.info['residual_delta_energy'] = residual_delta_energy # Nowy cel!
        # --------------------------------------------------------

        atoms.calc = None
        processed_atoms_list.append(atoms)

    # Krok 4: Zapisz przetworzony zbiór danych (bez zmian)
    print(f"Zapisywanie przetworzonego zbioru danych do pliku: '{OUTPUT_DATASET_FILE}'")
    ase.io.write(OUTPUT_DATASET_FILE, processed_atoms_list)
    print("\nGotowe! Zbiór danych z uwzględnieniem shiftów został pomyślnie utworzony.")


  torch.load(f=model_path, map_location=device)


Używane urządzenie: cpu
Inicjalizacja kalkulatorów...
Using head Default out of ['Default']
No dtype selected, switching to float64 to match model dtype.


  torch.load(f=model_path, map_location=device)


Using head Default out of ['Default']
No dtype selected, switching to float64 to match model dtype.
Wczytywanie trajektorii z pliku: 'mace_off_trajectory.xyz'
Znaleziono 101 klatek do przetworzenia.
Obliczanie energii dla pojedynczych atomów w próżni (shift)...
  - Atom Z=1: Shift = -13.5720 eV
  - Atom Z=6: Shift = -1030.5672 eV
  - Atom Z=7: Shift = -1486.3750 eV


Przetwarzanie klatek: 100%|██████████| 101/101 [03:02<00:00,  1.81s/it]

Zapisywanie przetworzonego zbioru danych do pliku: 'delta_learning_dataset_with_shifts.xyz'

Gotowe! Zbiór danych z uwzględnieniem shiftów został pomyślnie utworzony.





In [1]:
!git clone https://github.com/modrzejewski/mbe-automation.git

Cloning into 'mbe-automation'...
remote: Enumerating objects: 3768, done.[K
remote: Counting objects: 100% (443/443), done.[K
remote: Compressing objects: 100% (338/338), done.[K
remote: Total 3768 (delta 180), reused 211 (delta 104), pack-reused 3325 (from 1)[K
Receiving objects: 100% (3768/3768), 868.77 KiB | 495.00 KiB/s, done.
Resolving deltas: 100% (2114/2114), done.
