In [75]:
from pathlib import Path

import numpy as np
import torch
import torch_geometric as pyg
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from torch_geometric.data import Batch, Data
import functorch
import copy
from ocpmodels.transfer_learning.models.distribution_regression import (
    GaussianKernel,
    KernelMeanEmbeddingRidgeRegression,
    LinearMeanEmbeddingKernel,
    StandardizedOutputRegression,
    median_heuristic,
)

from ocpmodels.transfer_learning.common.utils import (
    ATOMS_TO_GRAPH_KWARGS,
    load_xyz_to_pyg_batch,
    load_xyz_to_pyg_data,
)
from ocpmodels.transfer_learning.loaders import BaseLoader

In [2]:
%cd ../..

/home/isak/life/references/projects/src/python_lang/ocp


In [76]:
#%cd ../..
### Load checkpoint
CHECKPOINT_PATH = Path("checkpoints/s2ef_efwt/all/schnet/schnet_all_large.pt")
checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu")

### Load data
DATA_PATH = Path("data/luigi/example-traj-Fe-N2-111.xyz")
raw_data, data_batch, num_frames, num_atoms = load_xyz_to_pyg_batch(DATA_PATH, ATOMS_TO_GRAPH_KWARGS["schnet"])
raw_data, data_list, num_frames, num_atoms = load_xyz_to_pyg_data(DATA_PATH, ATOMS_TO_GRAPH_KWARGS["schnet"])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

representation_layer = 2
base_loader = BaseLoader(
    checkpoint["config"],
    representation=True,
    representation_kwargs={
        "representation_layer": representation_layer,
    },
)
base_loader.load_checkpoint(CHECKPOINT_PATH, strict_load=False)
model = base_loader.model
model.to(device)
model.mekrr_forces = True
model.regress_forces = False

	Unexpected key(s) in state_dict: "atomic_mass", "interactions.2.mlp.0.weight", "interactions.2.mlp.0.bias", "interactions.2.mlp.2.weight", "interactions.2.mlp.2.bias", "interactions.2.conv.lin1.weight", "interactions.2.conv.lin2.weight", "interactions.2.conv.lin2.bias", "interactions.2.conv.nn.0.weight", "interactions.2.conv.nn.0.bias", "interactions.2.conv.nn.2.weight", "interactions.2.conv.nn.2.bias", "interactions.2.lin.weight", "interactions.2.lin.bias", "interactions.3.mlp.0.weight", "interactions.3.mlp.0.bias", "interactions.3.mlp.2.weight", "interactions.3.mlp.2.bias", "interactions.3.conv.lin1.weight", "interactions.3.conv.lin2.weight", "interactions.3.conv.lin2.bias", "interactions.3.conv.nn.0.weight", "interactions.3.conv.nn.0.bias", "interactions.3.conv.nn.2.weight", "interactions.3.conv.nn.2.bias", "interactions.3.lin.weight", "interactions.3.lin.bias", "interactions.4.mlp.0.weight", "interactions.4.mlp.0.bias", "interactions.4.mlp.2.weight", "interactions.4.mlp.2.bias", "

In [4]:
from ocpmodels.common.utils import get_pbc_distances, radius_graph_pbc
#from torch_scatter import scatter, segment_coo, segment_csr
# import torch_scatter

# # For the CSR operator:
# class SumCSR(torch.autograd.Function):
#     @staticmethod
#     def forward(ctx, values, groups):
#         ctx.save_for_backward(groups)
#         return torch_scatter.segment_csr(values, groups, reduce="sum")

#     @staticmethod
#     def backward(ctx, grad_output):
#         (groups,) = ctx.saved_tensors
#         return GatherCSR.apply(grad_output, groups), None


# class GatherCSR(torch.autograd.Function):
#     @staticmethod
#     def forward(ctx, values, groups):
#         ctx.save_for_backward(groups)
#         return torch_scatter.gather_csr(values, groups)

#     @staticmethod
#     def backward(ctx, grad_output):
#         (groups,) = ctx.saved_tensors
#         return SumCSR.apply(grad_output, groups), None
    
# def segment_csr(values, groups, reduce=None):
#     return SumCSR.apply(values, groups)


# # For the COO operator:
# class SumCOO(torch.autograd.Function):
#     @staticmethod
#     def forward(ctx, values, groups, dim_size):
#         ctx.save_for_backward(groups)
#         ctx.dim_size = dim_size
#         return torch_scatter.segment_coo(
#             values, groups, dim_size=dim_size, reduce="sum"
#         )

#     @staticmethod
#     def backward(ctx, grad_output):
#         (groups,) = ctx.saved_tensors
#         return GatherCOO.apply(grad_output, groups, ctx.dim_size), None, None


# class GatherCOO(torch.autograd.Function):
#     @staticmethod
#     def forward(ctx, values, groups, dim_size):
#         ctx.save_for_backward(groups)
#         ctx.dim_size = dim_size
#         return torch_scatter.gather_coo(values, groups)

#     @staticmethod
#     def backward(ctx, grad_output):
#         (groups,) = ctx.saved_tensors
#         return SumCOO.apply(grad_output, groups, ctx.dim_size), None, None

# def segment_coo(values, groups, dim_size, reduce=None):
#     return SumCOO.apply(values, groups, dim_size)

def my_radius_graph_pbc(data, radius, max_num_neighbors_threshold, pbc=[True, True, True]):
    edge_index, unit_cell, _ = radius_graph_pbc_(data, radius, max_num_neighbors_threshold, pbc)
    neighbors = []
    try:
        for dat in data.to_data_list():
            _, _, n = radius_graph_pbc_(dat, radius, max_num_neighbors_threshold)
            neighbors.append(n)
    except Exception as err:
        _, _, neighbors = radius_graph_pbx_(data, radius, max_num_neighbors_threshold)
    neighbors = torch.stack(neighbors).long()
    return edge_index, unit_cell, neighbors

def radius_graph_pbc_(data, radius, max_num_neighbors_threshold, pbc=[True, True, True]):
    device = data.pos.device
    batch_size = len(data.natoms)

    if hasattr(data, "pbc"):
        data.pbc = torch.atleast_2d(data.pbc)
        for i in range(3):
            if not torch.any(data.pbc[:, i]).item():
                pbc[i] = False
            elif torch.all(data.pbc[:, i]).item():
                pbc[i] = True
            else:
                raise RuntimeError(
                    "Different structures in the batch have different PBC configurations. This is not currently supported."
                )

    # position of the atoms
    atom_pos = data.pos

    # Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch
    num_atoms_per_image = data.natoms
    num_atoms_per_image_sqr = (num_atoms_per_image**2).long()

    # index offset between images
    index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image

    index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr)
    num_atoms_per_image_expand = torch.repeat_interleave(num_atoms_per_image, num_atoms_per_image_sqr)

    # Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image
    # that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement
    # the following (but 10x faster since it removes the for loop)
    # for batch_idx in range(batch_size):
    #    batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0)
    num_atom_pairs = torch.sum(num_atoms_per_image_sqr)
    index_sqr_offset = torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr
    index_sqr_offset = torch.repeat_interleave(index_sqr_offset, num_atoms_per_image_sqr)
    atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset

    # Compute the indices for the pairs of atoms (using division and mod)
    # If the systems get too large this apporach could run into numerical precision issues
    index1 = (torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor")) + index_offset_expand
    index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand
    # Get the positions for each atom
    pos1 = torch.index_select(atom_pos, 0, index1)
    pos2 = torch.index_select(atom_pos, 0, index2)

    # Calculate required number of unit cells in each direction.
    # Smallest distance between planes separated by a1 is
    # 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane.
    # Note that the unit cell volume V = a1 * (a2 x a3) and that
    # (a2 x a3) / V is also the reciprocal primitive vector
    # (crystallographer's definition).

    cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1)
    cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True)

    if pbc[0]:
        inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1)
        rep_a1 = torch.ceil(radius * inv_min_dist_a1)
    else:
        rep_a1 = data.cell.new_zeros(1)

    if pbc[1]:
        cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1)
        inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1)
        rep_a2 = torch.ceil(radius * inv_min_dist_a2)
    else:
        rep_a2 = data.cell.new_zeros(1)

    if pbc[2]:
        cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1)
        inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1)
        rep_a3 = torch.ceil(radius * inv_min_dist_a3)
    else:
        rep_a3 = data.cell.new_zeros(1)

    # Take the max over all images for uniformity. This is essentially padding.
    # Note that this can significantly increase the number of computed distances
    # if the required repetitions are very different between images
    # (which they usually are). Changing this to sparse (scatter) operations
    # might be worth the effort if this function becomes a bottleneck.
    max_rep = [rep_a1.max(), rep_a2.max(), rep_a3.max()]

    # Tensor of unit cells
    cells_per_dim = [torch.arange(-rep, rep + 1, device=device, dtype=torch.float) for rep in max_rep]
    unit_cell = torch.cartesian_prod(*cells_per_dim)
    num_cells = len(unit_cell)
    unit_cell_per_atom = unit_cell.view(1, num_cells, 3).repeat(len(index2), 1, 1)
    unit_cell = torch.transpose(unit_cell, 0, 1)
    unit_cell_batch = unit_cell.view(1, 3, num_cells).expand(batch_size, -1, -1)

    # Compute the x, y, z positional offsets for each cell in each image
    data_cell = torch.transpose(data.cell, 1, 2)
    pbc_offsets = torch.bmm(data_cell, unit_cell_batch)
    pbc_offsets_per_atom = torch.repeat_interleave(pbc_offsets, num_atoms_per_image_sqr, dim=0)

    # Expand the positions and indices for the 9 cells
    pos1 = pos1.view(-1, 3, 1).expand(-1, -1, num_cells)
    pos2 = pos2.view(-1, 3, 1).expand(-1, -1, num_cells)
    index1 = index1.view(-1, 1).repeat(1, num_cells).view(-1)
    index2 = index2.view(-1, 1).repeat(1, num_cells).view(-1)
    # Add the PBC offsets for the second atom
    pos2 = pos2 + pbc_offsets_per_atom

    # Compute the squared distance between atoms
    atom_distance_sqr = torch.sum((pos1 - pos2) ** 2, dim=1)
    atom_distance_sqr = atom_distance_sqr.view(-1)

    # Remove pairs that are too far apart
    mask_within_radius = torch.le(atom_distance_sqr, radius * radius)
    # Remove pairs with the same atoms (distance = 0.0)
    mask_not_same = torch.gt(atom_distance_sqr, 0.0001)
    mask = torch.logical_and(mask_within_radius, mask_not_same)
    index1 = torch.masked_select(index1, mask)
    index2 = torch.masked_select(index2, mask)
    unit_cell = torch.masked_select(unit_cell_per_atom.view(-1, 3), mask.view(-1, 1).expand(-1, 3))
    unit_cell = unit_cell.view(-1, 3)
    atom_distance_sqr = torch.masked_select(atom_distance_sqr, mask)

    # Remove due to errors with jvps
#     mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask(
#         natoms=data.natoms,
#         index=index1,
#         atom_distance=atom_distance_sqr,
#         max_num_neighbors_threshold=max_num_neighbors_threshold,
#     )
    num_neighbors_image = torch.tensor(len(index1)).long().to(device)

#     if not torch.all(mask_num_neighbors):
#         # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors
#         index1 = torch.masked_select(index1, mask_num_neighbors)
#         index2 = torch.masked_select(index2, mask_num_neighbors)
#         unit_cell = torch.masked_select(unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3))
#         unit_cell = unit_cell.view(-1, 3)
    #unit_cell = unit_cell.view(-1, 3)
    edge_index = torch.stack((index2, index1))

    return edge_index, unit_cell, num_neighbors_image

# def get_max_neighbors_mask(natoms, index, atom_distance, max_num_neighbors_threshold):
#     """
#     Give a mask that filters out edges so that each atom has at most
#     `max_num_neighbors_threshold` neighbors.
#     Assumes that `index` is sorted.
#     """
#     device = natoms.device
#     num_atoms = natoms.sum()

#     # Get number of neighbors
#     # segment_coo assumes sorted index
#     ones = index.new_ones(1).expand_as(index)
#     num_neighbors = segment_coo(ones, index, dim_size=num_atoms)
#     max_num_neighbors = num_neighbors.max()
#     num_neighbors_thresholded = num_neighbors.clamp(max=max_num_neighbors_threshold)

#     # Get number of (thresholded) neighbors per image
#     image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long)
#     image_indptr[1:] = torch.cumsum(natoms, dim=0)
#     num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr)

#     # If max_num_neighbors is below the threshold, return early
#     if max_num_neighbors <= max_num_neighbors_threshold or max_num_neighbors_threshold <= 0:
#         mask_num_neighbors = torch.tensor([True], dtype=bool, device=device).expand_as(index)
#         return mask_num_neighbors, num_neighbors_image

#     # Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors.
#     # Fill with infinity so we can easily remove unused distances later.
#     distance_sort = torch.full([num_atoms * max_num_neighbors], np.inf, device=device)

#     # Create an index map to map distances from atom_distance to distance_sort
#     # index_sort_map assumes index to be sorted
#     index_neighbor_offset = torch.cumsum(num_neighbors, dim=0) - num_neighbors
#     index_neighbor_offset_expand = torch.repeat_interleave(index_neighbor_offset, num_neighbors)
#     index_sort_map = index * max_num_neighbors + torch.arange(len(index), device=device) - index_neighbor_offset_expand
#     distance_sort.index_copy_(0, index_sort_map, atom_distance)
#     distance_sort = distance_sort.view(num_atoms, max_num_neighbors)

#     # Sort neighboring atoms based on distance
#     distance_sort, index_sort = torch.sort(distance_sort, dim=1)
#     # Select the max_num_neighbors_threshold neighbors that are closest
#     distance_sort = distance_sort[:, :max_num_neighbors_threshold]
#     index_sort = index_sort[:, :max_num_neighbors_threshold]

#     # Offset index_sort so that it indexes into index
#     index_sort = index_sort + index_neighbor_offset.view(-1, 1).expand(-1, max_num_neighbors_threshold)
#     # Remove "unused pairs" with infinite distances
#     mask_finite = torch.isfinite(distance_sort)
#     index_sort = torch.masked_select(index_sort, mask_finite)

#     # At this point index_sort contains the index into index of the
#     # closest max_num_neighbors_threshold neighbors per atom
#     # Create a mask to remove all pairs not in index_sort
#     mask_num_neighbors = torch.zeros(len(index), device=device, dtype=bool)
#     mask_num_neighbors.index_fill_(0, index_sort, True)

#     return mask_num_neighbors, num_neighbors_image

In [77]:
data = Batch.from_data_list(data_batch[:100])

In [None]:
edge_index, cell_offsets, neighbors = my_radius_graph_pbc(data, 6.0, 10000)

In [None]:
out = get_pbc_distances(
    data.pos,
    edge_index,
    data.cell,
    cell_offsets,
    neighbors,
    return_offsets=True,
    return_distance_vec=True,
)

In [None]:
out

In [None]:
cell_offsets, neighbors

In [None]:
from ocpmodels.common.utils import get_max_neighbors_mask, get_pbc_distances, radius_graph_pbc

In [None]:
edge_index1, cell_offsets1, neighbors1 = radius_graph_pbc(data, 6.0, 50)

In [None]:
out1 = get_pbc_distances(
    data.pos,
    edge_index1,
    data.cell,
    cell_offsets1,
    neighbors1,
    return_offsets=True,
    return_distance_vec=True,
)

In [None]:
torch.equal(edge_index, edge_index1)

In [None]:
torch.equal(cell_offsets, cell_offsets1)

In [None]:
torch.equal(neighbors, neighbors1)

In [None]:
trues = []
for key in out.keys():
    trues.append(torch.equal(out[key], out1[key]))
print(all(trues))

In [None]:
"""
Copyright (c) Facebook, Inc. and its affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import logging

import torch
import torch.nn as nn
from torch_geometric.nn import radius_graph

from ocpmodels.common.utils import (
    compute_neighbors,
    conditional_grad,
    get_pbc_distances,
    radius_graph_pbc,
)


class BaseModel(nn.Module):
    def __init__(self, num_atoms=None, bond_feat_dim=None, num_targets=None):
        super(BaseModel, self).__init__()
        self.num_atoms = num_atoms
        self.bond_feat_dim = bond_feat_dim
        self.num_targets = num_targets

    def forward(self, data):
        raise NotImplementedError

    def generate_graph(
        self,
        data,
        cutoff=None,
        max_neighbors=None,
        use_pbc=None,
        otf_graph=None,
    ):
        cutoff = cutoff or self.cutoff
        max_neighbors = max_neighbors or self.max_neighbors
        use_pbc = use_pbc or self.use_pbc
        otf_graph = otf_graph or self.otf_graph

        if not otf_graph:
            try:
                edge_index = data.edge_index

                if use_pbc:
                    cell_offsets = data.cell_offsets
                    neighbors = data.neighbors

            except AttributeError:
                logging.warning("Turning otf_graph=True as required attributes not present in data object")
                otf_graph = True

        if use_pbc:
            if otf_graph:
                edge_index, cell_offsets, neighbors = my_radius_graph_pbc(data, cutoff, max_neighbors)

            out = get_pbc_distances(
                data.pos,
                edge_index,
                data.cell,
                cell_offsets,
                neighbors,
                return_offsets=True,
                return_distance_vec=True,
            )

            edge_index = out["edge_index"]
            edge_dist = out["distances"]
            cell_offset_distances = out["offsets"]
            distance_vec = out["distance_vec"]
        else:
            if otf_graph:
                edge_index = radius_graph(
                    data.pos,
                    r=cutoff,
                    batch=data.batch,
                    max_num_neighbors=max_neighbors,
                )

            j, i = edge_index
            distance_vec = data.pos[j] - data.pos[i]

            edge_dist = distance_vec.norm(dim=-1)
            cell_offsets = torch.zeros(edge_index.shape[1], 3, device=data.pos.device)
            cell_offset_distances = torch.zeros_like(cell_offsets, device=data.pos.device)
            neighbors = compute_neighbors(data, edge_index)

        return (
            edge_index,
            edge_dist,
            distance_vec,
            cell_offsets,
            cell_offset_distances,
            neighbors,
        )

    @property
    def num_params(self):
        return sum(p.numel() for p in self.parameters())


In [None]:
"""
Copyright (c) Facebook, Inc. and its affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import torch
from torch_geometric.nn import SchNet
from torch_scatter import scatter

from ocpmodels.common.registry import registry
from ocpmodels.common.utils import (
    conditional_grad,
    get_pbc_distances,
    radius_graph_pbc,
)
#from ocpmodels.models.base import BaseModel


@registry.register_model("schnet")
class SchNetWrap(SchNet, BaseModel):
    r"""Wrapper around the continuous-filter convolutional neural network SchNet from the
    `"SchNet: A Continuous-filter Convolutional Neural Network for Modeling
    Quantum Interactions" <https://arxiv.org/abs/1706.08566>`_. Each layer uses interaction
    block of the form:

    .. math::
        \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot
        h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))),

    Args:
        num_atoms (int): Unused argument
        bond_feat_dim (int): Unused argument
        num_targets (int): Number of targets to predict.
        representation (bool, optional): If set to :obj:`True`, the model will output intermediate representation output
            from the 'representation_layer' interaction block. (default: :obj:`False`)
        representation_layer (int, optional): The interaction block to output the representation from.
        use_pbc (bool, optional): If set to :obj:`True`, account for periodic boundary conditions.
            (default: :obj:`True`)
        regress_forces (bool, optional): If set to :obj:`True`, predict forces by differentiating
            energy with respect to positions.
            (default: :obj:`True`)
        otf_graph (bool, optional): If set to :obj:`True`, compute graph edges on the fly.
            (default: :obj:`False`)
        hidden_channels (int, optional): Number of hidden channels.
            (default: :obj:`128`)
        num_filters (int, optional): Number of filters to use.
            (default: :obj:`128`)
        num_interactions (int, optional): Number of interaction blocks
            (default: :obj:`6`)
        num_gaussians (int, optional): The number of gaussians :math:`\mu`.
            (default: :obj:`50`)
        cutoff (float, optional): Cutoff distance for interatomic interactions.
            (default: :obj:`10.0`)
        readout (string, optional): Whether to apply :obj:`"add"` or
            :obj:`"mean"` global aggregation. (default: :obj:`"add"`)
    """

    def __init__(
        self,
        num_atoms,  # not used
        bond_feat_dim,  # not used
        num_targets,
        representation=False,
        representation_layer=None,
        use_pbc=True,
        regress_forces=True,
        otf_graph=False,
        hidden_channels=6,
        num_filters=12,
        num_interactions=6,
        num_gaussians=5,
        cutoff=6.0,
        readout="add",
        mekrr_forces=False,
    ):
        self.num_targets = num_targets
        self.representation = representation
        self.representation_layer = representation_layer
        self.regress_forces = regress_forces
        self.use_pbc = use_pbc
        self.cutoff = cutoff
        self.otf_graph = otf_graph
        self.max_neighbors = 10000000
        self.reduce = readout
        self.mekrr_forces = mekrr_forces
        super(SchNetWrap, self).__init__(
            hidden_channels=hidden_channels,
            num_filters=num_filters,
            num_interactions=num_interactions,
            num_gaussians=num_gaussians,
            cutoff=cutoff,
            readout=readout,
        )
        # Added by: Isak Falk
        # If using the model as representation we output the intermediate layer
        if self.representation:
            self.interactions = self.interactions[: self.representation_layer]

    @conditional_grad(torch.enable_grad())
    def _forward(self, data):
        z = data.atomic_numbers.long()
        pos = data.pos
        batch = data.batch

        (
            edge_index,
            edge_weight,
            distance_vec,
            cell_offsets,
            _,  # cell offset distances
            neighbors,
        ) = self.generate_graph(
            data
        )  # See the BaseModel.generate_graph method for more details

        # Added by: Isak Falk
        # Intermediate representation only works for self.use_pbc=True
        if self.use_pbc:
            assert z.dim() == 1 and z.dtype == torch.long

            edge_attr = self.distance_expansion(edge_weight)

            h = self.embedding(z)
            # For potential representation output
            for interaction in self.interactions:
                h = h + interaction(h, edge_index, edge_weight, edge_attr)
            if self.representation:
                return h

            h = self.lin1(h)
            h = self.act(h)
            h = self.lin2(h)

            batch = torch.zeros_like(z) if batch is None else batch
            energy = scatter(h, batch, dim=0, reduce=self.reduce)
        else:
            # Cannot use this for representation
            energy = super(SchNetWrap, self).forward(z, pos, batch)
        return energy

    def forward(self, data):
        # Need to not change the requires_grad of the input for mekrr
        if self.regress_forces and not self.mekrr_forces:
            data.pos.requires_grad_(True)
        energy = self._forward(data)

        if self.regress_forces:
            forces = -1 * (
                torch.autograd.grad(
                    energy,
                    data.pos,
                    grad_outputs=torch.ones_like(energy),
                    create_graph=True,
                )[0]
            )
            return energy, forces
        else:
            return energy

    @property
    def num_params(self):
        return sum(p.numel() for p in self.parameters())


In [None]:
model = SchNetWrap(num_targets=1, num_atoms=47, bond_feat_dim=50, representation=True, representation_layer=2)
model.regress_forces = False
model.mekrr = True

In [84]:
frames = 5
data = Batch.from_data_list(data_batch[:frames])
data.pos.requires_grad = True

In [85]:
h = model(data)
d = h.shape[-1]



In [92]:
frames = 50
data = Batch.from_data_list(data_batch[:frames])
data.pos.requires_grad = True
pos = data.pop("pos")

def f(pos, data, model):
    pos_list = []
    batch_idx = data.batch
    batch_unique_idx = torch.unique(batch_idx)
    for uidx in batch_unique_idx:
        pos_list.append(pos[batch_idx == uidx])

    data_list = data.to_data_list()
    for i, pos in enumerate(pos_list):
        data_list[i].pos = pos

    new_batch = Batch.from_data_list(data_list)
    h = model(new_batch).reshape(frames, num_atoms, d)
    return h

y = f(
    pos,
    data,
    model,
)
m = y.shape[0]
gr = torch.autograd.grad(
    outputs=y,
    inputs=pos,
    grad_outputs=torch.ones_like(y),
    retain_graph=False,
    create_graph=False,
    allow_unused=False,
    is_grads_batched=False,
)[0]
gr
output, vjp_fn = torch.func.vjp(torch.func.functionalize(lambda x: f(x, data, model)), pos)
#_, jvp_fn = torch.func.jvp(lambda x: f(x, data, model), pos)



In [93]:
output

tensor([[[-2.7650e-01,  1.5149e+00, -2.8976e-01,  ...,  1.9644e+00,
          -4.1806e-01, -9.0687e-01],
         [ 1.3620e-01,  1.6507e+00, -4.3054e-01,  ...,  6.2981e-01,
          -9.7201e-02, -1.1375e+00],
         [-9.8533e-02,  1.5343e+00, -2.0628e-01,  ...,  1.5074e+00,
          -6.5693e-01, -8.4682e-01],
         ...,
         [-5.1197e-01,  1.5668e+00, -4.0957e-01,  ...,  2.0028e+00,
          -3.2623e-01, -8.9610e-01],
         [ 2.4639e-01,  6.8669e-01,  4.5762e-01,  ...,  2.2897e+00,
          -4.6831e-01, -1.5964e+00],
         [-4.6865e-01,  5.8663e-01,  3.4989e-01,  ...,  1.5248e+00,
          -2.7791e-01, -1.3841e+00]],

        [[-3.1987e-01,  1.5374e+00, -2.9204e-01,  ...,  1.9518e+00,
          -4.2553e-01, -8.9740e-01],
         [ 6.2852e-02,  1.7195e+00, -3.5327e-01,  ...,  6.3971e-01,
          -1.1011e-01, -1.1836e+00],
         [-1.7572e-01,  1.4655e+00, -1.7963e-01,  ...,  1.5187e+00,
          -6.3084e-01, -8.0103e-01],
         ...,
         [-5.8822e-01,  1

In [97]:
vjp_fn(torch.ones_like(torch.ones_like(y)))[0]

tensor([[ -6.2902,   5.1938,  13.4935],
        [ 24.6666, -12.7259, -24.2249],
        [ -3.8296,  -5.3556,  -7.1037],
        ...,
        [  1.1242,   2.9582,  35.0233],
        [-31.1952, -51.6997,  42.6889],
        [  8.8059, 100.3142,  99.7052]], grad_fn=<AddBackward0>)

In [96]:
gr

tensor([[ -6.2902,   5.1938,  13.4935],
        [ 24.6666, -12.7259, -24.2249],
        [ -3.8296,  -5.3556,  -7.1037],
        ...,
        [  1.1242,   2.9582,  35.0233],
        [-31.1952, -51.6997,  42.6889],
        [  8.8059, 100.3142,  99.7052]])

In [18]:
y = f(
    data.pos,
    data,
    model,
)
gr = torch.autograd.grad(
    outputs=y,
    inputs=data.pos,
    grad_outputs=torch.ones_like(y),
    retain_graph=False,
    create_graph=False,
    allow_unused=False,
    is_grads_batched=False,
)[0]


func_output, jvp = torch.autograd.functional.jvp(torch.func.functionalize(lambda x: f(x, data, model=model), remove="mutations_and_views"),
                                                 data.pos,
                                                 v=torch.ones_like(data.pos),
                                                 strict=True)



RuntimeError: The output of the user-provided function is independent of input 0. This is not allowed in strict mode.

In [None]:
def create_batch(pos, data):
    batch_idx = data.batch
    data_list = []
    data.pop("pos")

In [23]:
batch_idx = data.batch
data_list = []
batch_idx

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4])

In [36]:
frames = 5
data = Batch.from_data_list(data_batch[:frames])
data.pos.requires_grad = True
pos = data.pop("pos")
pos_list = []
batch_unique_idx = torch.unique(batch_idx)
for uidx in batch_unique_idx:
    pos_list.append(pos[batch_idx == uidx])

data_list = data.to_data_list()
for i, pos in enumerate(pos_list):
    data_list[i]["pos"] = pos

new_batch = Batch.from_data_list(data_list)

In [38]:
new_batch.batch

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4])

In [42]:
data.batch

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4])

In [43]:
new_batch.pos

tensor([[-2.0042e+00,  1.1571e+00, -2.5200e+00],
        [-2.0273e+00,  1.1828e+00, -1.9011e-01],
        [-2.0004e+00, -1.1554e+00, -1.8034e+00],
        [-2.0192e+00, -3.4735e+00, -1.0363e+00],
        [ 7.2491e-03, -2.4519e+00, -1.1423e-01],
        [ 1.4912e-04, -4.6227e+00, -1.7621e+00],
        [ 2.0244e+00, -3.4753e+00, -1.0387e+00],
        [ 4.0215e+00, -2.3029e+00, -1.8622e-01],
        [ 4.0186e+00, -4.6295e+00, -1.7866e+00],
        [ 6.0143e+00, -3.4703e+00, -1.0251e+00],
        [ 6.0130e+00, -1.1596e+00, -1.7765e+00],
        [ 6.0117e+00,  1.1477e+00, -1.8844e-01],
        [ 4.0055e+00,  2.3124e+00, -1.7776e+00],
        [ 4.0117e+00,  4.6355e+00, -1.8742e-01],
        [ 6.0139e+00,  3.4695e+00, -1.0276e+00],
        [ 6.0129e+00,  3.4715e+00, -3.3625e+00],
        [ 6.0129e+00,  1.1572e+00, -2.5204e+00],
        [ 4.0069e+00, -1.3589e-03, -1.0381e+00],
        [ 2.0252e+00,  1.1829e+00, -1.8995e-01],
        [ 2.0022e+00, -1.1564e+00, -1.8098e+00],
        [ 2.0047e+00

## LinOp

In [None]:
def lin_op00(c0, data, kernel):
    pass

def lin_op01(c1, data, kernel):
    pass

def lin_op10(c0, data, kernel):
    pass

def lin_op11(c1, data, kernel):
    pass

def lin_op(c0, c1, data, kernel):
    _c0 = lin_op00(c0) + lin_op(c1)
    _c1 = lin_op10(c0) + lin_op(c1)
    return _c0, _c1
    