In [2]:
import pyro
import pyro.distributions as dist
import torch
import numpy as np
import matplotlib.pyplot as plt
import pyro.infer.mcmc as mcmc

In [1]:
def poisson_factorization(data, latent_dim):
    # Define model parameters
    n, p = data.shape
    
    #Construct samples of F, G, and X. X is assumed to be poisson distribution of F * G
    F = pyro.sample("F", dist.Gamma(1., 1.).expand([n, latent_dim]))
    G = pyro.sample("G", dist.Gamma(1., 1.).expand([latent_dim, p]))
    
    
    # Define masking function to hide lower right corner of the data
    mask = torch.ones_like(data)
    mask[-20:, -2:] = 0.
    mask = mask.bool()
    
    # Observe the observed entries of X
    pyro.sample("X_observed", dist.Poisson((F @ G)[mask]), obs=torch.tensor(data[mask]))
    
    
    
    X = F @ G
    return X

# Define model

#Is our latent dimension here also deterined as a hyperparameter? 
latent_dim = 2
model = poisson_factorization

# Generate fake data
n, p = 50, 5
data = np.random.poisson(5., size=(n, p)) + 500

# Convert data to PyTorch tensor
data = torch.tensor(data)

# Run MCMC
num_samples = 1000
warmup_steps = 1000
kernel = mcmc.NUTS(model)
mcmc_run = mcmc.MCMC(kernel, num_samples=num_samples, warmup_steps=warmup_steps)

#Run MCMC process on our data with given Latent dimension
mcmc_run.run(data, latent_dim)

NameError: name 'np' is not defined

In [28]:
# Extract posterior samples
posterior_samples = mcmc_run.get_samples()

# Extract F, G, lambda_f, and lambda_g samples
F_samples = posterior_samples["F"]
G_samples = posterior_samples["G"]


# Plot posterior samples of F and G
# fig, axs = plt.subplots(latent_dim, 2, figsize=(12, 6), sharex=True)
# for i in range(latent_dim):
#     axs[i, 0].hist(F_samples[:, i, :], bins=50)
#     axs[i, 0].set_title(f"F[:, {i+1}]")
#     axs[i, 1].hist(G_samples[i, :, :], bins=50)
#     axs[i, 1].set_title(f"G[{i+1}, :]")
# plt.tight_layout()
# plt.show()

F_samples[-1] @ G_samples[-1]

tensor([[515.2449, 504.5264, 503.0182, 500.1896, 502.1837],
        [492.8720, 488.6221, 490.4343, 488.7366, 491.2715],
        [501.2274, 497.2758, 499.3195, 497.6552, 500.2719],
        [505.5141, 502.0925, 504.4592, 502.8754, 505.5734],
        [501.7811, 496.6788, 498.1031, 496.2445, 498.7442],
        [488.7137, 485.5200, 487.8701, 486.3580, 488.9782],
        [493.0889, 488.5482, 490.2044, 488.4574, 490.9633],
        [504.1041, 500.4686, 502.7076, 501.0907, 503.7578],
        [524.2792, 518.2939, 519.4275, 517.3757, 519.9191],
        [491.9086, 487.2018, 488.7582, 486.9857, 489.4672],
        [514.0132, 509.4611, 511.2859, 509.4953, 512.1265],
        [518.5750, 512.3624, 513.3251, 511.2466, 513.7318],
        [516.6780, 512.5861, 514.6827, 512.9640, 515.6594],
        [520.2710, 514.5070, 515.7271, 513.7204, 516.2628],
        [527.9942, 522.1348, 523.3676, 521.3295, 523.9086],
        [494.0881, 484.8686, 483.9963, 481.4617, 483.4846],
        [498.9581, 495.4121, 497.6568, 4

In [29]:
data

tensor([[507, 504, 504, 505, 507],
        [503, 506, 502, 510, 508],
        [501, 504, 505, 502, 504],
        [504, 506, 508, 506, 503],
        [505, 504, 506, 504, 506],
        [504, 505, 503, 504, 507],
        [504, 507, 501, 507, 501],
        [501, 504, 505, 507, 507],
        [500, 506, 504, 504, 508],
        [507, 511, 500, 502, 506],
        [504, 505, 507, 500, 503],
        [507, 507, 503, 504, 503],
        [504, 505, 505, 506, 506],
        [506, 502, 505, 506, 507],
        [508, 503, 508, 504, 508],
        [502, 506, 502, 505, 504],
        [503, 506, 503, 507, 504],
        [503, 504, 504, 503, 508],
        [503, 506, 502, 506, 508],
        [505, 506, 503, 504, 503],
        [502, 508, 506, 505, 505],
        [502, 502, 505, 506, 504],
        [503, 506, 504, 503, 504],
        [507, 506, 505, 501, 507],
        [506, 503, 508, 504, 503],
        [505, 502, 504, 500, 506],
        [507, 504, 506, 505, 505],
        [503, 504, 501, 504, 507],
        [503, 503, 5

In [27]:
a = torch.ones(5,5)
mask = torch.ones(5,5)
mask[-1,-1] = 0
mask = mask.bool()
a[mask]
a

tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])