Skip to content

Conversation

xuanguang-li
Copy link
Contributor

Key changes:

  • Removed PyMC: Used Numpyro to compute posteriors instead.
  • Refactored class and functions: Defined an AR1 class to pack parameters, separated plotting functions, and packed path statistics computation functions.
  • Removed for loops: Used lax.scan in almost all loops except for plotting functions and the Numpyro model.
  • Removed numpy.random: Used jax.random for simulation and fixed the random seed with PRNGKey(0).
  • Fixed typos and styling: Followed the style guide and PEP8.

This improved the efficiency of the main function from 3m to 2.7s on my computer!

@mmcky
Copy link
Contributor

mmcky commented Aug 31, 2025

@xuanguang-li this is a big translation effort. Nice work! @HumphreyYang would you have time to review (no rush).

@mmcky mmcky marked this pull request as ready for review August 31, 2025 22:23
@mmcky mmcky requested a review from Copilot August 31, 2025 22:24
@mmcky
Copy link
Contributor

mmcky commented Aug 31, 2025

@xuanguang-li sorry I just saw you have applied the in-work label. Let me know when you're ready for this to be reviewed.

Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR modernizes an econometric forecasting lecture by replacing PyMC with NumPyro and implementing JAX-based optimizations for significant performance improvements (3m → 2.7s). The changes maintain the educational content while improving computational efficiency through vectorization and functional programming patterns.

Key changes:

  • Framework Migration: Replaced PyMC with NumPyro for Bayesian inference and switched to JAX for numerical computing
  • Performance Optimization: Eliminated Python loops using JAX's lax.scan and vectorized operations
  • Code Restructuring: Introduced an AR1 class and separated plotting functions for better organization

@xuanguang-li
Copy link
Contributor Author

Thanks @mmcky. I will remove the in-work label after checking the typos in one or two days.

Copy link

github-actions bot commented Sep 1, 2025

@github-actions github-actions bot temporarily deployed to pull request September 1, 2025 03:12 Inactive
@github-actions github-actions bot temporarily deployed to pull request September 1, 2025 03:13 Inactive
@github-actions github-actions bot temporarily deployed to pull request September 1, 2025 06:42 Inactive
@github-actions github-actions bot temporarily deployed to pull request September 1, 2025 06:42 Inactive
@xuanguang-li xuanguang-li added ready and removed in-work labels Sep 1, 2025
@xuanguang-li
Copy link
Contributor Author

Hi @mmcky and @HumphreyYang, this PR is ready for review.

@HumphreyYang
Copy link
Member

Many thanks, @xuanguang-li!

Do you think we should set num-chains=1 to avoid this error, or should we just leave the error as is?

In numpyro, the number of chains is parallelized across the available processors. So, to run four chains, you can simply set numpyro.set_host_device_count(4).

@xuanguang-li
Copy link
Contributor Author

Thanks, @HumphreyYang.

I have set numpyro.set_host_device(4) at the import cell, but it did not resolve this error, as shown in the Github deploy. Should it be set inside the MCMC.run?

@HumphreyYang
Copy link
Member

Hi @xuanguang-li, nice observation. This is because we need to set host device before we import JAX:

import matplotlib.pyplot as plt
import seaborn as sns
from typing import NamedTuple

# numpyro
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
numpyro.set_host_device_count(4)

# jax
import jax
import jax.random as random
import jax.numpy as jnp
from jax import lax

# arviz
import arviz as az

sns.set_style('white')
colors = sns.color_palette()
key = random.PRNGKey(0)

This should fix the warning.

(See the note in https://num.pyro.ai/en/stable/utilities.html#set-host-device-count).

@github-actions github-actions bot temporarily deployed to pull request September 7, 2025 05:38 Inactive
@github-actions github-actions bot temporarily deployed to pull request September 7, 2025 05:46 Inactive
@github-actions github-actions bot temporarily deployed to pull request September 7, 2025 06:34 Inactive
@github-actions github-actions bot temporarily deployed to pull request September 7, 2025 06:36 Inactive
@xuanguang-li
Copy link
Contributor Author

Hi @HumphreyYang, I’m sorry but this method didn’t eliminate the error report in the GitHub deployment.

I printed jax.local_device_count() at the bottom of the import cell, and it still shows the device as 1. However, this method did work on my local machine.

Do you have any idea what might be happening?

@HumphreyYang
Copy link
Member

HumphreyYang commented Sep 8, 2025

Many thanks @xuanguang-li, Yeah I see why. That's because the series is running on a GPU server with just one GPU. I was testing on my local machine with multicore CPU.

For GPUs, nupyro parallization strategy is chain_method="vectorized". I just pushed a commit and it should fix the warning.

Also, I think the latest numpyro version can take a unicode string as variable name now (i.e. in numpyro.sample)

@github-actions github-actions bot temporarily deployed to pull request September 8, 2025 06:06 Inactive
@github-actions github-actions bot temporarily deployed to pull request September 8, 2025 06:07 Inactive
@xuanguang-li
Copy link
Contributor Author

Thanks @HumphreyYang! This solved the problem perfectly.

Also, I think the latest numpyro version can take a unicode string as a variable name now (i.e. in numpyro.sample)

I’ll try using Greek letters in numpyro.sample. Thanks for letting me know!

@github-actions github-actions bot temporarily deployed to pull request September 8, 2025 06:39 Inactive
@github-actions github-actions bot temporarily deployed to pull request September 8, 2025 06:40 Inactive

This comment was marked as outdated.

This comment was marked as outdated.

@github-actions github-actions bot temporarily deployed to pull request September 9, 2025 04:16 Inactive
@xuanguang-li
Copy link
Contributor Author

Hi @mmcky @HumphreyYang, I’ve implemented the requested changes. Could you take a look when you have time?

Regarding severe recession, I couldn’t find a clear definition in Wecker (1979), and the NBER website’s definition seems vague:

In our interpretation of this definition, we treat the three criteria — depth, diffusion, and duration — as somewhat interchangeable.

https://www.nber.org/research/business-cycle-dating

Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-584--sunny-cactus-210e3e.netlify.app (aee4dfd)

📚 Changed Lecture Pages: ar1_turningpts

@mmcky mmcky added the review label Sep 17, 2025
@jstac
Copy link
Contributor

jstac commented Sep 19, 2025

We don't need an instantiation function here. Defaults can be stored in the class.

class AR1(NamedTuple):
    """
    Represents a univariate first-order autoregressive (AR(1)) process.

    Parameters
    ----------
    ρ : float
        Autoregressive coefficient, must satisfy |ρ| < 1 for stationarity.
    σ : float
        Standard deviation of the error term.
    y0 : float
        Initial value of the process at time t=0.
    T0 : int, optional
        Length of the initial observed path (default is 100).
    T1 : int, optional
        Length of the future path to simulate (default is 100).
    """
    ρ: float
    σ: float
    y0: float
    T0: int
    T1: int


def make_ar1(ρ: float, σ: float, y0: float, T0: int = 100, T1: int = 100):
    """
    Factory function to create an AR1 instance with default values for T0 and T1.
    
    Returns
    -------
    AR1
        AR1 named tuple containing the specified parameters.
    """
    return AR1(ρ=ρ, σ=σ, y0=y0, T0=T0, T1=T1)

@jstac
Copy link
Contributor

jstac commented Sep 19, 2025

Link needs fixing?

image

@jstac
Copy link
Contributor

jstac commented Sep 19, 2025

I'm not in favor of dropping down into lax in functions like this -- plotting code where speed is irrelevant.

I recommend using at[x].set(y) or a appending to an ordinary Python list. What do you think @mmcky @HumphreyYang ?

def plot_Wecker(ar1: AR1, initial_path, ax, N=1000):
    """
    Plot the predictive distributions from "pure" Wecker's method.

    Parameters
    ----------
    ar1 : AR1
        An AR1 named tuple containing the process parameters (ρ, σ, T0, T1).
    initial_path : array-like
        The initial observed path of the AR(1) process.
    N : int
        Number of future sample paths to simulate for predictive distributions.
    """
    # Plot simulated initial and future paths
    y_T0 = initial_path[-1]
    future_path = AR1_simulate_future(ar1, y_T0, N=N)
    plot_path(ar1, initial_path, future_path, ax[0, 0])

    # Simulate future paths and compute statistics
    def step(carry, n):
        future_temp = future_path[n, :]
        (next_reces, severe_rec, min_val_8q, next_up_turn, next_down_turn
         ) = compute_path_statistics(initial_path, future_temp)
    
        return carry, (next_reces, severe_rec, min_val_8q, next_up_turn, next_down_turn)
    
    _, (next_reces, severe_rec, min_val_8q, next_up_turn, next_down_turn
        ) = lax.scan(step, None, jnp.arange(N))
    
    # Plot path statistics
    plot_path_stats(next_reces, severe_rec, min_val_8q, 
                    next_up_turn, next_down_turn, ax)


fig, ax = plt.subplots(3, 2, figsize=(15, 12))
plot_Wecker(ar1, initial_path, ax)
plt.show()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

4 participants