In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import mse_loss, gaussian_nll_loss
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
import copy
import math
from ase.db import connect
from ase import Atoms
from torch.utils.data import DataLoader
from collections.abc import Sequence
import json
import itertools

In [12]:
###############################################################
# Implementations of Batch, reduce_splits, #
# cosine_cutoff, BesselExpansion, compute_edge_vectors_and_norms, and sum_index
###############################################################

"""Data object classes and related utilities."""

###############################################################
# BaseData, Data, AtomsData, GeometricData classes
###############################################################

class BaseData:
    """A dict-like base class for data objects.

    Store all tensors in a dict for easy access and enumeration.
    """

    def __init__(self, **kwargs):
        self.tensors = dict()
        for key, value in kwargs.items():
            self.__setattr__(key, value)

    def __getattr__(self, key):
        # try to get from self.tensors
        if key in self.tensors:
            return self.tensors[key]
        # If not found in tensors, raise an AttributeError
        raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'")

    def __setattr__(self, key, value) -> None:
        # store tensors in self.tensors and everything else in self.__dict__
        if isinstance(value, torch.Tensor):
            self.tensors[key] = value
            self.__dict__.pop(key, None) 
        else:
            super().__setattr__(key, value)
            self.tensors.pop(key, None)

    def __getstate__(self) -> dict:
        return self.__dict__

    def __setstate__(self, state: dict) -> None:
        self.__dict__ = state

    def validate(self) -> bool:
        for key, tensor in self.tensors.items():
            assert isinstance(tensor, torch.Tensor), f"'{key}' is not a tensor!"
        return True

    def to(self, device: torch.device) -> None:
        self.tensors = {k: v.to(device) for k, v in self.tensors.items()}



class Data(BaseData):
    """A data object describing a homogeneous graph.

    Includes general graph information about: nodes, edges, target labels and global features.
    """

    def __init__(
        self,
        node_features: torch.Tensor,
        edge_index: torch.Tensor = torch.tensor([]),
        edge_features: torch.Tensor | None = None,
        targets: torch.Tensor | None = None,
        global_features: torch.Tensor | None = None,
        **kwargs
    ) -> None:
        super().__init__(**kwargs)
        self.node_features = node_features
        self.edge_index = edge_index
        self.edge_features = edge_features
        self.global_features = global_features
        self.targets = targets

    def validate(self) -> bool:
        super().validate()
        assert self.num_nodes > 0
        assert self.node_features.shape[0] == self.num_nodes
        assert self.node_features.ndim >= 2
        assert self.edge_index.shape[0] == self.num_edges
        assert self.edge_index.shape[0] == 0 or self.edge_index.shape[1] == 2
        assert self.edge_index.shape[0] == 0 or self.edge_index.max() < self.num_nodes
        if self.edge_features is not None:
            assert self.edge_features.shape[0] == self.num_edges
            assert self.num_edges == 0 or self.edge_features.ndim >= 2
        return True

    @property
    def num_nodes(self) -> torch.Tensor:
        # try to get num_nodes from tensors, else from node_features
        return self.tensors.get("num_nodes", self.node_features.shape[0])

    @property
    def num_edges(self) -> torch.Tensor:
        # try to get num_edges from tensors, else from edge_index
        return self.tensors.get("num_edges", self.edge_index.shape[0])

    @property
    def edge_index_source(self) -> torch.Tensor:
        return self.edge_index[:, 0]

    @property
    def edge_index_target(self) -> torch.Tensor:
        return self.edge_index[:, 1]


class AtomsData(Data):
    """A data object describing atoms as a graph with spatial information."""

    def __init__(
        self,
        node_positions: torch.Tensor,
        energy: torch.Tensor | None = None,
        forces: torch.Tensor | None = None,
        magmoms: torch.Tensor | None = None,
        cell: torch.Tensor | None = None,
        volume: torch.Tensor | None = None,
        stress: torch.Tensor | None = None,
        pbc: torch.Tensor | None = None,
        edge_shift: torch.Tensor | None = None,
        **kwargs
    ) -> None:
        super().__init__(**kwargs)
        self.node_positions = node_positions
        self.energy = energy
        self.forces = forces
        self.magmoms = magmoms
        self.cell = cell
        self.volume = volume
        self.stress = stress
        self.pbc = pbc
        self.edge_shift = edge_shift

    def validate(self) -> bool:
        super().validate()
        assert self.node_positions.shape[0] == self.num_nodes
        assert self.node_positions.ndim == 2
        spatial_dim = self.node_positions.shape[1]
        if self.energy is not None:
            assert self.energy.shape == (1,)
        if self.forces is not None:
            assert self.forces.shape == (self.num_nodes, spatial_dim)
        if self.magmoms is not None:
            assert self.magmoms.shape == (self.num_nodes, 1)
        if self.cell is not None or self.pbc is not None:
            assert self.cell is not None
            assert self.pbc is not None
            assert self.cell.shape == (spatial_dim, spatial_dim)
            assert self.pbc.shape == (spatial_dim,)
        if self.volume is not None:
            assert self.cell is not None
            assert self.volume.shape == (1,)
            assert torch.isclose(self.volume, torch.linalg.det(self.cell))
        if self.stress is not None:
            assert self.cell is not None
            assert self.stress.shape == (spatial_dim, spatial_dim)
        if self.edge_shift is not None:
            assert self.edge_index is not None
            assert self.edge_shift.shape == (self.num_edges, spatial_dim)
        return True

    def any_pbc(self) -> bool:
        return self.pbc is not None and bool(torch.any(self.pbc))


class GeometricData(Data):
    """A data object describing a geometric graph with spatial information."""

    def __init__(
        self,
        node_positions: torch.Tensor,
        node_velocities: torch.Tensor | None = None,
        node_accelerations: torch.Tensor | None = None,
        **kwargs
    ) -> None:
        super().__init__(**kwargs)
        self.node_positions = node_positions
        self.node_velocities = node_velocities
        self.node_accelerations = node_accelerations

    def validate(self) -> bool:
        super().validate()
        assert self.node_positions.shape[0] == self.num_nodes
        assert self.node_positions.ndim == 2
        spatial_dim = self.node_positions.shape[1]
        if self.node_velocities is not None:
            assert self.node_velocities.shape[0] == self.num_nodes
            assert self.node_velocities.shape[1] == spatial_dim
        if self.node_accelerations is not None:
            assert self.node_accelerations.shape[0] == self.num_nodes
            assert self.node_accelerations.shape[1] == spatial_dim
        return True

###############################################################
# Batch, collate_data and Utility functions
###############################################################

class Batch(Data):
    """An object representing a batch of data.

    Typically a disjoint union of graphs.
    """

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self._node_data_index: torch.Tensor | None = None
        self._edge_data_index: torch.Tensor | None = None

    def validate(self) -> bool:
        for key, tensor in self.tensors.items():
            assert isinstance(tensor, torch.Tensor), f"'{key}' is not a tensor!"
        assert self.node_features.shape[0] == torch.sum(self.num_nodes)
        assert self.node_features.shape[0] == 0 or self.node_features.ndim >= 2
        assert self.edge_index.shape[0] == torch.sum(self.num_edges)
        assert self.edge_index.shape[0] == 0 or self.edge_index.shape[1] == 2
        assert self.edge_index.shape[0] == 0 or self.edge_index.max() < self.node_features.shape[0]
        assert self.num_data >= 1
        if self.edge_features is not None:
            assert self.edge_features.shape[0] == self.edge_index.shape[0]
            assert self.edge_features.ndim >= 2
        if self.global_features is not None:
            assert self.global_features.shape[0] == self.num_data
        if self.targets is not None:
            assert self.targets.ndim >= 2
        return True

    @property
    def num_data(self) -> int:
        # Number of graphs in the batch is the length of the num_nodes tensor.
        return self.num_nodes.shape[0]

    @property
    def node_data_index(self) -> torch.Tensor:
        if self._node_data_index is None:
            self._node_data_index = torch.repeat_interleave(
                torch.arange(self.num_nodes.shape[0], device=self.num_nodes.device), self.num_nodes)
        return self._node_data_index

    @property
    def edge_data_index(self) -> torch.Tensor:
        if self._edge_data_index is None:
            self._edge_data_index = torch.repeat_interleave(
                torch.arange(self.num_edges.shape[0], device=self.num_edges.device), self.num_edges)
        return self._edge_data_index


def collate_data(list_of_data: Sequence[Data]) -> Batch:
    """Collate a list of data objects into a batch object.

    The input graphs are combined into a single graph as a disjoint union by
    concatenation of all data and appropriate adjustment of the edge_index.
    """
    batch = dict()
    batch["num_nodes"] = torch.tensor([d.num_nodes for d in list_of_data])
    batch["num_edges"] = torch.tensor([d.num_edges for d in list_of_data])
    offset = torch.cumsum(batch["num_nodes"], dim=0) - batch["num_nodes"]
    batch["edge_index"] = torch.cat([d.edge_index + offset[i] for i, d in enumerate(list_of_data)])
    for k in list_of_data[0].tensors.keys():
        if k not in batch.keys():
            try:
                if k == "cell" or k == "stress":
                    batch[k] = torch.cat([d.tensors[k].unsqueeze(0) for d in list_of_data])
                else:
                    batch[k] = torch.cat([torch.atleast_2d(d.tensors[k]) for d in list_of_data])
            except Exception as e:
                raise Exception(f"Failed to add '{k}' to batch:", e)
    return Batch(**batch)

def sum_splits(values: torch.Tensor, splits: torch.Tensor, out: torch.Tensor | None = None) -> torch.Tensor:
    if out is None:
        out = torch.zeros((splits.size(0),) + values.shape[1:], dtype=values.dtype, device=values.device)
    idx = torch.repeat_interleave(torch.arange(splits.size(0), device=values.device), splits)
    out.index_add_(0, idx, values)
    return out

def mean_splits(values: torch.Tensor, splits: torch.Tensor, out: torch.Tensor | None = None) -> torch.Tensor:
    out = sum_splits(values, splits, out=out)
    # Divide by number of elements per split
    divisor = splits.view(-1, *([1]*(values.dim()-1)))
    out = out / divisor
    return out

def reduce_splits(values: torch.Tensor, splits: torch.Tensor, out: torch.Tensor | None = None, reduction: str = "sum") -> torch.Tensor:
    if reduction == "sum":
        return sum_splits(values, splits, out=out)
    elif reduction == "mean":
        return mean_splits(values, splits, out=out)
    else:
        raise ValueError(f"Unknown reduction method: {reduction}")

def cosine_cutoff(x: torch.Tensor, cutoff: float) -> torch.Tensor:
    assert cutoff > 0.0
    return 0.5 * (torch.cos(math.pi * x / cutoff) + 1) * (x <= cutoff)

class BesselExpansion(torch.nn.Module):
    def __init__(self, size: int, cutoff: float = 5.0, trainable: bool = False) -> None:
        super().__init__()
        self.size = size
        self.register_parameter(
            "b_pi_over_c",
            torch.nn.Parameter((torch.arange(size) + 1) * math.pi / cutoff, requires_grad=trainable)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + 1e-10
        return torch.sin(self.b_pi_over_c * x) / x

def compute_edge_vectors_and_norms(
    positions: torch.Tensor,
    edge_index: torch.Tensor,
    edge_shift: torch.Tensor | None = None,
    edge_cell: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    source_positions = positions[edge_index[:, 0]]
    target_positions = positions[edge_index[:, 1]]
    if edge_shift is not None:
        assert edge_cell is not None
        shift = torch.squeeze(edge_shift.unsqueeze(1) @ edge_cell, dim=1)
        target_positions = target_positions + shift
    vectors = target_positions - source_positions
    norms = torch.linalg.norm(vectors, dim=1, keepdim=True)
    return vectors, norms

def sum_index(
    values: torch.Tensor,
    index: torch.Tensor,
    out: torch.Tensor | None = None,
    num_out: int = 0
) -> torch.Tensor:
    assert out is not None or num_out > 0
    if out is None:
        out_shape = torch.Size([num_out]) + values.shape[1:]
        out = torch.zeros(out_shape, dtype=values.dtype, device=values.device)
    out.index_add_(0, index, values)
    return out

###############################################################
# Core PaiNN model code                                       #
###############################################################


class PaiNNInteractionBlock(nn.Module):
    def __init__(self, node_size: int, edge_size: int, cutoff: float):
        super().__init__()
        self.node_size = node_size
        self.edge_size = edge_size
        self.cutoff = cutoff
        self.edge_filter_net = nn.Linear(edge_size, 3 * node_size)
        self.scalar_message_net = nn.Sequential(
            nn.Linear(node_size, node_size),
            nn.SiLU(),
            nn.Linear(node_size, 3 * node_size),
        )
        self.U_net = nn.Linear(node_size, node_size, bias=False)
        self.V_net = nn.Linear(node_size, node_size, bias=False)
        self.a_net = nn.Sequential(
            nn.Linear(2 * node_size, node_size),
            nn.SiLU(),
            nn.Linear(node_size, 3 * node_size),
        )

    def _message_function(self, node_states_scalar, node_states_vector, edge_states, edge_vectors, edge_norms, edge_index):
        filter_weight = self.edge_filter_net(edge_states)
        filter_weight = filter_weight * cosine_cutoff(edge_norms, self.cutoff)

        scalar_output = self.scalar_message_net(node_states_scalar)
        src_nodes = edge_index[:,0]
        dst_nodes = edge_index[:,1]

        filter_output = filter_weight * scalar_output[src_nodes]

        gate_nodes, gate_edges, messages_scalar = torch.split(filter_output, self.node_size, dim=1)
        gate_nodes = gate_nodes.unsqueeze(1)
        gate_edges = gate_edges.unsqueeze(1)

        gated_node_states_vector = node_states_vector[src_nodes]*gate_nodes
        gated_edge_vectors = gate_edges * edge_vectors.unsqueeze(2)
        messages_vector = gated_node_states_vector + gated_edge_vectors

        delta_node_states_scalar_m = sum_index(messages_scalar, dst_nodes, torch.zeros_like(node_states_scalar))
        delta_node_states_vector_m = sum_index(messages_vector, dst_nodes, torch.zeros_like(node_states_vector))

        return delta_node_states_scalar_m, delta_node_states_vector_m

    def _node_state_update_function(self, node_states_scalar, node_states_vector):
        Uv = self.U_net(node_states_vector)
        Vv = self.V_net(node_states_vector)
        Vv_square_norm = torch.sum(Vv**2, dim=1)

        a = self.a_net(torch.cat((node_states_scalar, Vv_square_norm), dim=1))
        a_ss, a_sv, a_vv = torch.split(a, self.node_size, dim=1)

        inner_prod_Uv_Vv = torch.sum(Uv*Vv, dim=1)
        delta_node_states_scalar_u = a_ss + a_sv*inner_prod_Uv_Vv
        delta_node_states_vector_u = a_vv.unsqueeze(1)*Uv
        return delta_node_states_scalar_u, delta_node_states_vector_u

    def forward(self, node_states_scalar, node_states_vector, edge_states, edge_vectors, edge_norms, edge_index):
        delta_node_states_scalar_m, delta_node_states_vector_m = self._message_function(
            node_states_scalar, node_states_vector, edge_states, edge_vectors, edge_norms, edge_index
        )

        node_states_scalar = node_states_scalar + delta_node_states_scalar_m
        node_states_vector = node_states_vector + delta_node_states_vector_m

        delta_node_states_scalar_u, delta_node_states_vector_u = self._node_state_update_function(
            node_states_scalar, node_states_vector
        )

        node_states_scalar = node_states_scalar + delta_node_states_scalar_u
        node_states_vector = node_states_vector + delta_node_states_vector_u
        return node_states_scalar, node_states_vector

class PaiNN(nn.Module):
    def __init__(self, node_size=64, edge_size=20, num_interaction_blocks=3, cutoff=5.0,
                 pbc=False, use_readout=True, num_readout_layers=2, readout_size=1,
                 readout_reduction="sum"):
        super().__init__()
        self.node_size = node_size
        self.edge_size = edge_size
        self.num_interaction_blocks = num_interaction_blocks
        self.cutoff = cutoff
        self.pbc = pbc
        self.use_readout = use_readout
        self.num_readout_layers = num_readout_layers
        self.readout_size = readout_size
        self.readout_reduction = readout_reduction

        num_embeddings = 119
        self.node_embedding = nn.Embedding(num_embeddings, node_size)
        self.edge_expansion = BesselExpansion(edge_size, cutoff)
        self.interaction_blocks = nn.ModuleList(
            PaiNNInteractionBlock(node_size, edge_size, cutoff)
            for _ in range(num_interaction_blocks))

        if self.use_readout:
            layers = []
            for _ in range(num_readout_layers - 1):
                layers.append(nn.Linear(node_size, node_size))
                layers.append(nn.SiLU())
            layers.append(nn.Linear(node_size, readout_size))
            self.readout_net = nn.Sequential(*layers)
        else:
            self.readout_net = None

    def forward(self, batch):
        node_states_scalar = self.node_embedding(batch.node_features.squeeze(-1))
        node_states_vector = torch.zeros(
            node_states_scalar.size(0), 3, self.node_size,
            dtype=node_states_scalar.dtype,
            device=node_states_scalar.device
        )

        # If we have multiple graphs batched, cell is (num_graphs, 3, 3).
        # We need a (num_edges, 3, 3) cell_for_edges for each edge.
        if (hasattr(batch, 'cell') and batch.cell is not None and
            hasattr(batch, 'edge_shift') and batch.edge_shift is not None):
            # Create a per-edge cell tensor
            # batch.edge_data_index maps each edge to its graph index
            cell_for_edges = batch.cell[batch.edge_data_index]  # shape: (total_edges, 3, 3)

            edge_vectors, edge_norms = compute_edge_vectors_and_norms(
                batch.node_positions, batch.edge_index,
                batch.edge_shift, cell_for_edges
            )
        else:
            # If we don't have multiple graphs or no cell/edge_shift, just pass them directly.
            edge_vectors, edge_norms = compute_edge_vectors_and_norms(
                batch.node_positions, batch.edge_index,
                getattr(batch, 'edge_shift', None),
                getattr(batch, 'cell', None)
            )


        edge_vectors = edge_vectors / (edge_norms+1e-10)
        edge_states = self.edge_expansion(edge_norms)


        for block in self.interaction_blocks:
            node_states_scalar, node_states_vector = block(
                node_states_scalar, node_states_vector,
                edge_states, edge_vectors, edge_norms, batch.edge_index
            )


        if self.use_readout:
            node_states_scalar = self.readout_net(node_states_scalar)

        if self.readout_reduction:
            output_scalar = reduce_splits(node_states_scalar, batch.num_nodes, reduction=self.readout_reduction)
        else:
            output_scalar = node_states_scalar

        return output_scalar



###############################################################
# Loss function and Training utilities #
###############################################################

class MSELoss(torch.nn.Module):
    def __init__(self, target_property: str = "", forces: bool = False):
        super().__init__()
        self.target_property = target_property or "energy"
        self.forces = forces

    def forward(self, preds: dict, batch) -> torch.Tensor:
        # Basic MSE on the target property
        targets = batch.energy if self.target_property == "energy" else batch.targets
        loss = mse_loss(preds[self.target_property], targets)
        
        # If forces are included, add them to the loss
        if self.forces:
            loss_forces = mse_loss(preds["forces"], batch.forces)
            # Combine equally for simplicity
            loss = 0.5 * loss + 0.5 * loss_forces

        return loss

class GaussianNLLLoss(torch.nn.Module):
    def __init__(self, target_property: str = "", variance: float = 1.0, forces: bool = False):
        super().__init__()
        self.target_property = target_property or "energy"
        self.variance = variance
        self.forces = forces

    def forward(self, preds: dict, batch) -> torch.Tensor:
        # Basic Gaussian NLL on the target property
        targets = batch.energy if self.target_property == "energy" else batch.targets
        var = torch.full_like(preds[self.target_property], self.variance)
        loss = gaussian_nll_loss(preds[self.target_property], targets, var=var, reduction='mean')

        # If forces are included, add them
        if self.forces:
            var_forces = torch.full_like(preds["forces"], self.variance)
            loss_forces = gaussian_nll_loss(preds["forces"], batch.forces, var=var_forces, reduction='mean')
            # Combine equally for simplicity
            loss = 0.5 * loss + 0.5 * loss_forces

        return loss


################################################################
# Trainer with SWA, SWAG, SAM, ASAM and Laplace Implementation #
################################################################

class Trainer:
    def __init__(self, model, lr=1e-3, use_sam=False, use_asam=False, sam_rho=0.05,
                 use_laplace=False, num_laplace_samples=10, prior_precision=0.0,
                 use_swa=False, swa_lrs=1e-4, swa_start_percent=0.8, annealing_percent=0.05, 
                 annealing_strategy='cos', 
                 use_swag=False, max_num_models=20, no_cov_mat=True, loss_type="mse", max_steps=10000): 

        self.model = model

        self.use_sam = use_sam
        self.use_asam = use_asam
        self.sam_rho = sam_rho

        self.use_laplace = use_laplace
        self.num_laplace_samples = num_laplace_samples
        self.prior_precision = prior_precision

        self.use_swa = use_swa
        self.use_swag = use_swag
        self.swa_lrs = swa_lrs
        self.swa_start_percent = swa_start_percent
        self.annealing_percent = annealing_percent
        self.annealing_strategy = annealing_strategy
        self.max_num_models = max_num_models
        self.no_cov_mat = no_cov_mat

        if self.use_sam or self.use_asam:
            # Initialize base optimizer (SGD with Momentum)
            self.optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
            mode = "ASAM" if self.use_asam else "SAM"
            print(f"Optimizer set to SGD with Momentum for {mode} (rho={self.sam_rho}.")
        else:
            self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            print("Optimizer set to Adam.")
        
        self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.9999)

        # Choose loss function
        if loss_type == "mse":
            print("Using MSE Loss.")
            self.loss_function = MSELoss(target_property="energy", forces=False)
        elif loss_type == "nll":
            print("Using Gaussian NLL Loss.")
            self.loss_function = GaussianNLLLoss(target_property="energy", variance=1.0, forces=False)
        else:
            raise ValueError(f"Unknown loss_type: {loss_type}")

        if use_swag:
            self.num_parameters = sum(p.numel() for p in self.model.parameters())
            device = next(self.model.parameters()).device  # Get device from model
            self.mean = torch.zeros(self.num_parameters, device=device)
            self.sq_mean = torch.zeros(self.num_parameters, device=device)
            if not self.no_cov_mat:
                self.cov_mat_sqrt = torch.zeros(self.max_num_models, self.num_parameters, device=device)
            self.num_models_collected = 0

        if use_laplace:
            self._init_laplace()

        if use_sam and use_asam:
            raise ValueError("Cannot use both SAM and ASAM simultaneously.")
        
        if self.use_swa and self.use_swag:
            raise ValueError("Cannot use both SWA and SWAG simultaneously.")

        self.current_step = 0          # Initialize step counter
        self.max_steps = max_steps      # Total number of training steps
        self.swa_start_step = None      # Will be set in set_max_steps
        self.annealing_steps = None     # Will be set in set_max_steps
        self.epoch = 0  # Initialize epoch counter

        # Initialize SWA/SWAG if enabled
        if self.use_swa or self.use_swag:
            self.setup_swa()

    def set_max_steps(self, max_steps):
        """Set the maximum number of training steps and compute SWA parameters."""
        self.max_steps = max_steps
        if self.use_swa or self.use_swag:
            self.swa_start_step = int(self.swa_start_percent * self.max_steps)
            self.annealing_steps = int(self.annealing_percent * self.max_steps)
            print(f"SWA will start at step {self.swa_start_step} and anneal over {self.annealing_steps} steps.")

    def setup_swa(self):
        if self.use_swa:
            self.averaged_model = AveragedModel(self.model)
            self.swa_scheduler = SWALR(self.optimizer, swa_lr=self.swa_lrs)
            print("SWA has been set up.")
        
        if self.use_swag:
            # SWAG setup already handled in __init__
            print("SWAG has been set up.")

    def _init_laplace(self):
        self.accumulated_squared_gradients = [torch.zeros_like(p) for p in self.model.parameters()]
        self.total_batches = 0


    def train_step(self, batch):
        self.model.train()
        # Forward pass
        pred = self.model(batch)
        # Assume batch.energy is the target
        preds_dict = {"energy": pred}  
        loss = self.loss_function(preds_dict, batch)

        if self.use_sam:
            self.optimizer.zero_grad()
            loss.backward()

            # Compute gradient norm
            grad_norm = torch.norm(
                torch.stack([p.grad.detach().norm(2) for p in self.model.parameters() if p.grad is not None])
            )

            # SAM perturbation
            e_ws = []
            with torch.no_grad():
                for p in self.model.parameters():
                    if p.grad is None:
                        e_ws.append(None)
                        continue
                    e_w = p.grad / (grad_norm + 1e-12) * self.sam_rho
                    p.add_(e_w)
                    e_ws.append(e_w)

            # Second forward-backward pass
            self.optimizer.zero_grad()
            pred2 = self.model(batch)
            loss2 = self.loss_function({"energy": pred2}, batch)  # Corrected line
            loss2.backward()

            with torch.no_grad():
                for p, e_w in zip(self.model.parameters(), e_ws):
                    if e_w is not None:
                        p.sub_(e_w)

            self.optimizer.step()

            # Define print frequency (e.g., every 10 steps)
            print_frequency = 10
            if self.current_step % print_frequency == 0:
                # Calculate average perturbation norm
                perturbation_norms = [e_w.norm().item() for e_w in e_ws if e_w is not None]
                avg_perturbation_norm = (
                    sum(perturbation_norms) / len(perturbation_norms) if perturbation_norms else 0.0
                )

                print(f"[SAM] Step {self.current_step}, Gradient Norm: {grad_norm.item():.4f}")
                print(f"[SAM] Step {self.current_step}, Average Perturbation Norm: {avg_perturbation_norm:.6f}")
                print(f"[SAM] Step {self.current_step}, Loss Before SAM: {loss.item():.6f}, Loss After SAM: {loss2.item():.6f}")
                print(f"[SAM] Step {self.current_step} SAM optimization completed.")

            final_loss = loss2

        elif self.use_asam:
            self.optimizer.zero_grad()
            loss.backward()
            param_norms = []
            with torch.no_grad():
                for p in self.model.parameters():
                    if p.grad is None:
                        param_norms.append(None)
                        continue
                    param_norm = torch.norm(p)
                    param_norms.append(param_norm)
                scaled_grads = []
                for p, pn in zip(self.model.parameters(), param_norms):
                    if p.grad is None:
                        continue
                    scaled_grad = p.grad / (pn + 1e-12)
                    scaled_grads.append(scaled_grad.view(-1))
                scaled_grad_norm = torch.norm(torch.cat(scaled_grads))
                epsilon = self.sam_rho / (scaled_grad_norm + 1e-12)

                perturbations = []
                for p, pn in zip(self.model.parameters(), param_norms):
                    if p.grad is None:
                        perturbations.append(None)
                        continue
                    perturbation = epsilon * p.grad / (pn + 1e-12)
                    p.add_(perturbation)
                    perturbations.append(perturbation)

            # Second forward-backward pass
            self.optimizer.zero_grad()
            pred2 = self.model(batch)
            loss2 = self.loss_function({"energy": pred2}, batch)  # Corrected line
            loss2.backward()

            with torch.no_grad():
                for p, perturbation in zip(self.model.parameters(), perturbations):
                    if perturbation is not None:
                        p.sub_(perturbation)

            self.optimizer.step()

            # Define print frequency (e.g., every 10 steps)
            print_frequency = 10
            if self.current_step % print_frequency == 0:
                # Calculate average perturbation norm for ASAM
                perturbation_norms_asam = [perturbation.norm().item() for perturbation in perturbations if perturbation is not None]
                avg_perturbation_norm_asam = (
                    sum(perturbation_norms_asam) / len(perturbation_norms_asam) if perturbation_norms_asam else 0.0
                )

                print(f"[ASAM] Step {self.current_step}, Average Perturbation Norm: {avg_perturbation_norm_asam:.6f}")
                print(f"[ASAM] Step {self.current_step}, Loss Before ASAM: {loss.item():.6f}, Loss After ASAM: {loss2.item():.6f}")
                print(f"[ASAM] Step {self.current_step} ASAM optimization completed.")

            final_loss = loss2
            
        else:
            self.optimizer.zero_grad()
            loss.backward()
            if self.use_laplace:
                # Accumulate grad^2
                with torch.no_grad():
                    for i, p in enumerate(self.model.parameters()):
                        if p.grad is not None:
                            self.accumulated_squared_gradients[i] += p.grad.data.clone() ** 2
                self.total_batches += 1
            self.optimizer.step()
            final_loss = loss

        # Increment step counter
        self.current_step += 1

        # Handle SWA based on steps
        if self.use_swa and self.current_step >= self.swa_start_step:
            # Update SWA parameters
            self.averaged_model.update_parameters(self.model)
            # Annealing the SWA learning rate
            if self.annealing_steps > 0 and self.current_step <= (self.swa_start_step + self.annealing_steps):
                self.swa_scheduler.step()
            elif self.annealing_steps > 0 and self.current_step > (self.swa_start_step + self.annealing_steps):
                # After annealing_steps, set SWA LR to a minimum value or keep it constant
                pass  # You can implement a strategy here if needed

        # Handle SWAG if enabled
        if self.use_swag and self.current_step >= self.swa_start_step:
            self.collect_swag_model()

        return final_loss.item()

    def end_epoch(self):
        # Step the scheduler
        self.scheduler.step()

        # Increment epoch counter
        self.epoch += 1

    def collect_swag_model(self):
        param_vector = torch.nn.utils.parameters_to_vector(self.model.parameters())
        if torch.isnan(param_vector).any():
            print("NaNs detected in parameter vector during SWAG collection!")
        
        n = self.num_models_collected
        if n == 0:
            self.mean.copy_(param_vector)
            self.sq_mean.copy_(param_vector**2)
        else:
            delta = param_vector - self.mean
            self.mean += delta / (n + 1)
            delta2 = param_vector**2 - self.sq_mean
            self.sq_mean += delta2 / (n + 1)
        if not self.no_cov_mat and n < self.max_num_models:
            idx = n % self.max_num_models
            self.cov_mat_sqrt[idx].copy_(param_vector - self.mean)
        self.num_models_collected += 1

    def swag_sample(self, scale=1.0, cov=False):
        mean = self.mean
        sq_mean = self.sq_mean
        var = sq_mean - mean**2
        
        # **Clamp variance to zero to avoid negative values**
        var = torch.clamp(var, min=0.0)
        
        std = torch.sqrt(var + 1e-30)
        
        if torch.isnan(mean).any() or torch.isnan(std).any():
            print("NaNs detected in SWAG mean or std!")
        
        z = torch.randn_like(mean)
        if cov and not self.no_cov_mat:
            c = self.cov_mat_sqrt[:min(self.num_models_collected, self.max_num_models)]
            z_cov = torch.randn(c.size(0), device=c.device)
            sample = mean + scale * (z * std + (c.t().matmul(z_cov) / (self.num_models_collected - 1)**0.5))
        else:
            sample = mean + scale * z * std
        
        if torch.isnan(sample).any():
            print("NaNs detected in SWAG sampled parameters!")
        
        torch.nn.utils.vector_to_parameters(sample, self.model.parameters())


    def predict(self, batch):
        self.model.eval()
        with torch.no_grad():
            if self.use_laplace and hasattr(self, 'hessian_diagonal'):
                predictions = []
                
                # Clamp Hessian diagonals to avoid extreme values
                for i, var_p in enumerate(self.hessian_diagonal):
                    self.hessian_diagonal[i] = torch.clamp(var_p, min=1e-10)

                # Sample parameters multiple times
                for _ in range(self.num_laplace_samples):
                    sampled_params = []
                    for mean_p, var_p, p in zip(self.param_means, self.hessian_diagonal, self.model.parameters()):
                        var_p = torch.clamp(var_p, min=1e-10)
                        std = (1.0 / (var_p + 1e-6))**0.5
                        std = torch.clamp(std, max=0.001)
                        noise = torch.randn_like(std)
                        sampled_p = mean_p + noise * std
                        sampled_params.append(sampled_p)

                    # Backup current params
                    backup_params = [p.detach().clone() for p in self.model.parameters()]
                    for p, sp in zip(self.model.parameters(), sampled_params):
                        p.copy_(sp)

                    pred = self.model(batch)

                    # If pred is valid, store it
                    # (If you want to skip NaN predictions, you could check here and skip them)
                    predictions.append(pred.detach().cpu())

                    # Restore original parameters
                    for p, bp in zip(self.model.parameters(), backup_params):
                        p.copy_(bp)

                # Handle cases with zero or one valid sample
                if len(predictions) == 0:
                    # No valid samples, return NaNs
                    preds_mean = torch.full_like(batch.energy, float('nan'))
                    preds_var = torch.full_like(batch.energy, float('nan'))
                    return preds_mean, preds_var

                if len(predictions) == 1:
                    # Only one sample, variance is zero
                    preds_mean = predictions[0]
                    preds_var = torch.zeros_like(preds_mean)
                    return preds_mean, preds_var

                # Multiple samples: compute mean and var
                predictions_tensor = torch.stack(predictions)
                preds_mean = predictions_tensor.mean(dim=0)
                preds_var = predictions_tensor.var(dim=0, unbiased=False)
                return preds_mean, preds_var

            elif self.use_swag and self.num_models_collected > 0:
                self.swag_sample(scale=1.0, cov=not self.no_cov_mat)
                pred = self.model(batch)
                return pred, None
            else:
                pred = self.model(batch)
                return pred, None

    
    def finalize_laplace(self):
        if self.use_laplace:
            self.hessian_diagonal = []
            print("total batches: ", self.total_batches)
            for sq_grad in self.accumulated_squared_gradients:
                h_diag = (sq_grad / self.total_batches) + self.prior_precision
                self.hessian_diagonal.append(h_diag)
            self.param_means = [p.detach().clone() for p in self.model.parameters()]
            self.accumulated_squared_gradients = None

           
    def finalize(self):
        # Finalize SWA
        if self.use_swa:
            if not hasattr(self, 'train_loader'):
                raise ValueError("Train loader not set. Please assign train_loader to the trainer.")
            update_bn(self.train_loader, self.averaged_model)
            self.model = self.averaged_model.module  # Use the averaged model for evaluation
            print("SWA has been finalized and batch norms updated.")

        # Finalize SWAG
        if self.use_swag:
            # SWAG does not require additional finalization
            print("SWAG has been finalized.")

        # Finalize Laplace
        if self.use_laplace:
            self.finalize_laplace()
            print("Laplace Approximation finalized.")

In [3]:
### CHANGE THIS TO QM9 DATA BASE DIRECTORY ###
db_path = r"C:\Users\Jonat\Desktop\Datasets\qm9.db" ### <--- CHANGE THIS TO QM9 DATA BASE DIRECTORY

# Define a function to convert an ASE Atoms object and associated properties to AtomsData
def ase_to_atomsdata(atoms: Atoms, energy_property: str = "energy_U0") -> AtomsData:
    """
    Convert an ASE Atoms object from QM9 and selected property to AtomsData object.
    By default, use U0 (at 0K) as target energy if available in db.
    """
    # Extract atomic numbers as node features
    atomic_numbers = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.long).view(-1, 1)
    # Positions
    positions = torch.tensor(atoms.get_positions(), dtype=torch.float32)
    
    # For QM9, no periodicity:
    cell = torch.eye(3, dtype=torch.float32)
    pbc = torch.tensor([False, False, False])
    
    # Extract energy from the ASE Atoms info dictionary
    # The QM9 ASE database entries often store a dictionary with properties. 
    # Check `ase.db` documentation or run `ase db qm9.db --help` to see available keys.
    # Common keys: 'energy_U0', 'energy_U', 'gap', etc.
    energy = torch.tensor([atoms.info.get(energy_property, 0.0)], dtype=torch.float32)
    
    # Construct edge_index:
    # QM9 does not store bonds directly. You can define a cutoff to determine edges:
    # For small molecules, a cutoff of ~1.5Å might capture bonds. Adjust as needed.
    cutoff = 1.5
    pos = atoms.get_positions()
    src_list = []
    dst_list = []
    num_nodes = pos.shape[0]
    for i in range(num_nodes):
        for j in range(i+1, num_nodes):
            dist = ((pos[i] - pos[j])**2).sum()**0.5
            if dist < cutoff:
                src_list.append(i)
                dst_list.append(j)
                src_list.append(j)
                dst_list.append(i)
    edge_index = torch.tensor([src_list, dst_list], dtype=torch.long).t().contiguous()
    
    # For QM9, no edge_shift needed:
    edge_shift = None
    
    data = AtomsData(
        node_features=atomic_numbers,
        edge_index=edge_index,
        node_positions=positions,
        energy=energy,
        cell=cell,
        pbc=pbc,
        edge_shift=edge_shift
    )
    data.validate()
    return data

# Connect to QM9 database
with connect(db_path) as db:
    data_list = []
    # Remove limit or set it to a larger number
    for row in db.select():
        atoms = row.toatoms()
        data_obj = ase_to_atomsdata(atoms, energy_property="energy_U0")
        data_list.append(data_obj)

In [4]:
### CHANGE THIS TO SPLIT DIRECTORY ###
split_path = r"C:\Users\Jonat\Desktop\Datasets\randomsplits_110k_10k_rest.json" ### <--- CHANGE THIS TO SPLIT DIRECTORY

with open(split_path, "r") as f:
    splits = json.load(f)

train_indices = splits["train"]
val_indices = splits["validation"]
test_indices = splits["test"]

# data_list is your full dataset
train_data = [data_list[i] for i in train_indices]
val_data = [data_list[i] for i in val_indices]
test_data = [data_list[i] for i in test_indices]

# Create the DataLoader
batch_size = 32  # or any suitable batch size

train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=collate_data, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, collate_fn=collate_data, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, collate_fn=collate_data, shuffle=False)

In [None]:
# Initialize model
model = PaiNN(node_size=64, edge_size=20, num_interaction_blocks=2)

# Configure Trainer with desired functionality
trainer = Trainer(
    model=model,
    lr=5e-4,
    loss_type="mse",
    use_swa=False,
    use_swag=False,
    swa_lrs=1e-4,
    swa_start_percent=0.7,
    annealing_percent=0.1,
    annealing_strategy="cos",
    max_num_models=50,
    use_sam=True,
    use_asam=False,
    sam_rho=2e-3,
    use_laplace=False,
    num_laplace_samples=5,
    prior_precision=1.0,
    max_steps=100
)

