-
Notifications
You must be signed in to change notification settings - Fork 0
Conversation
I fixed everything but the device at the end in my PRs |
src/scvi_v2/_model.py
Outdated
jit_inference_fn = self.module.get_jit_inference_fn( | ||
inference_kwargs={"mc_samples": mc_samples, "cf_sample": cf_sample} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pretty sure this will cause the fn to recompile each time, you kind of want a data structure containing each possible inference fn (one per value of sample).
Otherwise you'd want jax to trace over cf_sample, which we can also do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think there should be an api here which allows for traced values, (i.e. pass kwargs into the inference_fn, not the get_jit_inference_fn)
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #1 +/- ##
=======================================
Coverage ? 92.89%
=======================================
Files ? 6
Lines ? 394
Branches ? 0
=======================================
Hits ? 366
Misses ? 28
Partials ? 0
Flags with carried forward coverage won't be shown. Click here to find out more. |
Tests pass using main branch of scvi-tools.
Exact port of torch code with only linear decoder options. Will build off of this.