Skip to content
This repository has been archived by the owner on Oct 9, 2024. It is now read-only.

Jax Port of MrVI #1

Merged
merged 11 commits into from
Oct 25, 2022
Merged

Jax Port of MrVI #1

merged 11 commits into from
Oct 25, 2022

Conversation

justjhong
Copy link
Collaborator

@justjhong justjhong commented Oct 18, 2022

Tests pass using main branch of scvi-tools.

Exact port of torch code with only linear decoder options. Will build off of this.

@adamgayoso
Copy link
Member

I fixed everything but the device at the end in my PRs

@justjhong justjhong marked this pull request as ready for review October 19, 2022 02:14
Comment on lines 211 to 213
jit_inference_fn = self.module.get_jit_inference_fn(
inference_kwargs={"mc_samples": mc_samples, "cf_sample": cf_sample}
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure this will cause the fn to recompile each time, you kind of want a data structure containing each possible inference fn (one per value of sample).

Otherwise you'd want jax to trace over cf_sample, which we can also do

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think there should be an api here which allows for traced values, (i.e. pass kwargs into the inference_fn, not the get_jit_inference_fn)

@codecov
Copy link

codecov bot commented Oct 25, 2022

Codecov Report

❗ No coverage uploaded for pull request base (main@e4d68c5). Click here to learn what that means.
The diff coverage is n/a.

Additional details and impacted files
@@           Coverage Diff           @@
##             main       #1   +/-   ##
=======================================
  Coverage        ?   92.89%           
=======================================
  Files           ?        6           
  Lines           ?      394           
  Branches        ?        0           
=======================================
  Hits            ?      366           
  Misses          ?       28           
  Partials        ?        0           
Flag Coverage Δ
unittests 92.89% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

@justjhong justjhong merged commit 829079c into main Oct 25, 2022
@justjhong justjhong deleted the jhong/mrvijax branch October 25, 2022 15:29
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants