In [1]:
import bayesfast as bf
import numpy as np
import multiprocess
from distributed import Client, LocalCluster

In [2]:
D = 16 # number of dims
a = 1.
b = 0.5
lower = np.full(D, -30.) # lower bound of the prior for x_1, ...
upper = np.full(D, 30.) # upper bound of the prior for x_1, ...
lower[0] = -4 # lower bound of the prior for x_0
upper[0] = 4 # upper bound of the prior for x_0
bound = np.array((lower, upper)).T
diff = bound[:, 1] - bound[:, 0]
const = np.sum(np.log(diff)) # normalization of the flat prior

def logp(x):
    n = x.shape[-1]
    _a = -0.5 * x[..., 0]**2 / a**2
    _b = -0.5 * np.sum(x[..., 1:]**2, axis=-1) * np.exp(-2 * b * x[..., 0])
    _c = (-0.5 * np.log(2 * np.pi * a**2) - 
          0.5 * (n - 1) * np.log(2 * np.pi) - (n - 1) * b * x[..., 0])
    return _a + _b + _c - const

def grad(x):
    n = x.shape[-1]
    foo = -x / np.insert(np.full((*x.shape[:-1], n - 1), 
                                 np.exp(2 * b * x[..., 0])), 0, a**2, axis=-1)
    foo[0] += b * np.sum(x[..., 1:]**2, axis=-1) * np.exp(-2 * b * x[..., 0])
    foo[0] -= (n - 1) * b
    return foo

In [3]:
den = bf.DensityLite(logp=logp, grad=grad, input_size=D, input_scales=bound,
                     hard_bounds=True)
np.random.seed(0)
x = np.random.randn(12, D)

In [4]:
cluster = LocalCluster(n_workers=12, threads_per_worker=1)
client = Client(cluster)
client

0,1
Client  Scheduler: tcp://127.0.0.1:34807  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 12  Cores: 12  Memory: 540.15 GB


### 1. let's first map `den` using `dask`

In [5]:
# first run using dask
%time foo = client.gather(client.map(den, x))
np.asarray(foo)

CPU times: user 138 ms, sys: 31.1 ms, total: 169 ms
Wall time: 1.47 s


array([-94.2455445 , -92.8864777 , -95.90057691, -86.70286387,
       -83.33536339, -96.04281462, -91.81630551, -88.46533025,
       -85.23804851, -99.53316646, -87.24607566, -91.3730939 ])

In [6]:
# second run using dask
%time foo = client.gather(client.map(den, x))
np.asarray(foo)

CPU times: user 19.3 ms, sys: 8.39 ms, total: 27.7 ms
Wall time: 25.4 ms


array([-94.2455445 , -92.8864777 , -95.90057691, -86.70286387,
       -83.33536339, -96.04281462, -91.81630551, -88.46533025,
       -85.23804851, -99.53316646, -87.24607566, -91.3730939 ])

In [7]:
# third run using dask
%time foo = client.gather(client.map(den, x))
np.asarray(foo)

CPU times: user 17.7 ms, sys: 6.23 ms, total: 24 ms
Wall time: 22 ms


array([-94.2455445 , -92.8864777 , -95.90057691, -86.70286387,
       -83.33536339, -96.04281462, -91.81630551, -88.46533025,
       -85.23804851, -99.53316646, -87.24607566, -91.3730939 ])

### 2. then map `den` using `multiprocess`

In [8]:
# first run using multiprocess
with multiprocess.Pool(12) as pool:
    %time foo = pool.map(den, x)
np.asarray(foo)

CPU times: user 21.1 ms, sys: 194 µs, total: 21.3 ms
Wall time: 20.7 ms


array([-94.2455445 , -92.8864777 , -95.90057691, -86.70286387,
       -83.33536339, -96.04281462, -91.81630551, -88.46533025,
       -85.23804851, -99.53316646, -87.24607566, -91.3730939 ])

In [9]:
# second run using multiprocess
with multiprocess.Pool(12) as pool:
    %time foo = pool.map(den, x)
np.asarray(foo)

CPU times: user 16 ms, sys: 7.98 ms, total: 24 ms
Wall time: 22.9 ms


array([-94.2455445 , -92.8864777 , -95.90057691, -86.70286387,
       -83.33536339, -96.04281462, -91.81630551, -88.46533025,
       -85.23804851, -99.53316646, -87.24607566, -91.3730939 ])

### 3. actually, `den` is `logp` plus some wrapper, so if we map `logp`, the smaller object, with `dask`

In [10]:
# first run using dask
%time foo = client.gather(client.map(logp, x))
np.asarray(foo)

CPU times: user 15.6 ms, sys: 11.9 ms, total: 27.5 ms
Wall time: 27.2 ms


array([-94.2455445 , -92.8864777 , -95.90057691, -86.70286387,
       -83.33536339, -96.04281462, -91.81630551, -88.46533025,
       -85.23804851, -99.53316646, -87.24607566, -91.3730939 ])

In [11]:
# second run using dask
%time foo = client.gather(client.map(logp, x))
np.asarray(foo)

CPU times: user 31.5 ms, sys: 1.41 ms, total: 32.9 ms
Wall time: 30.1 ms


array([-94.2455445 , -92.8864777 , -95.90057691, -86.70286387,
       -83.33536339, -96.04281462, -91.81630551, -88.46533025,
       -85.23804851, -99.53316646, -87.24607566, -91.3730939 ])