In [1]:
from pathlib import Path

import numpy as np
import torch
import torch_geometric as pyg
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from torch_geometric.data import Batch, Data
import functorch
import copy
from ocpmodels.transfer_learning.models.distribution_regression import (
    GaussianKernel,
    KernelMeanEmbeddingRidgeRegression,
    LinearMeanEmbeddingKernel,
    StandardizedOutputRegression,
    median_heuristic,
)

from ocpmodels.transfer_learning.common.utils import (
    ATOMS_TO_GRAPH_KWARGS,
    load_xyz_to_pyg_batch,
    load_xyz_to_pyg_data,
)
from ocpmodels.transfer_learning.loaders import BaseLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%cd /home/novelli/ocp
### Load checkpoint
CHECKPOINT_PATH = Path("checkpoints/s2ef_efwt/all/schnet/schnet_all_large.pt")
checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu")

### Load data
DATA_PATH = Path("data/luigi/example-traj-Fe-N2-111.xyz")
raw_data, data_batch, num_frames, num_atoms = load_xyz_to_pyg_batch(DATA_PATH, ATOMS_TO_GRAPH_KWARGS["schnet"])
raw_data, data_list, num_frames, num_atoms = load_xyz_to_pyg_data(DATA_PATH, ATOMS_TO_GRAPH_KWARGS["schnet"])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

representation_layer = 1
base_loader = BaseLoader(
    checkpoint["config"],
    representation=True,
    representation_kwargs={
        "representation_layer": representation_layer,
    },
)
base_loader.load_checkpoint(CHECKPOINT_PATH, strict_load=False)
model = base_loader.model
model.to(device)
model.mekrr_forces = True

/home/novelli/ocp


  cell = torch.Tensor(atoms.get_cell()).view(1, 3, 3)
	Unexpected key(s) in state_dict: "atomic_mass", "interactions.1.mlp.0.weight", "interactions.1.mlp.0.bias", "interactions.1.mlp.2.weight", "interactions.1.mlp.2.bias", "interactions.1.conv.lin1.weight", "interactions.1.conv.lin2.weight", "interactions.1.conv.lin2.bias", "interactions.1.conv.nn.0.weight", "interactions.1.conv.nn.0.bias", "interactions.1.conv.nn.2.weight", "interactions.1.conv.nn.2.bias", "interactions.1.lin.weight", "interactions.1.lin.bias", "interactions.2.mlp.0.weight", "interactions.2.mlp.0.bias", "interactions.2.mlp.2.weight", "interactions.2.mlp.2.bias", "interactions.2.conv.lin1.weight", "interactions.2.conv.lin2.weight", "interactions.2.conv.lin2.bias", "interactions.2.conv.nn.0.weight", "interactions.2.conv.nn.0.bias", "interactions.2.conv.nn.2.weight", "interactions.2.conv.nn.2.bias", "interactions.2.lin.weight", "interactions.2.lin.bias", "interactions.3.mlp.0.weight", "interactions.3.mlp.0.bias", "intera

In [3]:
class LinearKernel:
    def __init__(self):
        pass
    def __call__(self, x, y):
        return x @ y.T

lkernel = LinearMeanEmbeddingKernel(LinearKernel())

frames = 2
dat = Batch.from_data_list(data_batch[:frames]).to(device)
pos = dat["pos"]
pos.requires_grad = True

In [4]:
c_0 = torch.zeros((frames,), device=device, requires_grad=False)
c_1 = torch.zeros_like(pos, requires_grad=False)

In [24]:
def lin_op(c_0, c_1, dat, kernel):
    latent_vars = model(dat)[0]
    latent_vars = latent_vars.reshape((-1, num_atoms, latent_vars.shape[-1])).clone().detach()
    def model_wrapper(pos, **data):
        data["pos"] = pos
        data = Data.from_dict(data).to(device)
        return model(data)[0]
    _dat = copy.deepcopy(dat).to_dict()
    pos = _dat.pop("pos")

    c_0 = kernel(latent_vars, latent_vars)@c_0
    print(c_0.shape)
    jvp =  torch.autograd.functional.jvp(lambda x: kernel(model_wrapper(x, **_dat).reshape(-1, num_atoms, latent_vars.shape[-1]), latent_vars), pos, c_1)[1]
    print(jvp.shape)
    c_1 = c_1
    return c_0, c_1

In [25]:
_ = lin_op(c_0, c_1, dat, lkernel)



torch.Size([2])
torch.Size([2, 2])


In [None]:
def f(pos, data, model, kernel):
    h = model(data)[0]
    h_ = h.clone().detach()
    return kernel(h, h_)


y = f(
    pos,
    dat,
    model,
    lkernel,
)
m = y.shape[0]
gr = torch.autograd.grad(
    outputs=y,
    inputs=pos,
    grad_outputs=torch.ones_like(y),
    retain_graph=False,
    create_graph=False,
    allow_unused=False,
    is_grads_batched=False,
)[0]
pos.shape
gr.shape

output, vjp_fn = torch.func.vjp(lambda x: f(x, dat, model, lkernel), pos)