# Assign the train_loader to the trainer for SWA finalization
trainer.train_loader = train_loader

# Set max_steps in Trainer to compute SWA parameters
trainer.set_max_steps(trainer.max_steps)  # This sets swa_start_step and annealing_steps based on max_steps=1000

# Define validation frequency
validate_every_steps = 10  # Perform validation every 10 training steps

# Initialize training step counters
last_validation_step = 0      # Step counter for last validation
stop_training = False         # Flag to stop training early
epoch = 0                     # Initialize epoch counter

# Start training
for epoch in itertools.count(epoch):
    # Training Phase
    model.train()
    for batch in train_loader:
        loss_val = trainer.train_step(batch)
        # Compute MAE
        preds = trainer.model(batch)
        targets = batch.energy  # Assuming 'energy' is the target
        mae = F.l1_loss(preds, targets).item()
        
        print(f"Epoch {epoch}, Step {trainer.current_step}, Train MAE: {mae:.4f}")
        
        # Check if it's time to validate
        if trainer.current_step - last_validation_step >= validate_every_steps:
            # Validation Phase
            model.eval()
            val_maes = []
            with torch.no_grad():
                for val_batch in val_loader:
                    val_preds = trainer.model(val_batch)
                    val_targets = val_batch.energy  # Assuming 'energy' is the target
                    val_mae = F.l1_loss(val_preds, val_targets).item()
                    val_maes.append(val_mae)
            avg_val_mae = sum(val_maes) / len(val_maes)
            print(f"Step {trainer.current_step}, Validation MAE: {avg_val_mae:.4f}")
            last_validation_step = trainer.current_step  # Update last validation step
        
        # Check if maximum training steps are reached using Trainer's max_steps
        if trainer.current_step >= trainer.max_steps:
            stop_training = True
            break
    
    # End-of-epoch logic (for SWA/SWAG)
    trainer.end_epoch()
    
    if stop_training:
        print("Reached maximum number of training steps.")
        break


# Finalize SWA/SWAG and Laplace if enabled
trainer.finalize()

# Testing Phase
model.eval()
test_maes = []

with torch.no_grad():
    for batch in test_loader:
        test_preds = model(batch)
        test_targets = batch.energy  # Assuming 'energy' is the target
        test_mae = F.l1_loss(test_preds, test_targets).item()
        test_maes.append(test_mae)
avg_test_mae = sum(test_maes) / len(test_maes)
print(f"Test MAE: {avg_test_mae:.4f}")

# Predict with Uncertainty using Laplace or SWAG
if trainer.use_laplace:
    def get_laplace_predictions(trainer, test_loader, num_samples=10):

        trainer.model.eval()
        sample_batch = next(iter(test_loader))
        predictions = []

        for _ in range(num_samples):
            preds_mean, _ = trainer.predict(sample_batch)  # Sample predictions
            predictions.append(preds_mean)
        
        predictions = torch.stack(predictions)  # [num_samples, batch_size, output_dim]
        preds_mean = predictions.mean(dim=0)    # [batch_size, output_dim]
        preds_var = predictions.var(dim=0)      # [batch_size, output_dim]

        return preds_mean, preds_var
    
    # Get Laplace ensemble predictions
    laplace_mean, laplace_variance = get_laplace_predictions(trainer, test_loader, num_samples=10)

    print("Laplace Predictions Mean:", laplace_mean)
    print("Laplace Predictions Variance:", laplace_variance)

# Initialize a list to store predictions
if trainer.use_swag:
    
    def get_swag_predictions(trainer, test_loader, num_samples=10):
        trainer.model.eval()
        sample_batch = next(iter(test_loader))
        predictions = []

        for _ in range(num_samples):
            pred, _ = trainer.predict(sample_batch)
            predictions.append(pred)
        
        predictions = torch.stack(predictions)  # [num_samples, batch_size, output_dim]
        preds_mean = predictions.mean(dim=0)    # [batch_size, output_dim]
        preds_var = predictions.var(dim=0)      # [batch_size, output_dim]
        
        return preds_mean, preds_var
    
    # Get SWAG ensemble predictions
    swag_mean, swag_variance = get_swag_predictions(trainer, test_loader, num_samples=10)

    print("SWAG Predictions Mean:", swag_mean)
    print("SWAG Predictions Variance:", swag_variance)