In [1]:
from pathlib import Path

import numpy as np
import torch
import torch_geometric as pyg
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from torch_geometric.data import Batch, Data
import functorch
import copy
from ocpmodels.transfer_learning.models.distribution_regression import (
    GaussianKernel,
    KernelMeanEmbeddingRidgeRegression,
    LinearMeanEmbeddingKernel,
    StandardizedOutputRegression,
    median_heuristic,
)

from ocpmodels.transfer_learning.common.utils import (
    ATOMS_TO_GRAPH_KWARGS,
    load_xyz_to_pyg_batch,
    load_xyz_to_pyg_data,
)
from ocpmodels.transfer_learning.loaders import BaseLoader

In [2]:
%cd ../..

/home/isak/life/references/projects/src/python_lang/ocp


# Load model

In [3]:
#%cd ../..
### Load checkpoint
CHECKPOINT_PATH = Path("checkpoints/s2ef_efwt/all/schnet/schnet_all_large.pt")
checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu")

### Load data
DATA_PATH = Path("data/luigi/example-traj-Fe-N2-111.xyz")
raw_data, data_batch, num_frames, num_atoms = load_xyz_to_pyg_batch(DATA_PATH, ATOMS_TO_GRAPH_KWARGS["schnet"])
raw_data, data_list, num_frames, num_atoms = load_xyz_to_pyg_data(DATA_PATH, ATOMS_TO_GRAPH_KWARGS["schnet"])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

representation_layer = 2
base_loader = BaseLoader(
    checkpoint["config"],
    representation=True,
    representation_kwargs={
        "representation_layer": representation_layer,
    },
)
base_loader.load_checkpoint(CHECKPOINT_PATH, strict_load=False)
model = base_loader.model
model.to(device)
model.mekrr_forces = True
model.regress_forces = False

  cell = torch.Tensor(atoms.get_cell()).view(1, 3, 3)
	Unexpected key(s) in state_dict: "atomic_mass", "interactions.2.mlp.0.weight", "interactions.2.mlp.0.bias", "interactions.2.mlp.2.weight", "interactions.2.mlp.2.bias", "interactions.2.conv.lin1.weight", "interactions.2.conv.lin2.weight", "interactions.2.conv.lin2.bias", "interactions.2.conv.nn.0.weight", "interactions.2.conv.nn.0.bias", "interactions.2.conv.nn.2.weight", "interactions.2.conv.nn.2.bias", "interactions.2.lin.weight", "interactions.2.lin.bias", "interactions.3.mlp.0.weight", "interactions.3.mlp.0.bias", "interactions.3.mlp.2.weight", "interactions.3.mlp.2.bias", "interactions.3.conv.lin1.weight", "interactions.3.conv.lin2.weight", "interactions.3.conv.lin2.bias", "interactions.3.conv.nn.0.weight", "interactions.3.conv.nn.0.bias", "interactions.3.conv.nn.2.weight", "interactions.3.conv.nn.2.bias", "interactions.3.lin.weight", "interactions.3.lin.bias", "interactions.4.mlp.0.weight", "interactions.4.mlp.0.bias", "intera

# Data Loading

In [4]:
def prepare_batch(batch, frames=5):
    data = Batch.from_data_list(data_batch[:frames])
    data.pos.requires_grad = True
    return data

In [5]:
frames = 20
data = prepare_batch(data_batch, frames=frames)

# Precalculate things

In [6]:
h = model(data[0])
d = h.shape[-1]



# Autograd

In [7]:
import copy 

def f(pos, data, model):
    pos_list = []
    batch_idx = data.batch
    batch_unique_idx = torch.unique(batch_idx)
    for uidx in batch_unique_idx:
        pos_list.append(pos[batch_idx == uidx])

    data_list = data.to_data_list()
    for i, pos in enumerate(pos_list):
        data_list[i].pos = pos

    new_batch = Batch.from_data_list(data_list)
    h = model(new_batch)
    return h

def prepare_batch_for_f(batch, frames=None):
    batch = copy.deepcopy(batch)
    if frames is not None:
        batch = prepare_batch(batch, frames)
    batch.pos.requires_grad = True
    pos = batch.pop("pos")
    return batch, pos

## Try autograd interface

In [8]:
no_pos_data, pos = prepare_batch_for_f(data_batch, frames=10)
y = f(
    pos,
    no_pos_data,
    model,
)
m = y.shape[0]
gr = torch.autograd.grad(
    outputs=y,
    inputs=pos,
    grad_outputs=torch.ones_like(y),
    retain_graph=False,
    create_graph=False,
    allow_unused=False,
    is_grads_batched=False,
)[0]
gr



