### Parallelization  

https://emcee.readthedocs.io/en/stable/tutorials/parallel/

In [3]:
import os

os.environ["OMP_NUM_THREADS"] = "1"

In [4]:
# A computationally expensive model
import time
import numpy as np


def log_prob(theta):
    t = time.time() + np.random.uniform(0.005, 0.008)
    while True:
        if time.time() >= t:
            break
    return -0.5 * np.sum(theta**2)

In [5]:
import emcee

np.random.seed(42)
initial = np.random.randn(32, 5)
nwalkers, ndim = initial.shape
nsteps = 100

sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob)
start = time.time()
sampler.run_mcmc(initial, nsteps, progress=True)
end = time.time()
serial_time = end - start
print("Serial took {0:.1f} seconds".format(serial_time))

100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:22<00:00,  4.37it/s]

Serial took 23.1 seconds





In [None]:
# Perform multiprocessing to speed up the process
from multiprocessing import Pool

with Pool() as pool:
    sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob, pool=pool)
    start = time.time()
    sampler.run_mcmc(initial, nsteps, progress=True)
    end = time.time()
    multi_time = end - start
    print("Multiprocessing took {0:.1f} seconds".format(multi_time))
    print("{0:.1f} times faster than serial".format(serial_time / multi_time))

from multiprocessing import cpu_count

ncpu = cpu_count()
print("{0} CPUs".format(ncpu))

100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 15.68it/s]

Multiprocessing took 6.4 seconds
3.6 times faster than serial
8 CPUs





In [7]:
# Example of a large dataset increasing the runtime
def log_prob_data(theta, data):
    a = data[0]  # Use the data somehow...
    t = time.time() + np.random.uniform(0.005, 0.008)
    while True:
        if time.time() >= t:
            break
    return -0.5 * np.sum(theta**2)


data = np.random.randn(5000, 200)

sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob_data, args=(data,))
start = time.time()
sampler.run_mcmc(initial, nsteps, progress=True)
end = time.time()
serial_data_time = end - start
print("Serial took {0:.1f} seconds".format(serial_data_time))

100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:22<00:00,  4.36it/s]

Serial took 23.2 seconds





In [None]:
# Make the data variable global so that it is only passed into the function once, greatly reducing the runtime
def log_prob_data_global(theta):
    a = data[0]  # Use the data somehow...
    t = time.time() + np.random.uniform(0.005, 0.008)
    while True:
        if time.time() >= t:
            break
    return -0.5 * np.sum(theta**2)


with Pool() as pool:
    sampler = emcee.EnsembleSampler(
        nwalkers, ndim, log_prob_data_global, pool=pool
    )
    start = time.time()
    sampler.run_mcmc(initial, nsteps, progress=True)
    end = time.time()
    multi_data_global_time = end - start
    print(
        "Multiprocessing took {0:.1f} seconds".format(multi_data_global_time)
    )
    print(
        "{0:.1f} times faster than serial".format(
            serial_data_time / multi_data_global_time
        )
    )

100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 21.40it/s]


Multiprocessing took 4.7 seconds
4.9 times faster than serial
