Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove pre_sample, pre_warmup from NUTS and fix warning #545

Merged
merged 11 commits into from
Oct 11, 2024
Merged

Conversation

amal-ghamdi
Copy link
Contributor

@amal-ghamdi amal-ghamdi commented Oct 6, 2024

closes #524
closes #537

Verification Code 1:

Code verifies this change still gives same results as old NUTS (also this is verified by regression tests)

from git import Repo
repo = Repo('./CUQIpy')
branch = repo.active_branch
# Sample with NUTS + time it
import time
import numpy as np
from cuqi.testproblem import Deconvolution1D
from cuqi.distribution import Gaussian, Gamma, JointDistribution, GMRF
from cuqi.experimental.mcmc import NUTS
import matplotlib.pyplot as plt

# Inverse problem
np.random.seed(0) 
testproblem_obj = Deconvolution1D(dim=128, phantom='sinc', noise_std=0.001)
posterior = testproblem_obj.posterior

# NUTS
Ns=15
Nb=15
print(branch.name)
print('NUTS')
start = time.time()
np.random.seed(0)
NUTS_sampler = NUTS(posterior, step_size=None)
NUTS_sampler.warmup(Nb, tune_freq=1/Nb).sample(Ns)

end = time.time()
print('Time:', end - start)
print('samples norm:', np.linalg.norm(NUTS_sampler.get_samples().samples))
plt.plot([np.linalg.norm(NUTS_sampler.get_samples().samples[:,i]) for i in range(Ns+Nb)])
plt.ylim([0, 14])
# add branch name in title
plt.title(branch.name)
plt.figure()
plt.plot(NUTS_sampler.num_tree_node_list)
plt.title(branch.name)

Results main branch:

nuts_main_chain
nuts_main_tree

Results this branch:
nuts_branch_chain
nuts_branch_tree

Verification Code 2 (added as a test in the PR, for smaller sample size):

import cuqi
import numpy as np
from cuqi.distribution import Gamma, Gaussian, GMRF, JointDistribution, LMRF
from cuqi.experimental.mcmc import NUTS, HybridGibbs, Conjugate, LinearRTO, ConjugateApprox, UGLA
from cuqi.testproblem import Deconvolution1D
from git import Repo
import matplotlib.pyplot as plt
import time
repo = Repo('../')
branch = repo.active_branch

# Forward problem
np.random.seed(0)
A, y_data, info = Deconvolution1D(dim=128, phantom='sinc', noise_std=0.001).get_components()

# Bayesian Inverse Problem
s = Gamma(1, 1e-4)
x = GMRF(np.zeros(A.domain_dim), 50)
y = Gaussian(A@x, lambda s: 1/s)

# Posterior
target = JointDistribution(y, x, s)(y=y_data)

Nb=40
sampling_strategy = {
    "x" : NUTS(max_depth=7),
    "s" : Conjugate()
}

# Here we do 10 internal steps with NUTS for each Gibbs step
num_sampling_steps = {
    "x" : 1,
    "s" : 1
}

sampler = HybridGibbs(target, sampling_strategy, num_sampling_steps)
# start time
start_time = time.time()
sampler.warmup(Nb)
sampler.sample(40)
samples = sampler.get_samples()
# end time
end_time = time.time()
print(end_time - start_time)


samples["x"].plot_ci(exact=info.exactSolution)
plt.title(branch.name +" time: "+str(end_time - start_time))

results main branch

gibbs_main

results this branch

gibbs_branch

@amal-ghamdi amal-ghamdi mentioned this pull request Oct 6, 2024
2 tasks
Copy link
Collaborator

@nabriis nabriis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @amal-ghamdi. Really nice! I only had a comment related to the deepcopy. Perhaps we can find a more efficient method.

cuqi/experimental/mcmc/_gibbs.py Show resolved Hide resolved
cuqi/experimental/mcmc/_hmc.py Outdated Show resolved Hide resolved
tests/zexperimental/test_mcmc.py Show resolved Hide resolved
Copy link
Contributor

@chaozg chaozg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @amal-ghamdi , this PR looks good and I don't have any technical comments; but I did try the two demos on my machine and the results look great. BTW, it's good to know this trick branch = repo.active_branch : )

@amal-ghamdi
Copy link
Contributor Author

Hi @amal-ghamdi , this PR looks good and I don't have any technical comments; but I did try the two demos on my machine and the results look great. BTW, it's good to know this trick branch = repo.active_branch : )

Thank you @chaozg for your review!

@amal-ghamdi amal-ghamdi changed the title Fix nuts v4 Remove pre_sample, pre_warmup from NUTS and fix warning Oct 11, 2024
@amal-ghamdi amal-ghamdi merged commit 85645ce into main Oct 11, 2024
6 checks passed
@amal-ghamdi amal-ghamdi deleted the fix_NUTS_v4 branch October 11, 2024 10:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants