In [2]:
import wandb
import sys
import matplotlib.pyplot as plt
import scprep
import pandas as pd
sys.path.append('../src/')
from model import AEDist
from omegaconf import OmegaConf
import os
import glob
import numpy as np
import torch
from torch import nn
from torch.autograd import grad
adjoint = False
if adjoint:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint
import torch.optim as optim
from torch.autograd.functional import jacobian

def compute_jacobian_function(f, x, create_graph=True, retain_graph=True):
    """
    Compute the Jacobian of the decoder wrt a batch of points in the latent space using an efficient broadcasting approach.
    :param model: The VAE model.
    :param z_batch: A batch of points in the latent space (tensor).
    :return: A batch of Jacobian matrices.
    """
    # z_batch = z_batch.clone().detach().requires_grad_(True)
    x = x.clone()
    x.requires_grad_(True)
    # model.no_grad()
    output = f(x)
    batch_size, output_dim, latent_dim = *output.shape, x.shape[-1]

    # Use autograd's grad function to get gradients for each output dimension
    jacobian = torch.zeros(batch_size, output_dim, latent_dim).to(x.device)
    for i in range(output_dim):
        grad_outputs = torch.zeros(batch_size, output_dim).to(x.device)
        grad_outputs[:, i] = 1.0
        gradients = grad(outputs=output, inputs=x, grad_outputs=grad_outputs, create_graph=create_graph, retain_graph=retain_graph, only_inputs=True)[0]
        jacobian[:, i, :] = gradients
    return jacobian

def pullback_metric(x, fcn, create_graph=True, retain_graph=True):
    jac = compute_jacobian_function(fcn, x, create_graph, retain_graph)
    metric = torch.einsum('Nki,Nkj->Nij', jac, jac)
    return metric

# def pullback_metric2(x, fcn):
#     jac = compute_jacobian_function(fcn, x)
#     metric = torch.einsum('Nki,Nkj->Nij', jac, jac)
#     return metric


In [3]:
wandb.login()
api = wandb.Api()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mxingzhis[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
entity = "xingzhis"
project = "dmae"
run_id = 'iio2bb24'
run = api.run(f"{entity}/{project}/{run_id}")
folder_path = '../src/wandb/'
cfg = OmegaConf.create(run.config)
folder_list = glob.glob(f"{folder_path}*{run.id}*")
ckpt_files = glob.glob(f"{folder_list[0]}/files/*.ckpt")
ckpt_path = ckpt_files[0]
data_path2 = os.path.join(cfg.data.root, cfg.data.name + cfg.data.filetype)
data = np.load(data_path2, allow_pickle=True)
model = AEDist.load_from_checkpoint(ckpt_path)
x_tensor = torch.tensor(data['data'], dtype=torch.float32, device=model.device)
x_tensor_normalized = model.normalize(x_tensor)

  rank_zero_warn(


In [5]:
pullback_metric(x_tensor_normalized, model.encoder).shape

torch.Size([3000, 3, 3])

In [6]:
class ODEFunc(nn.Module):

    def __init__(self):
        super(ODEFunc, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(3, 50),
            nn.Tanh(),
            nn.Linear(50, 3),
        )

        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)

    def forward(self, t, x):
        return self.net(x)


In [7]:
# class DeltaLengthNet(torch.nn.Module):
#     def __init__(self, fcn):
#         super(DeltaLengthNet, self).__init__()
#         self.fcn = fcn
#         self.odefunc = ODEFunc()
    
#     def forward(self, t, x):
#         metric = pullback_metric(x, self.fcn, create_graph=False, retain_graph=True)
#         xdot = self.odefunc(t, x)
#         return torch.sqrt(torch.einsum('Ni,Nij,Nj->N', xdot, metric, xdot))

In [8]:
metric = pullback_metric(x_tensor_normalized, model.encoder)

In [9]:
torch.einsum('Ni,Nij,Nj->N', x_tensor_normalized, metric, x_tensor_normalized).shape

torch.Size([3000])

In [10]:
# for model.encoder.net[0]

In [11]:
# dl = DeltaLengthNet(model.encoder)


In [12]:
tswiss_roll = x_tensor_normalized
# Swiss roll random integer
npts = len(tswiss_roll)
start = np.random.randint(0,npts,size=1)
starttwo = np.random.randint(0,npts,size=1)
startthree = np.random.randint(0,npts,size=1)
startfour = np.random.randint(0,npts,size=1)
startfive= np.random.randint(0,npts,size=1)
end = np.random.randint(0,npts,size=1)
#Select start and end points for NeuralODE

x0 = torch.tensor(tswiss_roll[start,:]).cpu() #Start point
x1 = torch.tensor(tswiss_roll[starttwo,:]).cpu() 
x2 = torch.tensor(tswiss_roll[startthree,:]).cpu() 
x3 = torch.tensor(tswiss_roll[startfour,:]).cpu() 
x4 = torch.tensor(tswiss_roll[startfive,:]).cpu() 
xfin = torch.tensor(tswiss_roll[end,:]).cpu() 
xbatch = torch.cat((x0,x1,x2,x3,x4),0)
endbtch = torch.cat((xfin,xfin,xfin,xfin,xfin),0)


print(xfin.shape)
print(xbatch.shape)
print(endbtch.shape)

torch.Size([1, 3])
torch.Size([5, 3])
torch.Size([5, 3])


  x0 = torch.tensor(tswiss_roll[start,:]).cpu() #Start point
  x1 = torch.tensor(tswiss_roll[starttwo,:]).cpu()
  x2 = torch.tensor(tswiss_roll[startthree,:]).cpu()
  x3 = torch.tensor(tswiss_roll[startfour,:]).cpu()
  x4 = torch.tensor(tswiss_roll[startfive,:]).cpu()
  xfin = torch.tensor(tswiss_roll[end,:]).cpu()


In [13]:
batch_t = torch.linspace(0,1,2)

In [14]:
xbatch

tensor([[-0.1956,  1.2660,  1.6567],
        [ 0.4090,  0.7648,  0.7226],
        [ 0.1463,  0.8104,  0.6780],
        [ 0.8805,  0.9799,  1.3358],
        [ 1.0704, -0.2958,  0.6806]])

In [15]:
model.encoder(xbatch)

tensor([[ 2.9277, -9.8357],
        [ 1.2194,  2.6928],
        [10.2101, -0.0439],
        [-8.2807, -2.8404],
        [ 3.1774,  4.7731]], grad_fn=<AddmmBackward0>)

In [16]:
# dl(batch_t, xbatch)

In [17]:
xbatch

tensor([[-0.1956,  1.2660,  1.6567],
        [ 0.4090,  0.7648,  0.7226],
        [ 0.1463,  0.8104,  0.6780],
        [ 0.8805,  0.9799,  1.3358],
        [ 1.0704, -0.2958,  0.6806]])

In [18]:
batch_t = torch.linspace(0, 1, 100)
odefunc = ODEFunc()
xs = odeint(odefunc, xbatch, batch_t)
ts = batch_t

In [19]:
fcn = model.encoder
original_shape = xs.shape
xs_flat = xs.view(-1, xs.shape[2])
metric_flat = pullback_metric(xs_flat, fcn, create_graph=False, retain_graph=True)

In [20]:
metric_flat.shape

torch.Size([500, 3, 3])

In [21]:
class ODEFunc(nn.Module):

    def __init__(self, fcn):
        super(ODEFunc, self).__init__()
        self.fcn = fcn

        self.net = nn.Sequential(
            nn.Linear(3, 50),
            nn.Tanh(),
            nn.Linear(50, 2), # coefficients
        )

    def forward(self, t, x):
        xs_flat = x.view(-1, xs.shape[2])
        coefs = self.net(xs_flat)
        jac = compute_jacobian_function(fcn, xs_flat, create_graph=True, retain_graph=True)
        U, S, Vt = torch.linalg.svd(jac, full_matrices=False)
        velo_flat = torch.einsum('ij,ijk->ik', coefs, Vt)
        velo = velo_flat.view(x.shape[0], x.shape[1], -1)
        return velo

fcn = model.encoder
odefunc = ODEFunc(fcn)

In [57]:
original_shape = xs.shape
xs_flat = xs.view(-1, xs.shape[2])
# xs_flat = xs_flat.detach().clone()
xs_flat.requires_grad_(True)
jac = compute_jacobian_function(fcn, xs_flat, create_graph=False, retain_graph=True)
# metric_flat = pullback_metric(xs_flat, fcn, create_graph=False, retain_graph=True)
metric_flat = torch.einsum('Nki,Nkj->Nij', jac, jac)
xdot = odefunc(ts, xs)
xdot_flat = xdot.view(-1, xdot.shape[2])
l_flat = torch.sqrt(torch.einsum('Ni,Nij,Nj->N', xdot_flat, metric_flat, xdot_flat))
l = l_flat.view(original_shape[0], original_shape[1])
l_batch = l.mean(axis=0)

In [58]:
xs_flat

tensor([[-0.1956,  1.2660,  1.6567],
        [ 0.4090,  0.7648,  0.7226],
        [ 0.1463,  0.8104,  0.6780],
        ...,
        [ 0.0764,  0.8725,  0.5668],
        [ 0.7312,  1.1494,  1.1908],
        [ 0.9777, -0.1643,  0.6917]], grad_fn=<ViewBackward0>)

In [63]:
jac

tensor([[[  5.9738,   1.7593,  -4.4991],
         [  6.9808,   0.6456,  11.7341]],

        [[  2.7883,  -1.4000,   1.8339],
         [ -3.4731,  -1.0087,   1.4866]],

        [[  9.2566,   4.5395,   2.4293],
         [  3.2944,  -2.4549,   1.8112]],

        ...,

        [[ -2.7907,  -0.3447,   6.7974],
         [  4.1363,  -1.0368,   7.1594]],

        [[-12.2313,   0.9186,  -9.5081],
         [  2.2051,  -1.1680,  -5.7299]],

        [[-13.7632,  -1.8136, -33.5215],
         [  7.4836,   1.0632,  -3.3347]]])

In [59]:
# scprep.plot.scatter3d(xs_flat.detach().cpu().numpy())

In [60]:
U, S, Vt = torch.linalg.svd(jac, full_matrices=False)

In [61]:
Vt.permute(0,2,1).shape

torch.Size([500, 3, 2])

In [62]:
Vt

tensor([[[ 0.4745,  0.0371,  0.8795],
         [-0.8510, -0.2363,  0.4690]],

        [[ 0.9997, -0.0246,  0.0035],
         [-0.0173, -0.5892,  0.8078]],

        [[ 0.8998,  0.3528,  0.2567],
         [ 0.2680, -0.9112,  0.3128]],

        ...,

        [[ 0.1440, -0.1021,  0.9843],
         [ 0.9851, -0.0801, -0.1524]],

        [[-0.7582,  0.0487, -0.6502],
         [ 0.6280, -0.2137, -0.7483]],

        [[-0.3782, -0.0498, -0.9244],
         [ 0.9152,  0.1299, -0.3814]]])

In [27]:
# torch.einsum('ni,ni->n', Vt[:,0,:], Vt[:,1,:])

In [28]:
endbtch

tensor([[-1.0350,  1.0839,  0.8666],
        [-1.0350,  1.0839,  0.8666],
        [-1.0350,  1.0839,  0.8666],
        [-1.0350,  1.0839,  0.8666],
        [-1.0350,  1.0839,  0.8666]])

In [29]:
xbatch

tensor([[-0.1956,  1.2660,  1.6567],
        [ 0.4090,  0.7648,  0.7226],
        [ 0.1463,  0.8104,  0.6780],
        [ 0.8805,  0.9799,  1.3358],
        [ 1.0704, -0.2958,  0.6806]])

In [30]:
# t = torch.linspace(0., 25., 100)
# true_y0 = torch.tensor([[2., 0.]])
# true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]])
# class Lambda(nn.Module):

#     def forward(self, t, y):
#         return torch.mm(y**3, true_A)


# with torch.no_grad():
#     true_y = odeint(Lambda(), true_y0, t, method='dopri5')

# def get_batch():
#     s = torch.from_numpy(np.random.choice(np.arange(100 - 10, dtype=np.int64), 32, replace=False))
#     batch_y0 = true_y[s]  # (M, D)
#     batch_t = t[:10]  # (T)
#     batch_y = torch.stack([true_y[s + i] for i in range(10)], dim=0)  # (T, M, D)
#     return batch_y0, batch_t, batch_y


In [31]:
batch_y0, batch_t, batch_y = get_batch()

NameError: name 'get_batch' is not defined

In [None]:
batch_y0.shape

torch.Size([32, 1, 2])

In [None]:
batch_t.shape

torch.Size([10])

In [None]:
ode(batch_t, xbatch)

tensor([[ 0.0941,  0.0650,  0.0978],
        [ 0.0519, -0.0360,  0.0672],
        [ 0.1275,  0.0517,  0.0312],
        [ 0.1361,  0.0184, -0.0385],
        [-0.0176, -0.0065, -0.0898]], grad_fn=<AddmmBackward0>)

In [None]:
odeint(ode, xbatch, batch_t)

tensor([[[ 1.2636, -1.1970,  0.0490],
         [ 1.1355,  0.8826, -0.7435],
         [ 0.7154, -0.8976,  0.5071],
         [ 0.1271, -0.2776,  0.7833],
         [-1.0314, -0.0115,  0.6272]],

        [[ 1.3732, -1.1329,  0.1448],
         [ 1.1956,  0.8479, -0.6767],
         [ 0.8547, -0.8465,  0.5392],
         [ 0.2690, -0.2591,  0.7487],
         [-1.0580, -0.0182,  0.5400]]], grad_fn=<CopySlices>)

In [None]:
dl

DeltaLengthNet(
  (fcn): MLP(
    (net): Sequential(
      (0): Linear(in_features=3, out_features=256, bias=True)
      (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=256, out_features=128, bias=True)
      (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Linear(in_features=128, out_features=64, bias=True)
      (7): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU()
      (9): Linear(in_features=64, out_features=2, bias=True)
    )
  )
  (odefunc): ODEFunc(
    (net): Sequential(
      (0): Linear(in_features=3, out_features=50, bias=True)
      (1): Tanh()
      (2): Linear(in_features=50, out_features=3, bias=True)
    )
  )
)

In [70]:
xj = x_tensor_normalized[:100,:].clone()
xj.requires_grad = True
jac2 = torch.autograd.functional.jacobian(fcn, xj)

In [71]:
xj = x_tensor_normalized[:100,:].clone()
xj.requires_grad = True
jac1 = compute_jacobian_function(fcn, xj, create_graph=False, retain_graph=True)

In [74]:
xj

tensor([[-0.8684,  0.0236, -1.0067],
        [ 0.1840,  0.2444, -1.2358],
        [ 0.2437, -1.3371,  0.1722],
        [ 0.0489,  0.6453, -0.4259],
        [ 0.3651,  0.4967,  0.3923],
        [ 0.4829,  0.6382, -0.4590],
        [ 0.6718,  1.1707, -1.2243],
        [ 0.1208, -0.7065, -1.5939],
        [-1.1967, -1.2995,  0.1857],
        [-1.3850, -0.9415,  0.5508],
        [-0.0245,  1.2696, -0.8185],
        [ 0.7865, -0.8530, -1.0688],
        [-0.8079,  0.3914, -1.1201],
        [ 0.4048,  0.4748,  1.5497],
        [-0.1745,  0.1158, -1.3989],
        [-0.9508,  0.6480, -1.1287],
        [-1.3634,  0.9255,  0.4177],
        [ 0.5199,  0.7263,  0.4666],
        [-0.2398,  1.0528, -1.2614],
        [ 0.6575, -0.4875, -1.3386],
        [ 0.5762,  0.2497,  0.6565],
        [ 0.1421,  0.4722, -0.3928],
        [ 0.1000,  0.2369,  1.4754],
        [ 0.4500, -0.7788,  0.8218],
        [ 0.1868, -0.5579,  0.6494],
        [ 1.1157, -0.2390, -1.0569],
        [ 1.1488,  1.2061, -0.4805],
 

In [72]:
jac2

tensor([[[[-1.6000e-01,  1.2661e+00,  7.0687e-01],
          [-8.0331e-02, -4.2649e-02, -1.2482e-01],
          [ 5.6130e-02,  1.2781e-02,  5.5852e-03],
          ...,
          [ 1.1844e-02, -1.0015e-01, -3.5252e-02],
          [ 9.6427e-02, -1.2403e-02, -3.5353e-02],
          [-4.5646e-03, -9.6671e-03, -3.0431e-02]],

         [[ 4.0855e+00,  1.5588e+00, -5.6226e+00],
          [ 1.1061e-01,  1.9104e-02,  1.1722e-01],
          [-8.5327e-02, -4.5160e-02,  4.8423e-02],
          ...,
          [-2.4106e-01, -1.2250e-01,  3.6439e-01],
          [-1.5442e-01, -8.9653e-02,  1.9829e-01],
          [ 2.7556e-02,  7.2798e-03,  7.6086e-02]]],


        [[[-2.3466e-02, -3.7578e-02,  5.1181e-02],
          [-2.8150e+00,  4.4632e-01, -6.3662e-01],
          [ 1.1519e-03,  1.3879e-02, -3.1905e-02],
          ...,
          [ 8.4265e-03, -2.8990e-02,  7.0546e-02],
          [ 1.3097e-01,  9.3238e-03,  1.6862e-01],
          [-6.4703e-02, -4.5413e-03, -1.6575e-01]],

         [[-7.3698e-02, -8.29

In [73]:
jac1

tensor([[[-6.2720e-01,  7.2077e-02,  1.2819e+00],
         [ 1.9052e+00,  3.7060e-01, -1.6726e+00]],

        [[ 5.3103e-01,  4.5460e-01,  2.0450e+00],
         [ 6.8977e-01,  2.8719e-01,  3.3891e+00]],

        [[-3.4549e-01,  6.7609e-02,  5.4959e-01],
         [ 1.1801e+00, -9.2412e-01, -2.2275e+00]],

        [[ 4.2492e-01,  2.0359e-02,  4.6186e-01],
         [-2.0389e+00,  8.3883e-01,  1.3473e+00]],

        [[-3.8764e-01,  8.1049e-01,  1.4613e+00],
         [ 1.1917e+00, -2.1976e-01, -1.0997e+00]],

        [[ 4.2549e-01, -4.8525e-01,  2.0891e+00],
         [-8.0178e-01,  8.0972e-02, -1.8817e+00]],

        [[ 3.8723e-01,  1.7810e-01, -3.5687e+00],
         [ 4.5282e-01, -6.9495e-01,  1.8389e-02]],

        [[ 3.4725e-01,  5.0117e-01, -1.8377e-01],
         [-1.6966e-02,  3.3439e-01,  1.3539e-01]],

        [[-1.4895e+00, -4.5776e-01,  6.0408e-01],
         [-1.0796e+00,  5.8865e-02, -1.5685e+00]],

        [[ 1.5767e-01,  6.6500e-02,  5.0274e-01],
         [-1.8245e+00, -2.7839e-