In [None]:
import argparse
import logging
import os
import random
from typing import Dict, List, Tuple

import ase.io
from ase.neighborlist import neighbor_list
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, random_split
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as PyGDataLoader
from torch_geometric.nn import radius_graph

# Assuming 'models.py' is in the same directory or accessible in the python path
from models import DualReadoutMACE
from train import evaluate, DeltaEnergyLoss, load_data, pyg_collate, AtomsDataset

def main():
    # --- Configuration ---
    # Define your training parameters here instead of using command-line arguments.
    
    # Path to the dataset file
    dataset_path = "delta_learning_dataset.xyz"
    
    # IMPORTANT: Path to the pre-trained base MACE model (.pt or .model file)
    # You MUST provide a valid path to your model file.
    base_model_path = "MACE-MP_small.model" 
    
    # Training hyperparameters
    epochs = 100
    lr = 1e-3
    batch_size = 10
    
    # Dataset split
    validation_split = 0.1
    
    # Output directory for saved models
    output_dir = "delta_model_checkpoints"
    # --- End of Configuration ---


    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")

    # 1. Load Data
    full_atoms_list = load_data(dataset_path)
    
    # 2. Load base model to get r_max and atomic_numbers for data preparation
    logging.info(f"Loading base MACE model from {base_model_path}")
    try:
        base_mace_model = torch.load(base_model_path, map_location=device, weights_only=False)
        base_mace_model.to(dtype=torch.float64) # Ensure model is float64
        base_mace_model.eval()
    except FileNotFoundError:
        logging.error(f"Base model file not found at '{base_model_path}'. Please provide a valid path.")
        return
    except Exception as e:
        logging.error(f"Error loading base model: {e}")
        return

    try:
        r_max = base_mace_model.r_max.item()
        z_map = {z.item(): i for i, z in enumerate(base_mace_model.atomic_numbers)}
        logging.info(f"Using r_max={r_max:.2f} and atomic numbers from the base model.")
    except AttributeError:
        logging.error("Could not find 'r_max' or 'atomic_numbers' in the base model. Please check the model file.")
        return
    except Exception as e:
        logging.error(f"Error extracting model parameters: {e}")
        return

    dataset = AtomsDataset(full_atoms_list, r_max=r_max, z_map=z_map)

    # 3. Split data
    n_total = len(dataset)
    n_val = int(n_total * validation_split)
    n_train = n_total - n_val
    train_dataset, val_dataset = random_split(dataset, [n_train, n_val])
    
    logging.info(f"Training set size: {len(train_dataset)}")
    logging.info(f"Validation set size: {len(val_dataset)}")

    train_loader = PyGDataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=pyg_collate)
    val_loader = PyGDataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=pyg_collate)
    
    # 4. Initialize the delta-learning model
    model = DualReadoutMACE(base_mace_model)
    model.to(dtype=torch.float64) # Ensure the wrapper and new layer are also float64
    model.to(device)

    # 5. Setup optimizer, scheduler, and loss
    # The optimizer will only act on parameters where requires_grad is True,
    # which is only our new delta_readout layer.
    optimizer = torch.optim.Adam(
        [p for p in model.parameters() if p.requires_grad], 
        lr=lr
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10)
    loss_fn = DeltaEnergyLoss()

    # 6. Training loop
    best_val_loss = float('inf')
    logging.info("Starting training...")

    for epoch in range(epochs):
        model.train()
        total_train_loss = 0.0
        
        for batch in train_loader:
            batch = batch.to(device)
            
            optimizer.zero_grad()

            data_dict = batch.to_dict()
            if 'pos' in data_dict:
                data_dict['positions'] = data_dict.pop('pos')
            
            # Forward pass
            output = model(data_dict)
            
            # Compute loss
            loss = loss_fn(output, batch)
            
            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item() * batch.num_graphs

        avg_train_loss = total_train_loss / len(train_loader.dataset)
        
        # Validation
        avg_val_loss, mae, rmse = evaluate(model, val_loader, loss_fn, device)
        scheduler.step(avg_val_loss)

        logging.info(
            f"Epoch {epoch+1}/{epochs} | "
            f"Train Loss: {avg_train_loss:.6f} | "
            f"Val Loss: {avg_val_loss:.6f} | "
            f"Val MAE (delta): {mae:.6f} | "
            f"Val RMSE (delta): {rmse:.6f}"
        )
        
        # Save checkpoint if validation loss improves
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            checkpoint_path = os.path.join(output_dir, "best_model.pt")
            # We save the state_dict of the entire model, but only the delta_readout part is trained.
            torch.save(model.state_dict(), checkpoint_path)
            logging.info(f"New best model saved to {checkpoint_path}")

    logging.info("Training complete.")
    final_model_path = os.path.join(output_dir, "final_model.pt")
    torch.save(model.state_dict(), final_model_path)
    logging.info(f"Final model saved to {final_model_path}")

if __name__ == "__main__":
    main()

2025-10-08 12:00:09,303 - INFO - Using device: cpu
2025-10-08 12:00:09,303 - INFO - Loading data from 'delta_learning_dataset.xyz'...
2025-10-08 12:00:09,329 - INFO - Successfully loaded 101 configurations.
2025-10-08 12:00:09,329 - INFO - Loading base MACE model from MACE-MP_small.model
  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.


2025-10-08 12:00:10,364 - INFO - Using r_max=6.00 and atomic numbers from the base model.
2025-10-08 12:00:10,365 - INFO - Training set size: 91
2025-10-08 12:00:10,366 - INFO - Validation set size: 10
2025-10-08 12:00:10,368 - INFO - Starting training...


Zamrażanie parametrów całego modelu bazowego MACE...
Wykryto 128 cech wejściowych do głowy (readout).
Wagi nowej głowy 'delta_readout' zostały zainicjalizowane zerami.
Hak poprawnie zarejestrowany na ostatnim bloku 'product'.


2025-10-08 12:00:46,325 - INFO - Epoch 1/100 | Train Loss: 4135116184.329174 | Val Loss: 4132480272.240263 | Val MAE (delta): 32125.273065 | Val RMSE (delta): 32125.273066
2025-10-08 12:00:46,354 - INFO - New best model saved to delta_model_checkpoints/best_model.pt
2025-10-08 12:01:22,925 - INFO - Epoch 2/100 | Train Loss: 4130905435.096497 | Val Loss: 4128285882.890433 | Val MAE (delta): 32092.641073 | Val RMSE (delta): 32092.641073
2025-10-08 12:01:22,960 - INFO - New best model saved to delta_model_checkpoints/best_model.pt
2025-10-08 12:02:00,428 - INFO - Epoch 3/100 | Train Loss: 4126765102.399417 | Val Loss: 4124105344.384257 | Val MAE (delta): 32060.100339 | Val RMSE (delta): 32060.100340
2025-10-08 12:02:00,470 - INFO - New best model saved to delta_model_checkpoints/best_model.pt
2025-10-08 12:02:40,681 - INFO - Epoch 4/100 | Train Loss: 4122634790.399866 | Val Loss: 4119927499.410063 | Val MAE (delta): 32027.564090 | Val RMSE (delta): 32027.564091
2025-10-08 12:02:40,733 - I