In [None]:
# ref - https://github.com/mfouesneau/NUTS

"""
it takes a series of steps informed by first-order gradient information. This feature allows it to converge much more quickly to high-dimensional target
distributions compared to simpler methods such as Metropolis, Gibbs sampling (and derivatives).

NUTS uses a recursive algorithm to find likely candidate points that automatically stops when it
starts to double back and retrace its steps.  Empirically, NUTS perform at least as effciently as 
and sometimes more effciently than a well tuned standard HMC method, without requiring user intervention or costly tuning runs.
"""

import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

## Load dataset

In [2]:
from irt.load_data import *
from irt.irt_model import *

dataset = load_1pl_simulation(num_person = 10000, num_item = 100, ability_dim = 1, nonlinear = False)
train_dataset = load_dataset(train=True, num_person=10000, num_item=100, ability_dim=1, nonlinear=False, max_num_person=None, max_num_item=None)
test_dataset = load_dataset(train=False, num_person=10000, num_item=100, ability_dim=1, nonlinear=False, max_num_person=None, max_num_item=None)

## Preprocess data

In [3]:
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

num_item, num_person = train_dataset.num_item, train_dataset.num_person
response, mask = train_dataset.response, train_dataset.mask

response[response == -1] = 0  # filler value within support
response = torch.from_numpy(response).float().to(device)
mask = torch.from_numpy(mask).long().to(device)

In [4]:
from pyro.infer.mcmc.util import initialize_model, predictive

"""
Given a Python callable with Pyro primitives, generates the following model-specific
properties needed for inference using HMC/NUTS kernels:

- initial parameters to be sampled using a HMC kernel,
- a potential function whose input is a dict of parameters in unconstrained space,
- transforms to transform latent sites of `model` to unconstrained space,
- a prototype trace to be used in MCMC to consume traces from sampled parameters.
"""

irt_model = irt_model_1pl

init_params, potential_fn, transforms, _ = initialize_model(
        irt_model,
        model_args=(
            1, 
            num_person, 
            num_item, 
            device, 
            response, 
            mask, 
            1,
        ),
        num_chains=1,
    )

In [5]:
from pyro.infer.mcmc import NUTS
from pyro.infer.mcmc.api import MCMC

ability_dim = 1
num_samples = 200
num_warmup = 100
num_chains = 1

nuts_kernel = NUTS(potential_fn = potential_fn)

mcmc = MCMC(
    nuts_kernel,
    num_samples = num_samples,
    warmup_steps = num_warmup,
    num_chains = num_chains,
    initial_params = init_params,
    transforms = transforms,
)

# same irt model input param
mcmc.run(
    ability_dim, 
    num_person, 
    num_item, 
    device, 
    response, 
    mask, 
    1,
)

Sample: 100%|██████████| 300/300 [00:45,  6.56it/s, step size=1.46e-01, acc. prob=0.868]


In [6]:
# ability : (샘플수, 사람수, 1)
# item_feat : (샘플수, 문항수, 1)

samples = mcmc.get_samples()
for key in samples.keys():
    samples[key] = samples[key].cpu()
    
sample_means, sample_variances = {}, {}
for key, sample in samples.items():
    sample_means[key] = torch.mean(samples[key], dim=0)
    sample_variances[key] = torch.var(samples[key], dim=0)

In [7]:
print(sample_means['ability'].shape)
print(sample_means['item_feat'].shape)

torch.Size([8000, 1])
torch.Size([100, 1])


In [8]:
from pyro.util import ignore_experimental_warning

def sample_posterior_predictive(model, posterior_samples, *args):
    with ignore_experimental_warning():
        predict = predictive(model, posterior_samples, *args)
        return predict
    
# get posterior predictive samples (response)
posterior_predict_samples = sample_posterior_predictive(
    irt_model, 
    samples, 
    ability_dim, 
    num_person, 
    num_item, 
    device,
    None, 
    None, 
    1,
)



In [9]:
print(posterior_predict_samples['response'].shape)

torch.Size([200, 8000, 100, 1])
