# Parameter Estimation with Diffsol Sensitivities

We reproduce `examples/integration/parameter_estimation/estimate.py` in notebook form.
Noisy measurements follow $x' = -k x$. Using forward sensitivities from diffsol we
fit `k` via gradient descent.


In [None]:
import numpy as np
import torch
from diffsol_pytorch import DiffsolModule, testing
from helpers import generate_decay_data, preferred_device


In [None]:
DECAY_CODE = '''
in = [k]
k { 0.4 }
u {
    x = 1.0,
}
F {
    -k * x,
}
'''
module = DiffsolModule(DECAY_CODE)


In [None]:
device_target = preferred_device()
times, obs = generate_decay_data()


In [None]:
params = torch.tensor([0.2], dtype=torch.float64, requires_grad=True)
optimizer = torch.optim.Adam([params], lr=5e-2)
loss_history = []
for step in range(60):
    fwd = testing.forward_mode(module, params.detach().tolist(), times.tolist())
    pred = fwd.solution[0]
    residual = pred - obs
    loss = 0.5 * float(np.mean(residual**2))
    grad_sol = residual.reshape(1, -1) / residual.size
    grads = np.einsum("pij,ij->p", fwd.sensitivities, grad_sol)
    optimizer.zero_grad()
    params.grad = torch.tensor(grads, dtype=params.dtype)
    optimizer.step()
    loss_history.append(loss)
params.item()


In [None]:
import matplotlib.pyplot as plt
plt.plot(loss_history)
plt.xlabel('Iteration')
plt.ylabel('MSE loss')
plt.title('Parameter estimation')
plt.show()


In [None]:
with torch.no_grad():
    pred = testing.forward_mode(module, [params.item()], times.tolist()).solution[0]
plt.plot(times, obs, 'o', label='observations')
plt.plot(times, np.exp(-0.4 * times), label='true k=0.4')
plt.plot(times, pred, label=f'estimated k={params.item():.3f}')
plt.legend()
plt.show()
