You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am interested in the ODE of the form dz(s)/ds = f_theta (s, r, z(s)) where r can be a vector independent of input x.
Does the current implementation support this feature?
The text was updated successfully, but these errors were encountered:
What you ask is not currently supported with the DataControl layer, but can be done pretty easily in two ways:
You can have your vector field nn.Module save r in self.r and use it in a forward that calls (s, z) only.
This is where the data control is set by default. The only module with u we consider is DataControl, which triggers the above to assign the current input to module.u. Depending on how general you'd prefer your implementation to be, you can simply modify DataControl
class CustomDataControl(nn.Module):
"""Data-control module. Allows for datasets-control inputs at arbitrary points of the DEFunc
"""
def __init__(self, r):
super().__init__()
self.u = None
self.r = r
def forward(self, x):
return torch.cat([x, self.r], 1).to(x)
which is slightly hacky but works, or you can alter the logic in _prep_integration to allow for custom assignments to module.u.
I managed to make it work following your suggestion!
One remaining question is, in this case, does the general condition r get any gradient?
I am interested in building conditional CNF like this work https://github.com/stevenygd/PointFlow
Do you think that is doable with this library?
Glad to hear it worked. Depending on your implementation, it should work just fine and get gradients. It is certainly possible to build PointFlows with torchdyn :)
Hi, thanks for the great library.
I am interested in the ODE of the form dz(s)/ds = f_theta (s, r, z(s)) where r can be a vector independent of input x.
Does the current implementation support this feature?
The text was updated successfully, but these errors were encountered: