In [None]:
import numpy as np
import pandas as pd
import arviz as az
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
from jax import random
from jax.scipy.special import expit

import numpyro
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import SVI, Trace_ELBO, Predictive, NUTS, MCMC

rng_np = np.random.default_rng(101)
n_factors_pmf = 10
n_factors_exposure = 20
K = 5

# MovieLens data

In [None]:
names = ['userId', 'movieId', 'rating', 'timestamp']
df_ratings = pd.read_csv('./ml-100k/u.data', sep='\t', names=names)
df_ratings = df_ratings.drop('timestamp', axis=1)
print(df_ratings.shape)
df_ratings.head(3)


train-test split

In [None]:
# Taken from https://docs.pymc.io/notebooks/probabilistic_matrix_factorization.html

# Define a function for splitting train/test data.
def split_train_test(data, percent_test, rng):
    """Split the data into train/test sets."""
    n, m = data.shape             # # users, # movies
    N = n * m                     # # cells in matrix
    # Prepare train/test ndarrays.
    train = data.copy()
    test = np.ones(data.shape) * np.nan
    # Draw random sample of training data to use for testing.
    tosample = np.where(~np.isnan(train))       # ignore nan values in data
    idx_pairs = list(zip(tosample[0], tosample[1]))   # tuples of row/col index pairs
    test_size = int(len(idx_pairs) * percent_test)  # use a % of data as test set
    train_size = len(idx_pairs) - test_size   # and remainder for training
    indices = np.arange(len(idx_pairs))         # indices of index pairs
    sample = rng.choice(indices, replace=False, size=test_size)
    # Transfer random sample from train set to test set.
    for idx in sample:
        idx_pair = idx_pairs[idx]
        test[idx_pair] = train[idx_pair]  # transfer to test set
        train[idx_pair] = np.nan          # remove from train set
    # Verify everything worked properly
    assert(train_size == N-np.isnan(train).sum())
    assert(test_size == N-np.isnan(test).sum())
    # Return train set and test set
    return train, test


all_users = np.sort(df_ratings.userId.unique())
all_movies = np.sort(df_ratings.movieId.unique())
df_dense_data = df_ratings.pivot_table(index='userId', columns='movieId', values='rating')
assert all(df_dense_data.columns == all_movies)
assert all(df_dense_data.index == all_users)

train, test = split_train_test(df_dense_data.values, 0.2, rng=rng_np)
del df_dense_data
print(train.shape, test.shape)

In [None]:
n_users, n_items = train.shape

# Exposure model

In [None]:
exposure_train = jnp.nan_to_num(train, nan=0)
exposure_train = jnp.where(exposure_train > 0, 1, 0)
print(exposure_train.shape)

In [None]:
print(exposure_train.mean(), exposure_train.sum(), jnp.count_nonzero(exposure_train))

In [None]:
def pf_model(
    n_users: int, 
    n_items: int, 
    n_factors: int, 
    exposure=None, 
    # **kwargs
):
    """Generative program of the Poisson factorization model"""
    a = 0.23
    b = 1 / a
    prefs = numpyro.sample('prefs', dist.Gamma(a, b), sample_shape=(n_users, n_factors))
    atts = numpyro.sample('atts', dist.Gamma(a, b), sample_shape=(n_items, n_factors))
    rates = numpyro.deterministic('rates', prefs @ atts.T)
    obs = numpyro.sample('obs', dist.Poisson(rates), obs=exposure)
    return obs


def pf_guide(
    n_users: int, 
    n_items: int, 
    n_factors: int, 
    exposure=None, 
):
    """Poisson factorization guide."""
    a_prefs = numpyro.param(
        'a_prefs', 
        jnp.ones((n_users, n_factors)) * 0.23, 
        constraint=constraints.positive
    )
    a_atts = numpyro.param(
        'a_atts', 
        jnp.ones((n_items, n_factors)) * 0.23, 
        constraint=constraints.positive
    )
    b_prefs = 1 / a_prefs
    b_atts = 1 / a_atts
    prefs = numpyro.sample('prefs', dist.Gamma(a_prefs, b_prefs))
    atts = numpyro.sample('atts', dist.Gamma(a_atts, b_atts))
    rates = numpyro.deterministic('rates', prefs @ atts.T)
    return rates


with numpyro.handlers.seed(rng_seed=10134):
    prior_sample = pf_model(n_users, n_items, n_factors_exposure)
    print(prior_sample.shape)
    print(prior_sample.mean(), prior_sample.sum(), jnp.count_nonzero(prior_sample))
    guide_vals = pf_guide(n_users, n_items, n_factors_exposure)
    print(guide_vals.shape)

In [None]:
# SVI
n_iters = 15
step_size = 0.01
optimizer = numpyro.optim.Adam(step_size=step_size)
svi = SVI(pf_model, pf_guide, optimizer, loss=Trace_ELBO())

init_state = svi.init(
    random.PRNGKey(202), 
    n_users, n_items, n_factors_exposure, exposure_train
)
state, losses = jax.lax.scan(
    lambda state, i: svi.update(state, n_users, n_items, n_factors_exposure, exposure_train), 
    init_state, jnp.arange(n_iters)
)

In [None]:
plt.plot(losses);

In [None]:
for i in range(n_iters):
    svi_state, loss = svi.update(
        svi_state, 
        n_users, n_items, n_factors_exposure, exposure_train
    )

In [None]:
# TODO: initialize model with result of SVI

In [None]:
nuts_kernel = NUTS(poisson_factorization_model)
mcmc = MCMC(nuts_kernel, num_samples=500, num_warmup=1000, num_chains=1)
rng_key = random.PRNGKey(1234)
# mcmc.warmup()
mcmc.run(rng_key, exposure_train, n_factors_exposure)
posterior_samples = mcmc.get_samples()


# Outcome model

In [None]:
confounders = jnp.load('./substitute_confounders.npy')
print(confounders.shape)
train_no_nan = jnp.nan_to_num(train, nan=0)
print(train_no_nan.shape)

scaled_ratings = (train - 1) / (K - 1)
# pd.Series(scaled_ratings.flatten()).plot.hist()
scaled_ratings = jnp.nan_to_num(scaled_ratings, nan=0)
print(scaled_ratings.shape)
# scaled_ratings
exposure_train = jnp.nan_to_num(train, nan=0)
exposure_train = jnp.where(exposure_train > 0, 1, 0)
print(exposure_train.shape)
exposure_train

In [None]:
def pmf_model(exposure, confounders, ratings, n_factors: int, **kwargs):
    n_users, n_items = exposure.shape
    s = 0.1
    U = numpyro.sample('U', dist.Normal(0., s), sample_shape=(n_users, n_factors))
    V = numpyro.sample('V', dist.Normal(0., s), sample_shape=(n_items, n_factors))
    gammas = numpyro.sample('gammas', dist.Normal(0., 1), sample_shape=(n_users, 1))
    mean = (U @ V.T) * exposure + (gammas * confounders)
    mean = expit(mean)
    scale = jnp.ones_like(mean) * 0.001
    r = numpyro.sample('R', dist.Normal(mean, scale), obs=ratings)
    return mean


with numpyro.handlers.seed(rng_seed=101):
    res_model = pmf_model(exposure_train, confounders, scaled_ratings, n_factors_pmf)
    print(res_model.shape)
    print(res_model.min(), res_model.mean(), res_model.max())
res_model

In [None]:
nuts_kernel = NUTS(pmf_model)
mcmc = MCMC(nuts_kernel, num_samples=500, num_warmup=1000, num_chains=1)
rng_key = random.PRNGKey(15)
mcmc.run(rng_key, exposure_train, confounders, scaled_ratings, n_factors_pmf)
posterior_samples = mcmc.get_samples()

# Results

In [None]:
# Functions taken from https://docs.pymc.io/notebooks/probabilistic_matrix_factorization.html

def rmse(test_data, predicted):
    """Calculate root mean squared error.
    Ignoring missing values in the test data.
    """
    I = ~jnp.isnan(test_data)   # indicator for missing values
    N = I.sum()                # number of non-missing values
    sqerror = jnp.power(test_data - predicted, 2)
    mse = sqerror[I].sum() / N                 # mean squared error
    return jnp.sqrt(mse)


def mae(test_data, predicted):
    """Calculate Mean Absolute Error Ignoring missing values"""
    I = ~jnp.isnan(test_data)   # indicator for missing values
    N = I.sum()                # number of non-missing values
    abserror = jnp.abs(test_data - predicted)
    mae = abserror[I].sum() / N   # mean squared error
    return mae

In [None]:
def pmf_predict(U, V, gammas, confounders, k):
    mean = (U @ V.T) * 1. + (gammas * confounders)
    scaled_mean = expit(mean)
    preds = scaled_mean * (k - 1) + 1
    return preds

# samples
U_samples = jnp.load('./samples/U.npy')
V_samples = jnp.load('./samples/V.npy')
gammas_samples = jnp.load('./samples/gammas.npy')
beta_0_samples = jnp.load('./samples/beta_0.npy')

# params
U = np.median(U_samples, axis=0)
V = np.median(V_samples, axis=0)
gammas = np.median(gammas_samples, axis=0)
beta_0 = np.median(beta_0_samples, axis=0)

assert n_factors_pmf == U.shape[1]
assert n_factors_pmf == V.shape[1]

preds_pmf = pmf_predict(U, V, gammas, confounders, K)
print(preds_pmf.shape)

In [None]:
print('RMSE:', rmse(test, preds_pmf))
print('MAE:', mae(test, preds_pmf))