Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LDA implementation #1132

Merged
merged 56 commits into from Sep 23, 2021
Merged

LDA implementation #1132

merged 56 commits into from Sep 23, 2021

Conversation

justjhong
Copy link
Contributor

@justjhong justjhong commented Sep 2, 2021

Autoencoder Variational Bayes implementation of Latent Dirichlet Allocation as a PyroModule. Runs in the same magnitude of time as sklearn's implementation of LDA and achieves better perplexity on average. See benchmark here: https://colab.research.google.com/drive/1Iq_drlBTLadM8KJtZwIs96RdzG8MmEFA?authuser=1#scrollTo=lINSWBVshwvr.

Note: to have this work with reparametrization stably, I used a logistic normal to approximate the dirichlet distribution.

@codecov
Copy link

codecov bot commented Sep 2, 2021

Codecov Report

Merging #1132 (3ccbabb) into master (3ce5514) will increase coverage by 0.29%.
The diff coverage is 97.11%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1132      +/-   ##
==========================================
+ Coverage   90.27%   90.56%   +0.29%     
==========================================
  Files          93       95       +2     
  Lines        7341     7518     +177     
==========================================
+ Hits         6627     6809     +182     
+ Misses        714      709       -5     
Impacted Files Coverage Δ
scvi/model/base/_base_model.py 92.26% <83.33%> (-0.31%) ⬇️
scvi/train/_trainingplans.py 95.88% <94.73%> (+0.46%) ⬆️
scvi/model/_amortizedlda.py 96.22% <96.22%> (ø)
scvi/module/_amortizedlda.py 98.30% <98.30%> (ø)
scvi/dataloaders/_anntorchdataset.py 93.44% <100.00%> (-0.11%) ⬇️
scvi/external/stereoscope/_model.py 92.85% <100.00%> (+3.02%) ⬆️
scvi/model/__init__.py 100.00% <100.00%> (ø)
scvi/model/_destvi.py 87.00% <100.00%> (+2.38%) ⬆️
scvi/model/_scanvi.py 93.91% <100.00%> (-0.06%) ⬇️
scvi/model/_totalvi.py 84.44% <100.00%> (+0.23%) ⬆️
... and 7 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 3ce5514...3ccbabb. Read the comment docs.

@justjhong justjhong force-pushed the jhong/vanilla-lda branch 3 times, most recently from 780c4eb to 590d4f4 Compare September 10, 2021 18:15
@justjhong justjhong marked this pull request as ready for review September 10, 2021 19:10
scvi/model/_lda.py Outdated Show resolved Hide resolved
scvi/model/_lda.py Outdated Show resolved Hide resolved
scvi/model/_lda.py Outdated Show resolved Hide resolved
scvi/model/_lda.py Outdated Show resolved Hide resolved
scvi/model/_lda.py Outdated Show resolved Hide resolved
scvi/module/_lda.py Outdated Show resolved Hide resolved
scvi/module/_lda.py Outdated Show resolved Hide resolved
scvi/model/_lda.py Outdated Show resolved Hide resolved
scvi/model/_lda.py Outdated Show resolved Hide resolved
tests/models/test_pyro.py Show resolved Hide resolved
scvi/train/_trainingplans.py Outdated Show resolved Hide resolved
@adamgayoso adamgayoso added this to the 0.14.0 milestone Sep 22, 2021
Copy link
Member

@adamgayoso adamgayoso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM after the last small change



def logistic_normal_approximation(
alpha: torch.Tensor,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment/docstring about what's going on here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also realized that the sigma here was the variance when I should be returning the std, don't think it actually changes a lot though

@justjhong justjhong merged commit 9b888b8 into master Sep 23, 2021
@justjhong justjhong deleted the jhong/vanilla-lda branch September 23, 2021 04:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants