You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
dμ = torch.cat([el.flatten() if el is not None else torch.zeros(1) for el in dμ], dim=-1)
should be fixed by
param_shapes = [p.shape for p in vf.parameters()]
dμ = torch.cat([el.flatten() if el is not None else torch.zeros(param_shapes[i]).to(t.device).flatten() for i, el in enumerate(dμ)], dim=-1)
otherwise, the shape of gradient (torch.zeros(1)) does not match the parameter shape in vector field.
The text was updated successfully, but these errors were encountered:
In torchdyn -> numerics -> sensitivity.py
function _gather_odefunc_adjoint(), line 71:
should be fixed by
otherwise, the shape of gradient (
torch.zeros(1)
) does not match the parameter shape in vector field.The text was updated successfully, but these errors were encountered: