Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyro integration #895

Merged
merged 23 commits into from Feb 4, 2021
Merged

Pyro integration #895

merged 23 commits into from Feb 4, 2021

Conversation

adamgayoso
Copy link
Member

@adamgayoso adamgayoso commented Jan 29, 2021

Initial implementation of a Pyro base module class, and a Pyro training plan (pytorch lightning class). See the example in the tests file I added.

@adamgayoso
Copy link
Member Author

@vitkl, if you have a chance, it would be great if you could let me know if this code is general enough for your use case. I'm not super familiar with Pyro, but by having everything in a torch nn module (see test case), our save/load, etc of the high-level API will work automatically.

The thing to note about Pytorch lightning is that you need to write device agnostic code. The trainer handles moving things to the GPU for you. This code here works on my GPU.

@vitkl
Copy link
Contributor

vitkl commented Jan 29, 2021

Thanks @adamgayoso ! We are currently quite busy with paper revision. We (I and @yozhikoff) will try using cell2location model with you code next week (hopefully) or week thereafter. Given your comments - getting this to work will probably need some tweaking.

Maybe I am missing something but where is _CONSTANTS.X_KEY specified? E.g. how do I tell scVI what are the arguments to the model? a25ee28#diff-663073fe4ef3cfa8f3fad68a502f760ee33dbb94c3c803c50635f0d63f67ee30R28-R29

As far as I understand pyro plates, they need to be given index from the data loader and other args (subsample_size=None, subsample=None, dim=None, http://docs.pyro.ai/en/0.3.0-release/primitives.html#pyro.plate, a25ee28#diff-663073fe4ef3cfa8f3fad68a502f760ee33dbb94c3c803c50635f0d63f67ee30R36) to work as expected (normalise loss by minibatch size and select correct indices for a local parameter). Without providing index from the data loader as subsample it will sample random indices - in the model below it is important to get correct indices for local variable cell_norm.

We currently do this by providing indices as argument to the model / guide functions:

    def create_plates(self, x_data, idx, cell2sample, cell2covar):
        return [pyro.plate("obs_axis", self.n_obs, dim=-2, 
                           subsample_size=self.minibatch_size, 
                           subsample=idx),
                pyro.plate("var_axis", self.n_var, dim=-1),
                pyro.plate("factor_axis", self.n_fact, dim=-2),
                pyro.plate("experim_axis", self.n_experim, dim=-2)]
                
    def model(self, x_data, idx, cell2sample, cell2covar):
        
        obs_axis, var_axis, factor_axis, experim_axis = self.create_plates(x_data, idx, cell2sample, cell2covar)

        with var_axis, factor_axis:
           gene_loadings_fg = pyro.sample('gene_loadings_fg',
                                                      dist.Gamma(torch.ones([self.n_fact, self.n_var]),
                                                                           torch.ones([self.n_fact, self.n_var])))
        with var_axis:
           gene_alpha_g = pyro.sample('gene_alpha_g', dist.Exponential(torch.ones([1, self.n_var])))

        with experim_axis:
           detection_mean = pyro.sample('detection_mean', dist.Gamma(torch.ones([self.n_experim, 1]), torch.ones([self.n_experim, 1])))
        with obs_axis as ind:
           cell_norm = pyro.sample('cell_norm', dist.Gamma(200, 200 / torch.mm(cell2sample[ind], self.detection_mean)))
                                                                           
        with var_axis:
            with obs_axis as ind:
                self.mu_biol = torch.mm(cell2covar, self.gene_loadings_fg)
                self.total_count, self.logits = _convert_mean_disp_to_counts_logits(self.mu_biol, self.gene_alpha_g,
                                                                            eps=1e-8) # from scVI 
                self.data_target = pyro.sample('data_target',
                                               dist.NegativeBinomial(total_count=self.total_count, logits=self.logits),
                                               obs=x_data) # I did not manage to make pyro work with scVI NegativeBinomial distribution class

However, we are still figuring out how to use pyro plates correctly and need to do more testing.

Do failed checks mean that this code does not work?

@adamgayoso
Copy link
Member Author

Do failed checks mean that this code does not work?

Yes, but the issue is just that I need to tell it not to use the GPU, as I tested it locally with a GPU. I'll fix that.

Maybe I am missing something but where is _CONSTANTS.X_KEY specified?

This is definitely opaque. What happens is that setup_anndata() adds a data registry to the anndata itself, that is then used by our data loader. So our AnnDataLoader returns a dictionary with keys (accessible via scvi._CONSTANTS and values being the minibatched tensors that were registered originally. What Romain did for Stereoscope was to add an obs column of the index of each cell and register it, so that it could be used for stereoscope. You can see this in the init of the spatial model.

And going back to your original comment about mixing autoguides with NN encoders -- that's not something that's really in our control as much as it is a question of whether you can do that in Pyro easily.

@codecov
Copy link

codecov bot commented Jan 29, 2021

Codecov Report

Merging #895 (2f9f8a6) into master (83e312e) will decrease coverage by 0.04%.
The diff coverage is 79.16%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #895      +/-   ##
==========================================
- Coverage   89.27%   89.23%   -0.05%     
==========================================
  Files          73       73              
  Lines        5604     5647      +43     
==========================================
+ Hits         5003     5039      +36     
- Misses        601      608       +7     
Impacted Files Coverage Δ
scvi/lightning/__init__.py 100.00% <ø> (ø)
scvi/model/base/_base_model.py 89.90% <36.36%> (-3.03%) ⬇️
scvi/lightning/_trainingplans.py 94.39% <88.88%> (-0.85%) ⬇️
scvi/compose/__init__.py 100.00% <100.00%> (ø)
scvi/compose/_base_module.py 98.63% <100.00%> (+0.16%) ⬆️
scvi/data/_anntorchdataset.py 94.23% <0.00%> (+1.92%) ⬆️
scvi/dataloaders/_ann_dataloader.py 87.75% <0.00%> (+4.08%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 83e312e...2f9f8a6. Read the comment docs.

@romain-lopez
Copy link
Member

I added a simple version of scVI with Pyro. I think this is enough to show that our API makes it easy to implement models with Pyro as well! Once this is merged, we can add a skeleton in pyro to make this integration more clear!

@romain-lopez
Copy link
Member

what's with the codacy?

@vitkl
Copy link
Contributor

vitkl commented Jan 31, 2021 via email

@vitkl
Copy link
Contributor

vitkl commented Jan 31, 2021

With AnnotationDataLoader or otherwise, is it possible to load and use the whole dataset in each training iteration rather than mini-batches?
E.g. here: https://github.com/YosefLab/scvi-tools/blob/e90dacd7485e5342badb02c791d977fd9e73609d/tests/models/test_pyro.py#L152

@romain-lopez
Copy link
Member

it should be easy to change the pytorch data loader behavior to send the full dataset at each iteration (maximal batch size, and un-randomized sampling of indices)

@adamgayoso
Copy link
Member Author

adamgayoso commented Feb 1, 2021

it should be easy to change the pytorch data loader behavior to send the full dataset at each iteration (maximal batch size, and un-randomized sampling of indices)

the only caveat here is it would put on cuda each time -- this is a new functionality we need to add -- full batch data loading on gpu.

Is "ind_x" in this line an index in minibatch or is it a column in
st_adata.obs?

@vitkl it's actually both, we make an obs column before training of this 1d array and it will get loaded by the data loader, so each minibatch we'd know which index each cell was.

@vitkl
Copy link
Contributor

vitkl commented Feb 1, 2021 via email

@adamgayoso
Copy link
Member Author

adamgayoso commented Feb 1, 2021

I'm not really sure how these autoguides work, but it's not happy on load, as when it's initialized it has no parameters. I think it's parameters get initialized after some data passes through?

is it possible to load full data

Yes, but it will be slower if using GPU, just set batch size to right size and shuffle to False.

@adamgayoso
Copy link
Member Author

@galenxing @romain-lopez I think I figured it out -- if using AutoGuides, which from my understanding is like ADVI? You have to run some data through the model so the guides get their torch params and THEN you can load from the state dict. Otherwise the guides have no torch params (I believe pyro needs to trace through the model to infer the guide, and this requires data)

In other words, I think we need to do the following

  1. If the underlying module is Pyro and there are Autoguides, run a models train method for 1 training step just before loading the state dict (might need to clear param store before too?)

There are probably a few good ways we can detect the Pyro/autoguide prereq.

@vitkl
Copy link
Contributor

vitkl commented Feb 1, 2021 via email

@romain-lopez
Copy link
Member

@fritzo @jamestwebber
With our incoming major refactor, we are exploring possibilities for Pyro support. In particular, I have implemented a simple version of scVI here, based on your VAE tutorial, as well as your scanvi reimplementation. If this is something interesting to you, it would be great to have some feedback / suggestions.

@martinjankowiak
Copy link

@romain-lopez exciting! in case it helps i implemented scanvi here. also see here

Copy link

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, glad to see you're trying out Pyro 😄

I have only a couple warnings about sharp edges like PyroModule requirements. Let me know if you have any questions or want to chat (Pyro slack, PyTorch slack, zoom, ...).

self,
pyro_module: PyroBaseModuleClass,
lr: float = 1e-3,
loss_fn: Callable = pyro.infer.Trace_ELBO(),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning: it's dangerous to set stateful default arguments. In this case the same default Trace_ELBO instance will be shared by all models and (details below) bad things may happen. I'd recommend defaulting to None and doing a standard if loss_fn is None: loss_fn = pyro.infer.Trace_ELBO().

Each Trace_ELBO instance guesses and stores the number of plates in its model, and assumes it will be associated with only a single model. If you use it with a different model with a different number of plates, you might see tensor shape errors.

Comment on lines 18 to 22
class BayesianRegression(PyroModule, PyroBaseModuleClass):
def __init__(self, in_features, out_features):
super().__init__()

self._auto_guide = AutoDiagonalNormal(self)
Copy link

@fritzo fritzo Feb 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technical detail: One of the rules of PyroModules is that the model and guide must be separate PyroModule s, and cannot be contained in a single PyroModule; however they can be contained in a single nn.Module. You might consider refactoring to separate PyroBaseModuleClass from some sort of PyroModelClass like this (where s <: t means issubclass(s, t)):

PyroBaseModuleClass <: nn.Module
  .model : PyroBaseModelClass <: PyroModule
  .guide : AutoNormal <: PyroModule

Basically "a model can't have its guide as an attribute".

The reason is due to PyroModule naming schemes and caching of pyro.sample calls. If model and guide are contained in a single PyroModule there may be weird conflicts in both names and the pyro.sample cache.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is super helpful!

# score against actual counts
pyro.sample("obs", x_dist.to_event(1), obs=x)

@pyro_method
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See note above: a PyroModule model cannot own its guide; you'll need an outer nn.Module to contain both.

@romain-lopez
Copy link
Member

Wonderful! Thanks @martinjankowiak and @fritzo! We will be in touch if we have more questions!

@adamgayoso
Copy link
Member Author

also @vitkl the API changed a bit, please take a look at the bayesian regression example. The wrapper NN Module class needs to have a model and guide attr.

try:
model.module.load_state_dict(model_state_dict)
except RuntimeError as err:
if isinstance(model.module, PyroBaseModuleClass):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

# sets a prior over the weight vector
# self.linear.weight = PyroSample(
# dist.Normal(self.zero, self.one)
# .expand([out_features, in_features])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we keep those, uncomment, or delete?

Copy link
Member

@romain-lopez romain-lopez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds great!

@adamgayoso adamgayoso changed the title [WIP]: Pyro integration Pyro integration Feb 4, 2021
@adamgayoso adamgayoso merged commit c1b069f into master Feb 4, 2021
@adamgayoso adamgayoso deleted the pyro branch February 4, 2021 01:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

interoperability with pyro to simplify implementing complex hierarchical priors?
5 participants