In [1]:
import pyro
import pyro.distributions as dist
import torch
import numpy as np
import matplotlib.pyplot as plt
import pyro.infer.mcmc as mcmc

In [None]:
def poisson_factorization(data, latent_dim):
    # Define model parameters
    n, p = data.shape
    
    #What are the prior distributions? Is it 1, 1? How to choose
    lambda_f = pyro.sample("lambda_f", dist.Gamma(1., 1.).expand([latent_dim, p]))
    lambda_g = pyro.sample("lambda_g", dist.Gamma(1., 1.).expand([n, latent_dim]))
    
    #Construct samples of F, G, and X. X is assumed to be poisson distribution of F * G
    F = pyro.sample("F", dist.Gamma(1., 1.).expand([n, latent_dim]))
    G = pyro.sample("G", dist.Gamma(1., 1.).expand([latent_dim, p]))
    X = pyro.sample("X", dist.Poisson(F @ G), obs=torch.tensor(data))
    return X

# Define model

#Is our latent dimension here also deterined as a hyperparameter? 
latent_dim = 2
model = poisson_factorization

# Generate fake data
n, p = 100, 50
data = np.random.poisson(5., size=(n, p))

# Convert data to PyTorch tensor
data = torch.tensor(data)

# Run MCMC
num_samples = 1000
warmup_steps = 100
kernel = mcmc.NUTS(model)
mcmc_run = mcmc.MCMC(kernel, num_samples=num_samples, warmup_steps=warmup_steps)

#Run MCMC process on our data with given Latent dimension
mcmc_run.run(data, latent_dim)

In [4]:
# Extract posterior samples
posterior_samples = mcmc_run.get_samples()

# Extract F, G, lambda_f, and lambda_g samples
F_samples = posterior_samples["F"]
G_samples = posterior_samples["G"]
lambda_f_samples = posterior_samples["lambda_f"]
lambda_g_samples = posterior_samples["lambda_g"]

# Plot posterior samples of F and G
fig, axs = plt.subplots(latent_dim, 2, figsize=(12, 6), sharex=True)
for i in range(latent_dim):
    axs[i, 0].hist(F_samples[:, i, :], bins=50)
    axs[i, 0].set_title(f"F[:, {i+1}]")
    axs[i, 1].hist(G_samples[i, :, :], bins=50)
    axs[i, 1].set_title(f"G[{i+1}, :]")
plt.tight_layout()
plt.show()

  
Warmup:   1%|          | 22/2000 [00:08,  1.81it/s, step size=2.43e-02, acc. prob=0.750]

KeyboardInterrupt: 