tensor([[  -6.2902,    5.1938,   13.4935],
        [  24.6666,  -12.7259,  -24.2249],
        [  -3.8296,   -5.3556,   -7.1037],
        ...,
        [  12.9996,    2.3244,   14.8525],
        [  29.7450, -108.7377,  244.1773],
        [ -12.3267,   60.5505,  -66.2716]])

## torch.func interface

In [9]:
from torch.func import jvp, vjp, grad, vmap

In [10]:
# function
func = torch.func.functionalize(lambda x: f(x, data, model))

In [11]:
# First vjp
non_pos_data, pos = prepare_batch_for_f(data_batch, frames=5)
y = f(
    pos,
    non_pos_data,
    model,
)
output, vjp_fn = vjp(func, pos)
vjp_fn(torch.ones_like(torch.ones_like(y)))[0]



IndexError: The shape of the mask [940] at index 0 does not match the shape of the indexed tensor [235, 3] at index 0

In [None]:
no_pos_data, pos = prepare_batch_for_f(data_batch, frames=5)
out, jvp_out = jvp(func, (pos,), (torch.randn(*pos.shape),), strict=True)

# Gradient KRR

Now we focus on using the kernel KRR with gradients, so we need to change the function.

In [14]:
gk = GaussianKernel()
gk.sigma = 1.0
gklme = LinearMeanEmbeddingKernel(gk)

In [15]:
def _pre_batch_pos_for_f(pos, no_pos_data):
    pos_list = []
    batch_idx = no_pos_data.batch
    batch_unique_idx = torch.unique(batch_idx)
    for uidx in batch_unique_idx:
        pos_list.append(pos[batch_idx == uidx])

    data_list = no_pos_data.to_data_list()
    for i, pos in enumerate(pos_list):
        data_list[i].pos = pos

    new_batch = Batch.from_data_list(data_list)
    return new_batch

In [16]:
frames = 20
data = prepare_batch(data_batch, frames)
no_pos_data, pos = prepare_batch_for_f(data)

In [17]:
# Linear Operator: Solving the linear system to get the two coefficients

# Initial coefficients
c0 = torch.randn(frames)
c1 = torch.randn(frames, num_atoms, 3)

# Will be passed into the model
f_kernel_kwargs = {
    "pos": pos,
    "data": no_pos_data,
    "kernel": gklme, 
    "model": model,
}

# Updates, these are split up into 4 updates 00, 01, 10, 11 for each of the submatrices
# We want to express them using autograd

In [18]:
# linop00
def f00(pos, data, kernel, model, num_atoms=num_atoms, d=d):
    frames = len(torch.unique(data.batch))
    new_batch = _pre_batch_pos_for_f(pos, data)
    h = model(new_batch).reshape(frames, num_atoms, d)
    h_ = h.clone().detach() # Stop gradients
    k = kernel(h, h_)
    return k

def linop00(c0, **f_kernel_kwargs):
    """f_kernel_kwargs are additional kwargs we pass onto f"""
    with torch.no_grad():
        k = f00(**f_kernel_kwargs)
    return torch.matmul(k, c0)    

c00 = linop00(c0, **f_kernel_kwargs)



In [29]:
# linop01

def f01(pos, data, kernel, model, num_atoms=num_atoms, d=d):
    frames = len(torch.unique(data.batch))
    new_batch = _pre_batch_pos_for_f(pos.reshape(frames*num_atoms, -1), data)
    h = model(new_batch).reshape(frames, num_atoms, d)
    h_ = h.clone().detach() # Stop gradients
    k = kernel(h, h_)
    return k

def linop01(c1, **f_kernel_kwargs):
    """f_kernel_kwargs are additional kwargs we pass onto f"""
    pos = f_kernel_kwargs.pop("pos")
    # Create function
    func = torch.func.functionalize(lambda x: f01(x, c1, **f_kernel_kwargs))
    frames = len(torch.unique(data.batch))
    print(func(pos.reshape(frames, num_atoms, -1)))
    # Get JVP
    out, jvp_out = jvp(func, (pos.reshape(frames, num_atoms, -1),), (c1,), strict=True)
    # NOTE: this is done so that we have G_t c1 on each row,
    # since the linop01 is \sum_t^T G_t c1 we can make this by simpy summing over the correct axis
    # Note that we have nans, I do not know why, but for now we just set it to zero
    c01 = torch.nan_to_num(jvp_out, 0.0)
    return jvp_out

c01 = linop01(c1, **f_kernel_kwargs)

KeyError: 'pos'

In [None]:
#pos, no_pos_data


In [21]:
f_kernel_kwargs = {
    "pos": pos,
    "data": no_pos_data,
    "kernel": gklme, 
    "model": model,
}
pos = f_kernel_kwargs.pop("pos")

## Test
full_h = model(data)
full_h = full_h.reshape(frames, num_atoms, -1).clone().detach()

def f(pos, data, kernel, model, num_atoms=num_atoms, d=d):
    frames = len(torch.unique(data.batch))
    new_batch = _pre_batch_pos_for_f(pos, data)
    h = model(new_batch).reshape(frames, num_atoms, d)
    k = kernel(h, full_h) # 1 x T
    return k



In [27]:
dat = Batch.from_data_list([data_batch[0]])
pos = dat.pop("pos")
f_kernel_kwargs = {
    #"pos": pos,
    "data": dat,
    "kernel": gklme,
    "model": model,
}


# Create function
# func = torch.func.functionalize(lambda x: f(x, **f_kernel_kwargs))
# out, jvp_out = jvp(func, (pos,), (c1.reshape(frames, num_atoms, -1)[0],), strict=True)
new_c1s = []
for t in range(4):
    func = torch.func.functionalize(lambda x: f(x, **f_kernel_kwargs))
    out, jvp_out = jvp(func, (pos,), (c1.reshape(frames, num_atoms, -1)[t],), strict=True)
    jvp_out = torch.nan_to_num(jvp_out)
    new_c1s.append(jvp_out)
    




In [28]:
new_c1s

[tensor([[ 0.0000e+00, -1.6500e-02, -3.7023e-02, -7.3338e-03, -2.8508e-03,
          -2.4859e-03, -3.6808e-03, -4.6446e-03, -1.5159e-03,  5.4215e-04,
           3.0420e-04,  8.6274e-05,  9.2874e-04,  4.0223e-04,  9.4349e-04,
           1.8706e-03,  1.8683e-03,  1.1277e-03,  2.4083e-04,  9.9377e-04]],
        grad_fn=<NanToNumBackward0>),
 tensor([[ 0.0000,  0.0557, -0.0100, -0.0024,  0.0004,  0.0025,  0.0032,  0.0024,
          -0.0003, -0.0004, -0.0002, -0.0003, -0.0007, -0.0005, -0.0003,  0.0004,
           0.0006,  0.0005,  0.0005, -0.0005]], grad_fn=<NanToNumBackward0>),
 tensor([[ 0.0000,  0.0181,  0.0077, -0.0016, -0.0036, -0.0039, -0.0056, -0.0050,
          -0.0012,  0.0008,  0.0005,  0.0005,  0.0005,  0.0009,  0.0012,  0.0018,
           0.0020,  0.0011,  0.0002, -0.0004]], grad_fn=<NanToNumBackward0>),
 tensor([[ 0.0000, -0.0160, -0.0117, -0.0006, -0.0018, -0.0007, -0.0011, -0.0009,
          -0.0017, -0.0013, -0.0006, -0.0004,  0.0002, -0.0004, -0.0005, -0.0007,
          -0

In [None]:
full_h = model(data)
full_h = full_h.reshape(frames, num_atoms, -1).clone().detach()

def f(pos, data, kernel, model, num_atoms=num_atoms, d=d):
    frames = len(torch.unique(data.batch))
    assert(frames == 1)
    new_batch = _pre_batch_pos_for_f(pos, data)
    h = model(new_batch).reshape(frames, num_atoms, d)
    k = kernel(h, full_h)
    return k



# Create function
# func = torch.func.functionalize(lambda x: f(x, **f_kernel_kwargs))
# out, jvp_out = jvp(func, (pos,), (c1.reshape(frames, num_atoms, -1)[0],), strict=True)
jvps = []
for t in range(frames):
    dat = Batch.from_data_list([data_batch[t]])
    pos = dat.pop("pos")
    f_kernel_kwargs = {
        #"pos": pos,
        "data": dat,
        "kernel": gklme,
        "model": model,
    }
    
    func = torch.func.functionalize(lambda x: f(x, **f_kernel_kwargs))
    out, jvp_out = jvp(func, (pos,), (c1.reshape(frames, num_atoms, -1)[t],), strict=True)
    jvps.append(torch.nan_to_num(jvp_out))
jvp

In [None]:
torch.stack(jvps).shape

In [None]:
def linop01(c1, **f_kernel_kwargs):
    """f_kernel_kwargs are additional kwargs we pass onto f"""
    pos = f_kernel_kwargs.pop("pos")
    # Create function
    func = torch.func.functionalize(lambda x: f01(x, **f_kernel_kwargs))
    # Get JVP
    out, jvp_out = jvp(func, (pos,), (c1,), strict=True)
    # Note, this is done so that we have G_t c1 on each row,
    # since the linop01 is \sum_t^T G_t c1 we can make this by simpy summing over the correct axis
    # Note that we have nans, I do not know why, but for now we just set it to zero
    c01 = jvp_out.fillna(0.0)
    return c0


c01 = linop01(c1, **f_kernel_kwargs)

In [None]:
c01