Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Framing model.forward more like ML libraries? #188

Closed
iancze opened this issue Apr 6, 2023 · 1 comment · Fixed by #248
Closed

Framing model.forward more like ML libraries? #188

iancze opened this issue Apr 6, 2023 · 1 comment · Fixed by #248
Assignees

Comments

@iancze
Copy link
Collaborator

iancze commented Apr 6, 2023

A common idiom used in many machine learning libraries is something like

model.fit(X, y)

where X is the feature vector and y is the target vector. Then, the model is "trained" internally. This shows up very commonly in scikit-learn, for example.

In PyTorch, I think I usually see something (for example) like

outputs = model(inputs)

and then a loss function is computed on a separate line, rather than internal to the model (though in theory I think it could be, like in a cross-val runner). Anyway, this makes it relatively easy to divide a dataset up into several smaller batches of inputs and outputs and then train in batches.

On the other hand, in MPoL we usually have something like (for example)

rml = precomposed.SimpleNet(coords=coords, nchan=dset.nchan)
modelVisibilityCube = rml()

and then have a loss function to compare modelVisibilityCube to the data. This obviously works, but it has always bugged me a little bit that we don't take in some X or inputs and therefore this prevents us from training in batches. I think the problem is that there isn't an obvious 1-to-1 relationship between number of input points and number of output points because of a) the Fourier nature of the problem and b) how data averaging ("gridding") affects the number of visibilities.

The Fourier nature of the problem means that you'll always need to populate a full image and then do the full FFT, even if you are comparing to a single data point. This means there isn't much time-saving for training on a smaller "batch" compared to the full batch. This is especially true for the GriddedDataset, and probably still applies in some fashion to the NuFFT.

If we made the NuFFT layer the default FourierTransformer, then we could take in u,v coordinates, such that we'd have

modelVisibilities = model(us, vs)

and then these could be used in a loss function.

But this doesn't make a ton of sense for the SimpleNet and the FourierLayer, which returns a modelVisibilityCube that needs to be indexed by the GriddedDataset.

Is this just a quirk of the nature of our problem, and I shouldn't lose sleep over not taking in an X? Or are we framing the network in some sub-optimal way? The GriddedDataset approach is accurate enough that I think there will be few applications where we want the training loop to function using the NuFFT directly. Rather, it's much more useful for predicting loose visibilities for visualization applications.

@iancze
Copy link
Collaborator Author

iancze commented Dec 22, 2023

We've made progress on adopting this idiom with the SGD work, also partially addressing #161.

To finish these issues, identify and merge relevant functionality to main, then document with a SGD tutorial.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant