In [None]:
from pathlib import Path

import numpy as np
import torch
import torch_geometric as pyg
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from torch_geometric.data import Batch, Data
import functorch
import copy
from ocpmodels.transfer_learning.models.distribution_regression import (
    GaussianKernel,
    KernelMeanEmbeddingRidgeRegression,
    LinearMeanEmbeddingKernel,
    StandardizedOutputRegression,
    median_heuristic,
)

from ocpmodels.transfer_learning.common.utils import (
    ATOMS_TO_GRAPH_KWARGS,
    load_xyz_to_pyg_batch,
    load_xyz_to_pyg_data,
)
from ocpmodels.transfer_learning.loaders import BaseLoader

In [None]:
%cd ../..

In [None]:
#%cd ../..
### Load checkpoint
CHECKPOINT_PATH = Path("checkpoints/s2ef_efwt/all/schnet/schnet_all_large.pt")
checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu")

### Load data
DATA_PATH = Path("data/luigi/example-traj-Fe-N2-111.xyz")
raw_data, data_batch, num_frames, num_atoms = load_xyz_to_pyg_batch(DATA_PATH, ATOMS_TO_GRAPH_KWARGS["schnet"])
raw_data, data_list, num_frames, num_atoms = load_xyz_to_pyg_data(DATA_PATH, ATOMS_TO_GRAPH_KWARGS["schnet"])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

representation_layer = 2
base_loader = BaseLoader(
    checkpoint["config"],
    representation=True,
    representation_kwargs={
        "representation_layer": representation_layer,
    },
)
base_loader.load_checkpoint(CHECKPOINT_PATH, strict_load=False)
model = base_loader.model
model.to(device)
model.mekrr_forces = True
model.regress_forces = False

In [None]:
class LinearKernel:
    def __init__(self):
        pass
    def __call__(self, x, y):
        return x @ y.T

lkernel = LinearMeanEmbeddingKernel(LinearKernel())

frames = 2
dat = Batch.from_data_list(data_batch[:frames]).to(device)
dat.pos.requires_grad = True

In [None]:
def lin_op(c_0, c_1, dat, kernel):
    latent_vars = model(dat)[0]
    latent_vars = latent_vars.reshape((-1, num_atoms, latent_vars.shape[-1])).clone()
    def model_wrapper(pos, **data):
        data["pos"] = pos
        data = Data.from_dict(data).to(device)
        return model(data)[0]
    _dat = copy.deepcopy(dat).to_dict()
    pos = _dat.pop("pos")

    c_0 = kernel(latent_vars, latent_vars)@c_0
    print(c_0.shape)
    jvp =  torch.autograd.functional.jvp(lambda x: kernel(model_wrapper(x, **_dat).reshape(-1, num_atoms, latent_vars.shape[-1]), latent_vars), pos, c_1)[1]
    print(jvp.shape)
    c_1 = c_1
    return c_0, c_1

In [None]:
_ = lin_op(c_0, c_1, dat, lkernel)

In [None]:
gk = GaussianKernel()
gk.sigma = 1.0
gklme = LinearMeanEmbeddingKernel(gk)

In [None]:
kernel = gklme
T, n = frames, num_atoms
latent_vars = model(dat)
d = latent_vars.shape[-1]
latent_vars = latent_vars.reshape((T, n, d)).clone().detach()
def model_wrapper(pos_tensor, **data):
    T, n, _ = pos_tensor.shape
    data["pos"] = pos.reshape(-1, 3)
    data = Data.from_dict(data).to(device)
    return model(data).reshape(T, n, -1)
_dat = copy.deepcopy(dat).to_dict()
pos = _dat.pop("pos")

# Initialize things
c0 = torch.zeros(T, 1).to(device)
c1 = torch.zeros(T, n, 3).to(device)
# Kernel
with torch.no_grad():
    h = model(dat)
    h = h.reshape(T, n, d)
    sigma = median_heuristic(h, h)
gklme.kernel.sigma = sigma
    
lmbda = 1e-4
k = kernel(h, h) 
klmbda = (k - k.diag().diag()) + (lmbda + k.diag()).diag()
# First part
a0 = klmbda @ c0
# a0 += /

In [None]:
#func = lambda x: kernel(model_wrapper(x, **_dat), latent_vars)
def func(x, **_dat):
    return model_wrapper(x, **_dat)
    #return x.sum()

pos_tensor = pos.reshape(T, n, 3)
f = lambda x: func(x, **_dat)
y = f(pos_tensor)
out, jvp =  torch.autograd.functional.jvp(f, pos_tensor, torch.ones_like(pos_tensor), strict=True)
print(out)
print(jvp)

In [None]:
from torchviz import make_dot

make_dot(y, params={"pos": pos_tensor})

In [None]:
_dat.to_data_list()

In [None]:
model(dat)

In [None]:
c1 = torch.zeros(T, n, 3).to(device)
y_pred = f(pos_tensor)
grad_pred = (
            torch.autograd.grad(
                y_pred,
                pos,
                grad_outputs=torch.ones_like(y_pred),
                create_graph=False,
                allow_unused=False,
            )[0]
)
grad_pred.sum()

In [None]:
##### _c1 = torch.zeros(T, n, 3)
out_dict = {"out": [], "jvp": []}
for t in range(T):
    _dat = dat[t].to_dict()
    pos = _dat.pop("pos")
    func = lambda x: kernel(model_wrapper(x, **_dat).reshape(1, n, d), latent_vars)
    out, jvp =  torch.autograd.functional.jvp(func, pos[t], c1[t].reshape(1, n, 3))
    out_dict["out"].append(out)
    out_dict["jvp"].append(jvp)
    
out = torch.cat(out_dict["out"])
jvp = torch.cat(out_dict["jvp"])

In [None]:
print(jvp)

In [None]:
jvp =  torch.autograd.functional.jvp(lambda x: kernel(model_wrapper(x, **_dat).reshape(-1, num_atoms, latent_vars.shape[-1]), latent_vars), pos, c_1)[1]

In [None]:
print(pos.shape, func(pos).shape)

In [None]:
def f(pos, data, model, kernel):
    h = model(data)[0]
    h_ = h.clone().detach()
    return kernel(h, h_)


y = f(
    pos,
    dat,
    model,
    lkernel,
)
m = y.shape[0]
gr = torch.autograd.grad(
    outputs=y,
    inputs=pos,
    grad_outputs=torch.ones_like(y),
    retain_graph=False,
    create_graph=False,
    allow_unused=False,
    is_grads_batched=False,
)[0]
pos.shape
gr.shape

output, vjp_fn = torch.func.vjp(lambda x: f(x, dat, model, lkernel), pos)

# Temp

In [None]:
# First we check that we can get gradients

In [None]:
def f(pos, data, model, kernel):
    h = model(data)[0]
    h_ = h.clone().detach()
    return kernel(h, h_)


y = f(
    pos,
    dat,
    model,
    lkernel,
)
m = y.shape[0]
gr = torch.autograd.grad(
    outputs=y,
    inputs=pos,
    grad_outputs=torch.ones_like(y),
    retain_graph=False,
    create_graph=False,
    allow_unused=False,
    is_grads_batched=False,
)[0]

In [None]:
phi = model(dat)

In [None]:
def gradient(y, x, grad_outputs=None):
    """Compute dy/dx @ grad_outputs"""
    if grad_outputs is None:
        grad_outputs = torch.ones_like(y)
    grad = torch.autograd.grad(y, [x], grad_outputs = grad_outputs, create_graph=True)[0]
    return grad

def jacobian(y, x):
    """Compute dy/dx = dy/dx @ grad_outputs; 
    for grad_outputs in[1, 0, ..., 0], [0, 1, 0, ..., 0], ...., [0, ..., 0, 1]"""
    jac = torch.zeros(y.shape[0], x.shape[0]) 
    for i in range(y.shape[0]):
        grad_outputs = torch.zeros_like(y)
        grad_outputs[i] = 1
        jac[i] = gradient(y, x, grad_outputs = grad_outputs)
    return jac

In [None]:
J = jacobian(phi, dat.pos)

In [None]:
def f(pos, data, model):
    h = model(data).reshape(frames, num_atoms, d)
    return h


y = f(
    dat.pos,
    dat,
    model,
)
m = y.shape[0]
gr = torch.autograd.grad(
    outputs=y,
    inputs=dat.pos,
    grad_outputs=torch.ones_like(y),
    retain_graph=False,
    create_graph=False,
    allow_unused=False,
    is_grads_batched=False,
)[0]

output, vjp_fn = torch.func.vjp(lambda x: f(x, dat, model), dat.pos)

In [None]:
def gradient(y, x, grad_outputs=None):
    """Compute dy/dx @ grad_outputs"""
    if grad_outputs is None:
        grad_outputs = torch.ones_like(y)
    grad = torch.autograd.grad(y, [x], grad_outputs = grad_outputs, create_graph=True)[0]
    return grad

def jacobian(y, x):
    """Compute dy/dx = dy/dx @ grad_outputs; 
    for grad_outputs in[1, 0, ..., 0], [0, 1, 0, ..., 0], ...., [0, ..., 0, 1]"""
    jac = torch.zeros(y.shape[0], x.shape[0]) 
    for i in range(y.shape[0]):
        grad_outputs = torch.zeros_like(y)
        grad_outputs[i] = 1
        jac[i] = gradient(y, x, grad_outputs = grad_outputs)
    return jac

In [None]:
dat_dict = dat.to_dict()
print(dat_dict.keys())

In [None]:
data_keys = list(dat_dict.keys())

In [None]:
def f(pos, cell, atomic_numbers, natoms, tags, edge_index, cell_offsets, y, force, fixed, batch, ptr, model):
    data = Data.from_dict({
        "pos": pos,
        "cell": cell,
        "atomic_numbers": atomic_numbers,
        "natoms": natoms, 
        "tags": tags,
        "edge_index": edge_index, 
        "cell_offsets": cell_offsets,
        "y": y,
        "force": force,
        "fixed": fixed,
        "batch": batch,
        "ptr": ptr
    })
    h = model(data)
    return h

dat_dict = dat.to_dict()
pos = dat_dict.pop("pos")
y = f(pos=pos, **dat_dict, model=model)

gr = torch.autograd.grad(
    outputs=y,
    inputs=pos,
    grad_outputs=torch.ones_like(y),
    retain_graph=False,
    create_graph=False,
    allow_unused=False,
    is_grads_batched=False,
)[0]

func_output, jvp = torch.autograd.functional.jvp(lambda x: f(x, **dat_dict, model=model),
                                                 pos,
                                                 v=torch.ones_like(pos),
                                                 strict=True)

In [None]:
jvp

In [None]:
def f(pos, , model):
    h = model(data).reshape(-1)
    return h

y = f(
    dat.pos,
    dat,
    model,
)
jac = jacobian(y, dat.pos.reshape(-1))

In [None]:
from ocpmodels.common.utils import get_max_neighbors_mask, get_pbc_distances, radius_graph_pbc

In [None]:
#from torch_scatter import scatter, segment_coo, segment_csr
import torch
import torch_scatter

# For the CSR operator:
class SumCSR(torch.autograd.Function):
    @staticmethod
    def forward(ctx, values, groups):
        ctx.save_for_backward(groups)
        return torch_scatter.segment_csr(values, groups, reduce="sum")

    @staticmethod
    def backward(ctx, grad_output):
        (groups,) = ctx.saved_tensors
        return GatherCSR.apply(grad_output, groups), None


class GatherCSR(torch.autograd.Function):
    @staticmethod
    def forward(ctx, values, groups):
        ctx.save_for_backward(groups)
        return torch_scatter.gather_csr(values, groups)

    @staticmethod
    def backward(ctx, grad_output):
        (groups,) = ctx.saved_tensors
        return SumCSR.apply(grad_output, groups), None


# For the COO operator:
class SumCOO(torch.autograd.Function):
    @staticmethod
    def forward(ctx, values, groups, dim_size):
        ctx.save_for_backward(groups)
        ctx.dim_size = dim_size
        return torch_scatter.segment_coo(
            values, groups, dim_size=dim_size, reduce="sum"
        )

    @staticmethod
    def backward(ctx, grad_output):
        (groups,) = ctx.saved_tensors
        return GatherCOO.apply(grad_output, groups, ctx.dim_size), None, None


class GatherCOO(torch.autograd.Function):
    @staticmethod
    def forward(ctx, values, groups, dim_size):
        ctx.save_for_backward(groups)
        ctx.dim_size = dim_size
        return torch_scatter.gather_coo(values, groups)

    @staticmethod
    def backward(ctx, grad_output):
        (groups,) = ctx.saved_tensors
        return SumCOO.apply(grad_output, groups, ctx.dim_size), None, None

def radius_graph_pbc(data, radius, max_num_neighbors_threshold, pbc=[True, True, True]):
    device = data.pos.device
    batch_size = len(data.natoms)

    if hasattr(data, "pbc"):
        data.pbc = torch.atleast_2d(data.pbc)
        for i in range(3):
            if not torch.any(data.pbc[:, i]).item():
                pbc[i] = False
            elif torch.all(data.pbc[:, i]).item():
                pbc[i] = True
            else:
                raise RuntimeError(
                    "Different structures in the batch have different PBC configurations. This is not currently supported."
                )

    # position of the atoms
    atom_pos = data.pos

    # Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch
    num_atoms_per_image = data.natoms
    num_atoms_per_image_sqr = (num_atoms_per_image**2).long()

    # index offset between images
    index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image

    index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr)
    num_atoms_per_image_expand = torch.repeat_interleave(num_atoms_per_image, num_atoms_per_image_sqr)

    # Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image
    # that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement
    # the following (but 10x faster since it removes the for loop)
    # for batch_idx in range(batch_size):
    #    batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0)
    num_atom_pairs = torch.sum(num_atoms_per_image_sqr)
    index_sqr_offset = torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr
    index_sqr_offset = torch.repeat_interleave(index_sqr_offset, num_atoms_per_image_sqr)
    atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset

    # Compute the indices for the pairs of atoms (using division and mod)
    # If the systems get too large this apporach could run into numerical precision issues
    index1 = (torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor")) + index_offset_expand
    index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand
    # Get the positions for each atom
    pos1 = torch.index_select(atom_pos, 0, index1)
    pos2 = torch.index_select(atom_pos, 0, index2)

    # Calculate required number of unit cells in each direction.
    # Smallest distance between planes separated by a1 is
    # 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane.
    # Note that the unit cell volume V = a1 * (a2 x a3) and that
    # (a2 x a3) / V is also the reciprocal primitive vector
    # (crystallographer's definition).

    cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1)
    cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True)

    if pbc[0]:
        inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1)
        rep_a1 = torch.ceil(radius * inv_min_dist_a1)
    else:
        rep_a1 = data.cell.new_zeros(1)

    if pbc[1]:
        cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1)
        inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1)
        rep_a2 = torch.ceil(radius * inv_min_dist_a2)
    else:
        rep_a2 = data.cell.new_zeros(1)

    if pbc[2]:
        cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1)
        inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1)
        rep_a3 = torch.ceil(radius * inv_min_dist_a3)
    else:
        rep_a3 = data.cell.new_zeros(1)

    # Take the max over all images for uniformity. This is essentially padding.
    # Note that this can significantly increase the number of computed distances
    # if the required repetitions are very different between images
    # (which they usually are). Changing this to sparse (scatter) operations
    # might be worth the effort if this function becomes a bottleneck.
    max_rep = [rep_a1.max(), rep_a2.max(), rep_a3.max()]

    # Tensor of unit cells
    cells_per_dim = [torch.arange(-rep, rep + 1, device=device, dtype=torch.float) for rep in max_rep]
    unit_cell = torch.cartesian_prod(*cells_per_dim)
    num_cells = len(unit_cell)
    unit_cell_per_atom = unit_cell.view(1, num_cells, 3).repeat(len(index2), 1, 1)
    unit_cell = torch.transpose(unit_cell, 0, 1)
    unit_cell_batch = unit_cell.view(1, 3, num_cells).expand(batch_size, -1, -1)

    # Compute the x, y, z positional offsets for each cell in each image
    data_cell = torch.transpose(data.cell, 1, 2)
    pbc_offsets = torch.bmm(data_cell, unit_cell_batch)
    pbc_offsets_per_atom = torch.repeat_interleave(pbc_offsets, num_atoms_per_image_sqr, dim=0)

    # Expand the positions and indices for the 9 cells
    pos1 = pos1.view(-1, 3, 1).expand(-1, -1, num_cells)
    pos2 = pos2.view(-1, 3, 1).expand(-1, -1, num_cells)
    index1 = index1.view(-1, 1).repeat(1, num_cells).view(-1)
    index2 = index2.view(-1, 1).repeat(1, num_cells).view(-1)
    # Add the PBC offsets for the second atom
    pos2 = pos2 + pbc_offsets_per_atom

    # Compute the squared distance between atoms
    atom_distance_sqr = torch.sum((pos1 - pos2) ** 2, dim=1)
    atom_distance_sqr = atom_distance_sqr.view(-1)

    # Remove pairs that are too far apart
    mask_within_radius = torch.le(atom_distance_sqr, radius * radius)
    # Remove pairs with the same atoms (distance = 0.0)
    mask_not_same = torch.gt(atom_distance_sqr, 0.0001)
    mask = torch.logical_and(mask_within_radius, mask_not_same)
    index1 = torch.masked_select(index1, mask)
    index2 = torch.masked_select(index2, mask)
    unit_cell = torch.masked_select(unit_cell_per_atom.view(-1, 3), mask.view(-1, 1).expand(-1, 3))
    unit_cell = unit_cell.view(-1, 3)
    atom_distance_sqr = torch.masked_select(atom_distance_sqr, mask)

    # Remove due to errors with jvps
#     mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask(
#         natoms=data.natoms,
#         index=index1,
#         atom_distance=atom_distance_sqr,
#         max_num_neighbors_threshold=max_num_neighbors_threshold,
#     )
    num_neighbors_image = torch.tensor(len(index1)).to(device)

#     if not torch.all(mask_num_neighbors):
#         # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors
#         index1 = torch.masked_select(index1, mask_num_neighbors)
#         index2 = torch.masked_select(index2, mask_num_neighbors)
#         unit_cell = torch.masked_select(unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3))
#         unit_cell = unit_cell.view(-1, 3)
    unit_cell = unit_cell.view(-1, 3)
    edge_index = torch.stack((index2, index1))

    return edge_index, unit_cell, num_neighbors_image

def get_max_neighbors_mask(natoms, index, atom_distance, max_num_neighbors_threshold):
    """
    Give a mask that filters out edges so that each atom has at most
    `max_num_neighbors_threshold` neighbors.
    Assumes that `index` is sorted.
    """
    device = natoms.device
    num_atoms = natoms.sum()

    # Get number of neighbors
    # segment_coo assumes sorted index
    ones = index.new_ones(1).expand_as(index)
    num_neighbors = segment_coo(ones, index, dim_size=num_atoms)
    max_num_neighbors = num_neighbors.max()
    num_neighbors_thresholded = num_neighbors.clamp(max=max_num_neighbors_threshold)

    # Get number of (thresholded) neighbors per image
    image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long)
    image_indptr[1:] = torch.cumsum(natoms, dim=0)
    num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr)

    # If max_num_neighbors is below the threshold, return early
    if max_num_neighbors <= max_num_neighbors_threshold or max_num_neighbors_threshold <= 0:
        mask_num_neighbors = torch.tensor([True], dtype=bool, device=device).expand_as(index)
        return mask_num_neighbors, num_neighbors_image

    # Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors.
    # Fill with infinity so we can easily remove unused distances later.
    distance_sort = torch.full([num_atoms * max_num_neighbors], np.inf, device=device)

    # Create an index map to map distances from atom_distance to distance_sort
    # index_sort_map assumes index to be sorted
    index_neighbor_offset = torch.cumsum(num_neighbors, dim=0) - num_neighbors
    index_neighbor_offset_expand = torch.repeat_interleave(index_neighbor_offset, num_neighbors)
    index_sort_map = index * max_num_neighbors + torch.arange(len(index), device=device) - index_neighbor_offset_expand
    distance_sort.index_copy_(0, index_sort_map, atom_distance)
    distance_sort = distance_sort.view(num_atoms, max_num_neighbors)

    # Sort neighboring atoms based on distance
    distance_sort, index_sort = torch.sort(distance_sort, dim=1)
    # Select the max_num_neighbors_threshold neighbors that are closest
    distance_sort = distance_sort[:, :max_num_neighbors_threshold]
    index_sort = index_sort[:, :max_num_neighbors_threshold]

    # Offset index_sort so that it indexes into index
    index_sort = index_sort + index_neighbor_offset.view(-1, 1).expand(-1, max_num_neighbors_threshold)
    # Remove "unused pairs" with infinite distances
    mask_finite = torch.isfinite(distance_sort)
    index_sort = torch.masked_select(index_sort, mask_finite)

    # At this point index_sort contains the index into index of the
    # closest max_num_neighbors_threshold neighbors per atom
    # Create a mask to remove all pairs not in index_sort
    mask_num_neighbors = torch.zeros(len(index), device=device, dtype=bool)
    mask_num_neighbors.index_fill_(0, index_sort, True)

    return mask_num_neighbors, num_neighbors_image

In [None]:
data = Batch.from_data_list(data_batch[:100])

In [None]:
edge_index, cell_offsets, neighbors = radius_graph_pbc(data, 6.0, 10000)

In [None]:
out = get_pbc_distances(
    data.pos,
    edge_index,
    data.cell,
    cell_offsets,
    neighbors,
    return_offsets=True,
    return_distance_vec=True,
)

In [None]:
out

In [None]:
cell_offsets, neighbors

In [None]:
from ocpmodels.common.utils import get_max_neighbors_mask, get_pbc_distances, radius_graph_pbc

In [None]:
edge_index1, cell_offsets1, neighbors1 = radius_graph_pbc(data, 6.0, 50)

In [None]:
out1 = get_pbc_distances(
    data.pos,
    edge_index1,
    data.cell,
    cell_offsets1,
    neighbors1,
    return_offsets=True,
    return_distance_vec=True,
)

In [None]:
out1

In [None]:
torch.equal(out["edge_index"], out1["edge_index"])

In [None]:
for key, val in out.items():
    print(key)
    print(torch.equal(val, val))