In [30]:
import numpy as np
import torch
import torch.nn as nn
from blackbirds.jacfwd import jacfwd

# step 1: calibration network

In [31]:
calibNN = nn.Linear(5, 1)
calib_data = torch.randn(16, 5)

abm_params = calibNN(calib_data)

In [32]:
print(abm_params.shape)

gt_data = torch.rand_like(abm_params)
print(gt_data.shape)

torch.Size([16, 1])
torch.Size([16, 1])


# step 2: abm network

In [33]:
def model_fwd(abm_params):
    return 2 * abm_params

def abm(abm_params, model_fwd, gt_data):
    abm_preds = model_fwd(abm_params)
    loss = nn.MSELoss()(abm_preds, gt_data)
    
    return loss, loss

# step 3: jacobian calculator and optimizer

In [34]:
grad_mode = 'forward' # or 'reverse'
chunk_size = int(abm_params.shape[0]/2)

In [35]:
if grad_mode == 'forward':
    grad_function = lambda *args, **kwargs: jacfwd(randomness="same", *args, **kwargs)
else:
    grad_function = torch.func.jacrev

In [36]:
jacobian_calculator = grad_function(abm, 0, has_aux=True, chunk_size=chunk_size)

In [37]:
abm_loss_jac, loss = jacobian_calculator(abm_params, model_fwd, gt_data)
print(abm_loss_jac.shape, loss.shape)

torch.Size([16, 1]) torch.Size([])


In [38]:
print(abm_loss_jac.shape, loss.shape)

torch.Size([16, 1]) torch.Size([])


In [39]:
# loss function on parameters of CalibNN
diff_grad = torch.dot(abm_loss_jac.flatten(), abm_params.flatten())
diff_grad.backward()

In [40]:
for param in calibNN.parameters():
    print(param.grad)

tensor([[-1.4103, -0.1383, -2.8588,  7.3415,  2.3704]])
tensor([-7.8232])


## Comments

In [21]:
flow_params = torch.tensor([1., 2, 3])

def flow(flow_params):
    return 2 * flow_params

def abm(abm_params): # ABM and loss function
    return torch.sum(abm_params**2), torch.sum(abm_params**2)


In [25]:
jacobian_calculator = torch.func.jacfwd(abm, 0, has_aux=True) # differentiate w.r.t first one

In [26]:
jacobian_calculator(flow_params)

(tensor([2., 4., 6.], grad_fn=<ViewBackward0>),
 tensor(14., grad_fn=<SelectBackward0>))

In [27]:
flow_params = torch.tensor([1., 2, 3], requires_grad=True)
abm_params = flow(flow_params)

abm_jacobian = jacobian_calculator(abm_params)

to_diff = torch.dot(abm_jacobian[0], abm_params)
to_diff.backward()


In [28]:
flow_params.grad

tensor([16., 32., 48.])