Skip to content

BPro2410/poisson_topicmodels

Repository files navigation

poisson-topicmodels

poisson-topicmodels: Probabilistic Topic Modeling with Bayesian Inference

Python 3.11+ License: MIT PyPI version codecov Code style: black

poisson-topicmodels is a modern Python package for probabilistic topic modeling using Bayesian inference, built on JAX and NumPyro.

Package documentation

There is a full package documentation available here.

Statement of Need

Traditional topic modeling packages (e.g., Gensim, scikit-learn's LDA) use older inference methods and lack flexibility for emerging research needs. poisson-topicmodels addresses key gaps:

  1. Modern Probabilistic Inference: Built on NumPyro, enabling automatic differentiation, probabilistic programming, and integration with cutting-edge Bayesian methods.

  2. Advanced Topic Models: Goes beyond LDA with guided topic discovery (keyword priors), covariate effects, ideal point estimation, and embeddings—all with principled Bayesian inference.

  3. GPU Acceleration: Leverages JAX for transparent GPU computation, essential for large-scale corpus analysis and enabling research that would be prohibitively slow on CPU.

  4. Scalability & Reproducibility: Optimized for mini-batch SVI training with built-in seed control for exact reproducibility—critical for research validation and publication.

  5. Research-Friendly API: Purpose-built for computational social science and NLP researchers who need interpretable, flexible models beyond black-box approaches.

Whether analyzing legislative text, social media discourse, or scientific abstracts, poisson-topicmodels enables researchers to extract interpretable semantic structure with confidence in results.

Features

poisson-topicmodels provides multiple topic modeling approaches:

Model Use Case Key Feature
Poisson Factorization (PF) Unsupervised baseline Fast, interpretable word-topic associations
Seeded PF (SPF) Guided discovery Incorporate domain knowledge via keyword priors
Covariate PF (CPF) Covariate effects Model topics influenced by document metadata
Covariate Seeded PF (CSPF) Guided + covariates Combine keyword guidance with external factors
Text-Based Ideal Points (TBIP) Ideal point estimation Estimate author positions from legislative/social text
Structured Text-Based Scaling (STBS) Topic-specific positions + author covariates Estimate topic-specific ideal points and covariate-driven ideology shifts
Embedded Topic Models (ETM) Modern embeddings Integrate pre-trained word embeddings

Core Capabilities:

  • ✨ Stochastic Variational Inference (SVI) with mini-batch training
  • ✨ Transparent GPU acceleration via JAX
  • ✨ Reproducible results with seed control
  • ✨ Type hints and comprehensive API documentation
  • ✨ >70% test coverage with continuous integration
  • ✨ Clear error messages and input validation

Quick Start

Get started in 5 minutes:

import numpy as np
from scipy.sparse import csr_matrix
from poisson_topicmodels import PF

# Prepare data: document-term matrix and vocabulary
counts = csr_matrix(np.random.poisson(2, (100, 500)).astype(np.float32))
vocab = np.array([f'word_{i}' for i in range(500)])

# Initialize and train model
model = PF(counts, vocab, num_topics=10, batch_size=32)
params = model.train_step(num_steps=100, lr=0.01, random_seed=42)

# Extract results
topics, _ = model.return_topics()
top_words = model.return_top_words_per_topic(n=10)
print(f"Found {topics.shape} topics")
print(f"Top words: {top_words[:3]}")

See examples/ directory for detailed notebooks.

Installation

From PyPI (recommended)

pip install poisson-topicmodels

GPU installs (opt-in)

Automatic GPU detection at install time is not reliable across macOS/Windows/Linux/cloud runtimes. Use explicit install targets:

NVIDIA GPU (Linux x86_64/aarch64, CUDA 12)

pip install "poisson-topicmodels[gpu-cuda12]"

Apple Silicon GPU (Metal)

pip install "poisson-topicmodels[gpu-metal]"

Run with:

JAX_PLATFORMS=METAL python your_script.py

AMD GPU (ROCm)

Install the package first, then follow the official JAX AMD instructions: JAX AMD GPU install guide. JAX's AMD install uses ROCm plugin wheels and environment-specific commands, so it is not encoded as a generic PyPI extra.

Other GPU Installations

For other cases we refer to manual Jax installation. See the guide.

From Source

git clone https://github.com/BPro2410/poisson_topicmodels.git
cd poisson_topicmodels
pip install -e .

Development Setup

git clone https://github.com/BPro2410/poisson_topicmodels.git
cd poisson_topicmodels
pip install -e ".[dev]"
pytest tests/  # Verify installation

Requirements

  • Python ≥ 3.11
  • JAX 0.4.35 (GPU support via optional install targets above)
  • NumPyro ≥ 0.15.3
  • NumPy, SciPy, scikit-learn, pandas

See pyproject.toml for complete dependency list.

Documentation

Basic Usage Examples

1. Unsupervised Topic Discovery (PF)

from poisson_topicmodels import PF

model = PF(counts, vocab, num_topics=10, batch_size=64)
model.train_step(num_steps=500, lr=0.001, random_seed=42)

# Extract topics
topics, topic_probs = model.return_topics()
top_words = model.return_top_words_per_topic(n=15)

2. Guided Topic Modeling with Keywords (SPF)

from poisson_topicmodels import SPF

keywords = {
    0: ['climate', 'environment', 'carbon'],
    1: ['economy', 'growth', 'trade'],
}

model = SPF(counts, vocab, keywords, residual_topics=5, batch_size=64)
model.train_step(num_steps=500, lr=0.001, random_seed=42)

3. Covariate Effects (CPF)

from poisson_topicmodels import CPF

# Include document-level covariates
covariates = np.random.randn(100, 3)  # 100 documents, 3 covariates

model = CPF(counts, vocab, covariates, num_topics=10, batch_size=64)
model.train_step(num_steps=500, lr=0.001, random_seed=42)

Custom Model Extension

Due to its modular structure it is easy to implement your own custom models with poisson-topicmodels. Below you can see a short example.

from poisson_topicmodels import NumpyroModel
import numpyro
from numpyro import plate, sample
import numpyro.distributions as dist

class MyModel(NumpyroModel):
    def _model(self, Y_batch, d_batch):
        with plate("n", len(Y_batch)):
            mu = sample("mu", dist.Normal(0, 1))
            sample("obs", dist.Normal(mu, 1), obs=Y_batch)

    def _guide(self, Y_batch, d_batch):
        mu_loc = numpyro.param("mu_loc", 0.0)
        mu_scale = numpyro.param("mu_scale", 1.0)
        with plate("n", len(Y_batch)):
            sample("mu", dist.Normal(mu_loc, mu_scale))

To implement a custom model, one has to only define the high-level model. The backbone of poisson-topicmodels handles training and inference.

Example Data

The repository includes data/10k_amazon.csv with ~10,000 Amazon product reviews for quick experimentation. See examples/01_getting_started.py for a complete walkthrough.

Docker Setup (Optional)

For a reproducible, isolated environment with JupyterLab:

# Build image
docker build -t topicmodels-jupyter .

# Run container (Linux/macOS)
docker run --rm -p 8888:8888 -v "$(pwd)":/workspace topicmodels-jupyter

# Then open http://localhost:8888 in your browser

Citation

If you use poisson_topicmodels in your research, please cite:

@software{topicmodels2026,
  title = {Poisson-topicmodels: Probabilistic Topic Modeling with Bayesian Inference},
  author = {Prostmaier, Bernd and Grün, Bettina and Hofmarcher, Paul},
  year = {2026},
  url = {https://github.com/BPro2410/poisson_topicmodels},
}

See CITATION.cff for additional citation formats.

Contributing

Contributions are welcome! Please see CONTRIBUTING.md for guidelines on:

  • Reporting bugs
  • Submitting pull requests
  • Code style and testing requirements
  • Documentation standards

License

This project is licensed under the MIT License. See LICENSE for details.

Support


Built with ❤️ for researchers and practitioners in computational social science and NLP

About

Poisson factorization techniques in JAX package

Resources

License

Contributing

Stars

Watchers

Forks

Packages

 
 
 

Contributors