-
-
Notifications
You must be signed in to change notification settings - Fork 53
[ar1_turningpts] Removed PyMC and changed the style #584
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@xuanguang-li this is a big translation effort. Nice work! @HumphreyYang would you have time to review (no rush). |
@xuanguang-li sorry I just saw you have applied the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR modernizes an econometric forecasting lecture by replacing PyMC with NumPyro and implementing JAX-based optimizations for significant performance improvements (3m → 2.7s). The changes maintain the educational content while improving computational efficiency through vectorization and functional programming patterns.
Key changes:
- Framework Migration: Replaced PyMC with NumPyro for Bayesian inference and switched to JAX for numerical computing
- Performance Optimization: Eliminated Python loops using JAX's
lax.scan
and vectorized operations - Code Restructuring: Introduced an
AR1
class and separated plotting functions for better organization
Thanks @mmcky. I will remove the |
Hi @mmcky and @HumphreyYang, this PR is ready for review. |
Many thanks, @xuanguang-li!
In |
Thanks, @HumphreyYang. I have set |
Hi @xuanguang-li, nice observation. This is because we need to set host device before we import JAX:
This should fix the warning. (See the note in https://num.pyro.ai/en/stable/utilities.html#set-host-device-count). |
Hi @HumphreyYang, I’m sorry but this method didn’t eliminate the error report in the GitHub deployment. I printed Do you have any idea what might be happening? |
Many thanks @xuanguang-li, Yeah I see why. That's because the series is running on a GPU server with just one GPU. I was testing on my local machine with multicore CPU. For GPUs, Also, I think the latest |
Thanks @HumphreyYang! This solved the problem perfectly.
I’ll try using Greek letters in |
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
Hi @mmcky @HumphreyYang, I’ve implemented the requested changes. Could you take a look when you have time? Regarding severe recession, I couldn’t find a clear definition in Wecker (1979), and the NBER website’s definition seems vague:
|
📖 Netlify Preview Ready! Preview URL: https://pr-584--sunny-cactus-210e3e.netlify.app (aee4dfd) 📚 Changed Lecture Pages: ar1_turningpts |
We don't need an instantiation function here. Defaults can be stored in the class. class AR1(NamedTuple):
"""
Represents a univariate first-order autoregressive (AR(1)) process.
Parameters
----------
ρ : float
Autoregressive coefficient, must satisfy |ρ| < 1 for stationarity.
σ : float
Standard deviation of the error term.
y0 : float
Initial value of the process at time t=0.
T0 : int, optional
Length of the initial observed path (default is 100).
T1 : int, optional
Length of the future path to simulate (default is 100).
"""
ρ: float
σ: float
y0: float
T0: int
T1: int
def make_ar1(ρ: float, σ: float, y0: float, T0: int = 100, T1: int = 100):
"""
Factory function to create an AR1 instance with default values for T0 and T1.
Returns
-------
AR1
AR1 named tuple containing the specified parameters.
"""
return AR1(ρ=ρ, σ=σ, y0=y0, T0=T0, T1=T1) |
I'm not in favor of dropping down into I recommend using def plot_Wecker(ar1: AR1, initial_path, ax, N=1000):
"""
Plot the predictive distributions from "pure" Wecker's method.
Parameters
----------
ar1 : AR1
An AR1 named tuple containing the process parameters (ρ, σ, T0, T1).
initial_path : array-like
The initial observed path of the AR(1) process.
N : int
Number of future sample paths to simulate for predictive distributions.
"""
# Plot simulated initial and future paths
y_T0 = initial_path[-1]
future_path = AR1_simulate_future(ar1, y_T0, N=N)
plot_path(ar1, initial_path, future_path, ax[0, 0])
# Simulate future paths and compute statistics
def step(carry, n):
future_temp = future_path[n, :]
(next_reces, severe_rec, min_val_8q, next_up_turn, next_down_turn
) = compute_path_statistics(initial_path, future_temp)
return carry, (next_reces, severe_rec, min_val_8q, next_up_turn, next_down_turn)
_, (next_reces, severe_rec, min_val_8q, next_up_turn, next_down_turn
) = lax.scan(step, None, jnp.arange(N))
# Plot path statistics
plot_path_stats(next_reces, severe_rec, min_val_8q,
next_up_turn, next_down_turn, ax)
fig, ax = plt.subplots(3, 2, figsize=(15, 12))
plot_Wecker(ar1, initial_path, ax)
plt.show() |
Key changes:
Numpyro
to compute posteriors instead.AR1
class to pack parameters, separated plotting functions, and packed path statistics computation functions.for
loops: Usedlax.scan
in almost all loops except for plotting functions and theNumpyro
model.numpy.random
: Usedjax.random
for simulation and fixed the random seed withPRNGKey(0)
.This improved the efficiency of the main function from 3m to 2.7s on my computer!