In [1]:
import torch
print(torch.__version__)

torch.backends.mps.is_available()

import torch
import time

device_cpu = torch.device("cpu")
device_mps = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

model_cpu = torch.nn.Linear(1000, 1000).to(device_cpu)
model_mps = torch.nn.Linear(1000, 1000).to(device_mps)

inputs_cpu = torch.randn(1000, 1000).to(device_cpu)
inputs_mps = torch.randn(1000, 1000).to(device_mps)

# Test CPU Time
start_time = time.time()
for _ in range(500):
    outputs_cpu = model_cpu(inputs_cpu)
cpu_time = time.time() - start_time
print(f"CPU Time: {cpu_time:.4f} seconds")

# Test MPS Time
start_time = time.time()
for _ in range(500):
    outputs_mps = model_mps(inputs_mps)
torch.mps.synchronize()  # Make sure MPS compeletes the computation
mps_time = time.time() - start_time
print(f"MPS Time: {mps_time:.4f} seconds")


2.2.2
CPU Time: 6.1144 seconds
MPS Time: 0.6143 seconds


### Goals: 
* Showcase computation of our desired jacobian-vector product (JVP) for a small toy network case using either:
* 1) torch.autograd.functional.jacobian (which has $\mathcal{O}(d^2)$ complexity),
  2) torch.autograd.grad (which has $\mathcal{O}(d)$ complexity using for loops,
  3) using torch.autograd.grad + torch.vmap to avoid loop over batch dimension.
* This is meant as a simple/early example for Bruce to work with.

In [2]:
#general 
import numpy as np
import os, sys
import torch

In [3]:
#specific to vfm repo
#this is mostly so you have a net to work with ... 
#sys.path.append('/home/dfd4/vfm_D_min_clipping/') #swap this for path to VFM repo in your local machine 
sys.path.append('/Users/zhanglige/Desktop/JP-Lab/Code/Velocity_Flow_Matching/')
import dnnlib
from training.networks import ToyMLP

### Construct a network to use 

In [4]:
#set up device first
device_name = 'cpu' #can swap this to cuda:0, etc pending on resources
device = torch.device(device_name)
print(device)

cpu


In [5]:
#create ToyMLP instance, 
#adjust it to train mode, tracking grads and pass it to device
mlp = ToyMLP(dim=784, time_varying=True, n_hidden=6, w=64)
mlp.train().to(device)

ToyMLP(
  (net): Sequential(
    (0): Linear(in_features=785, out_features=64, bias=True)
    (1): Tanh()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): Tanh()
    (4): Linear(in_features=64, out_features=64, bias=True)
    (5): Tanh()
    (6): Linear(in_features=64, out_features=64, bias=True)
    (7): Tanh()
    (8): Linear(in_features=64, out_features=64, bias=True)
    (9): Tanh()
    (10): Linear(in_features=64, out_features=64, bias=True)
    (11): Tanh()
    (12): Linear(in_features=64, out_features=64, bias=True)
    (13): Tanh()
    (14): Linear(in_features=64, out_features=784, bias=True)
  )
)

#### Above yields a simple mlp with $n$ number of hidden layers, each with width $w$.  
#### Note that input feature size is data_dim + 1. This is by design, as this net takes also a time input which is concatenated to flattened image input.
### Let's compute one desired jvp using torch.autograd.functional.jacobian ... 

In [6]:
#first, set up inputs for net 
#am choosing small batch size to make computation faster 
batch_size = 3
flat_data_dim = 784 
imgs = torch.randn(batch_size, flat_data_dim).type(torch.float32).to(device)
ts = torch.rand(batch_size, device=device) 

In [7]:
#ok now calc Jacobian of net_out w.r.t imgs input 
#set requires_grad to True for net inputs... 
ts.requires_grad=True
imgs.requires_grad=True

We can think img as $x \in \mathbb{R}^{B \times D}$, where $B$ is the batch size and $D$ is the input dimension. Also ts as $t \in \mathbb{R}^{B}$. The mlp we denote as y, thus the overall function can be written as $y=f(x,t) \in \mathbb{R}^{B \times D'}$, where $D'$ is the output dimension.

Now if we don't consider batch, y,x are vectors, the Jacobian will be an intuitive matrix. But if we consider batches, it's a bit complex.

In [8]:
mlp_jac = torch.autograd.functional.jacobian(mlp, (imgs, ts))

In [9]:
print(len(mlp_jac))
print(mlp_jac[0].shape) #[B,D',B,D]
print(mlp_jac[1].shape) 

2
torch.Size([3, 784, 3, 784])
torch.Size([3, 784, 3])


In [10]:
Jacobian_4_img = mlp_jac[0]
print(Jacobian_4_img[2][1])

#This is the y_1 in batch 2, it should only be the output of x_i from batch 2, so we see the first two rows 
#should be 

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-1.1842e-05,  1.5036e-04,  3.9824e-04,  ...,  2.2364e-04,
          3.0547e-04,  2.9108e-04]])


#### Note that torch's Jac method will produce one Jacobian output per each input of the function call. So, in this case, first item in list is Jacobian of net output w.r.t. first (i.e., imgs) input and second item is Jacobian of output w.r.t second input (ts).
#### Note also that Jac ouputs have several FULL ZERO rows... This is because batch items are independent of each other - that is, we can collapse across batch dim in dim==2.

In [None]:
#imgs jvp
print(mlp_jac[0].shape)
torch_jac = torch.sum(mlp_jac[0], dim=2) #collapse over extra batch dim, see comment above # then it would be img:(B,D',D)
print(torch_jac.shape) #[B,D',D]
nabla_imgs = torch_jac.transpose(2,1) #transpose, to get gradient (SHOULDN"T BE?? I THINK)
print(nabla_imgs.shape) #[B,D,D']
#compute actual u \cdot nabla product. 
#note that this is between above nabla_imgs and corresponding network output tensor u 
imgs_jvp = torch.einsum('bij, bjk -> bik', mlp(imgs, ts).unsqueeze(1), nabla_imgs).squeeze(1) #bs, dim 
#mlp(imgs, ts).unsqueeze(1):=[B,1,D']    bjk:=[B,D',D] -> [B,1,D], after squeeze it would be [B,D]
#if you think about it [D',D] is the correct shape of a Jacobian
print(imgs_jvp)
print(imgs_jvp.shape)

torch.Size([3, 784, 3, 784])
torch.Size([3, 784, 784])
torch.Size([3, 784, 784])
tensor([[ 0.0030, -0.0006,  0.0008,  ...,  0.0011,  0.0006,  0.0009],
        [ 0.0027, -0.0002,  0.0014,  ...,  0.0011,  0.0006,  0.0014],
        [ 0.0023, -0.0009,  0.0006,  ...,  0.0009,  0.0003,  0.0004]],
       grad_fn=<SqueezeBackward1>)
torch.Size([3, 784])


For each batch sample \( i \), you are computing the following:

$$
\text{imgs\_jvp}^{(i)} = \mathbf{y}^{(i)} \cdot \left( \frac{\partial \mathbf{y}^{(i)}}{\partial \mathbf{x}^{(i)}} \right)^\top
$$

Where:

- $ \mathbf{y}^{(i)} \in \mathbb{R}^{D'} $: The model output for the \( i \)-th sample.
- $ \frac{\partial \mathbf{y}^{(i)}}{\partial \mathbf{x}^{(i)}} \in \mathbb{R}^{D' \times D} $: The Jacobian matrix for the \( i \)-th sample, representing the partial derivatives of the output with respect to the input.
- $ \text{imgs\_jvp}^{(i)} \in \mathbb{R}^{D} $: The result of the Jacobian-Vector Product, which is a vector of the same dimension as the input features.

This operation represents how the model output, when treated as a vector, interacts with the gradient (Jacobian) of the model with respect to the input, effectively measuring how changes in the input affect the output in the direction of the output itself.


#### Let's de-compress the above a bit. First, we take Jacobian of net output w.r.t to its first input (imgs) and sum/collapse it over dimension 2, since each output is only dependent on its corresponding batch input. 
#### Then, we take transpose w.r.t to 2 last dimensions - this is d/t relationship between Jacobian and gradients... Typically, for scalar functions, gradients are column vectors (i.e., rows of Jac transposed)
#### Will triple check with John that this is needed here... 

#### Finally, we use an Einstein summation to compute the desired product over the batch -- this yields the $u \cdot \nabla $ product we want for the Lie derivative calculation.

#### Couple of points here: 
* 1) In actual code, this will take LONGER to run, because in there imgs are not just simple inputs, but instead interpolations between original data and outputs of another, separate encoder network.
  2) Additionally, we compute two such JVPs (one for flow net, another for dynamics net) and compute partial derivative w.r.t. to time argument (ts) as well.
  3) All of these then go into computing out Lie derivative loss: $\mathcal{L}_{Lie} = \partial_{\tau} \mathbf{v} + \mathbf{u} \cdot \nabla_{\mathbf{v}} - \mathbf{v} \cdot \nabla_{\mathbf{u}}$

### Ok, let's compute same jvp now with torch.grad.autograd

In [12]:
#small method to compute desired Jacobian, for a batch 
def batch_jacobian(model, imgs, ts):
    """Computes the Jacobian of a batch of outputs w.r.t a batch of inputs."""

    batch_size, input_size = imgs.shape #[B,D]
    output_size = model(imgs, ts).shape[1] #[D']

    jacobian = torch.zeros(batch_size, output_size, input_size) #[B,D',D]

    #note that we loop over batch AND dimensions here! 
    for i in range(batch_size): #for b \in [b_1,...,b_B]
        for j in range(output_size): #For y \in (y_1,y_2,...,y_{D'})
            grad_outputs = torch.zeros_like(model(imgs, ts)) #[B,D']
            grad_outputs[i, j] = 1.0 #Weight for gradient, but you may consider as a filter
            jacobian[i, j] = torch.autograd.grad(
                model(imgs, ts), imgs, grad_outputs=grad_outputs, retain_graph=True
            )[0][i]

    return jacobian

ag_jac = batch_jacobian(mlp, imgs, ts)
#check that shape matches - this should already be collapsed across extra batch dim
print(ag_jac.shape) #[B,D',D]

torch.Size([3, 784, 784])


In [None]:
#check that this produces correct/desired Jac 
np.allclose(ag_jac.detach().cpu().numpy(), torch_jac.detach().cpu().numpy())

True

In [None]:
#ok, now compute jvp 
nabla_imgs_ag = ag_jac.transpose(2,1) #transpose to get grad 
imgs_jvp_ag = torch.einsum('bij, bjk -> bik', mlp(imgs, ts).unsqueeze(1), nabla_imgs_ag.to(device)).squeeze(1) #bs, dim 

In [24]:
#check that these results are indeed the same... 
np.allclose(imgs_jvp.detach().cpu().numpy(), imgs_jvp_ag.detach().cpu().numpy())

True

### Ok, now let's use torch vmap and torch.autograd.grad to run the above 
#### This avoids loop over batch items but comes at cost of larger VRAM requirements...

In [17]:
get_jvp = lambda v: torch.autograd.grad(mlp(imgs, ts), imgs, v, retain_graph=True)

In [37]:
#method to create v vector we will need to vmap over batch elements 
def build_IN_vmap(shape):
    """
    Computes IN we will use for vmapping 
    over torch.autograd.grad call 
    """
    I_N = []
    for i in range(shape[0]):
        for j in range(shape[1]):
            curr_IN = torch.zeros(shape)
            curr_IN[i, j]=1
            I_N.append(curr_IN)
    I_N = torch.cat(I_N, dim=0)
    return I_N.reshape((-1, shape[0], shape[1]))

In [38]:
vmap_v = build_IN_vmap((6, 784))

In [39]:
#check that shape matches - should be [bs*dim, bs, dim]
vmap_v.shape

torch.Size([4704, 6, 784])

In [49]:
vmap_ag_jac = torch.vmap(get_jvp)(vmap_v.type(torch.float32).to(device))[0]

In [51]:
#check that shape is correct - should be [bs*dim, bs, dim]
print(vmap_ag_jac.shape)
#reshape this output to [bs, dim, bs, dim]
#collapse over dim==2 as before 
vmap_ag_jac = vmap_ag_jac.reshape(batch_size, imgs.shape[1], batch_size, imgs.shape[1])
vmap_ag_jac = vmap_ag_jac.sum(2)
print(vmap_ag_jac.shape)

torch.Size([4704, 6, 784])
torch.Size([6, 784, 784])


In [52]:
#ok now check that this is indeed identical to previous results obtained with loop + AG and Jac 
vmap_ag_jac.all() == ag_jac.all()

tensor(True, device='cuda:1')

In [54]:
vmap_ag_jac.all() == torch_jac.all()

tensor(True, device='cuda:1')

### General Comments: 
1) I need to check with John if indeed I need to transpose Jac here? This is how I implemented things originally and this matches results computed by hand.
2) All cases above still end up computing a Jacobian different ways and then doing the Jacobian vector product... This is in part due to difficulty of finding a proper vector $v$ that would correctly select ONLY desired row of Jac and ALSO multiply it by our corresponding output vector.
3) I don't think this is possible really? But discuss it with John and then update Bruce 