From 5a95bb73722b7ca0b4b48c33f4938daf1f88622a Mon Sep 17 00:00:00 2001 From: "Oriol (Prodesk)" Date: Mon, 28 Oct 2019 17:05:42 +0100 Subject: [PATCH 1/3] fix bug in emcee blobs handling --- arviz/data/io_emcee.py | 2 ++ arviz/tests/test_data_emcee.py | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/arviz/data/io_emcee.py b/arviz/data/io_emcee.py index 010e53c799..48ff7b8643 100644 --- a/arviz/data/io_emcee.py +++ b/arviz/data/io_emcee.py @@ -182,6 +182,8 @@ def blobs_to_dict(self): blobs = np.array(self.sampler.blobs) if blobs is None or blobs.size == 0: raise ValueError("No blobs in sampler, blob_names must be None") + if len(blobs.shape) == 2: + blobs = np.expand_dims(blobs, axis=-1) blobs = blobs.swapaxes(0, 2) nblobs, nwalkers, ndraws, *_ = blobs.shape if len(self.blob_names) != nblobs and len(self.blob_names) != 1: diff --git a/arviz/tests/test_data_emcee.py b/arviz/tests/test_data_emcee.py index 7c04e359dd..5e41a539b0 100644 --- a/arviz/tests/test_data_emcee.py +++ b/arviz/tests/test_data_emcee.py @@ -128,6 +128,13 @@ def test_peculiar_blobs(self, data): fails = check_multiple_attrs({"sample_stats": ["mix"]}, inference_data) assert not fails + def test_single_blob(self, data): + sampler = emcee.EnsembleSampler(6, 1, lambda x: (-x ** 2, 3)) + sampler.run_mcmc(np.random.normal(size=(6, 1)), 20) + inference_data = from_emcee(sampler, blob_names=["blob"]) + fails = check_multiple_attrs({"sample_stats": ["blob"]}, inference_data) + assert not fails + @pytest.mark.parametrize( "blob_args", [ From 1e87467113673162a90d8c0711943fce4bb31c1d Mon Sep 17 00:00:00 2001 From: "Oriol (Prodesk)" Date: Mon, 28 Oct 2019 17:24:40 +0100 Subject: [PATCH 2/3] lint --- arviz/tests/test_data_emcee.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/tests/test_data_emcee.py b/arviz/tests/test_data_emcee.py index 5e41a539b0..c5dc5a3b71 100644 --- a/arviz/tests/test_data_emcee.py +++ b/arviz/tests/test_data_emcee.py @@ -128,7 +128,7 @@ def test_peculiar_blobs(self, data): fails = check_multiple_attrs({"sample_stats": ["mix"]}, inference_data) assert not fails - def test_single_blob(self, data): + def test_single_blob(self): sampler = emcee.EnsembleSampler(6, 1, lambda x: (-x ** 2, 3)) sampler.run_mcmc(np.random.normal(size=(6, 1)), 20) inference_data = from_emcee(sampler, blob_names=["blob"]) From f4c944ac9976736be260c2f43e5af1b92c090efd Mon Sep 17 00:00:00 2001 From: "OriolAbril(UPF)" Date: Tue, 29 Oct 2019 10:16:18 +0100 Subject: [PATCH 3/3] still had old black locally --- arviz/tests/test_data_emcee.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/tests/test_data_emcee.py b/arviz/tests/test_data_emcee.py index c5dc5a3b71..f6ee6f2d58 100644 --- a/arviz/tests/test_data_emcee.py +++ b/arviz/tests/test_data_emcee.py @@ -129,7 +129,7 @@ def test_peculiar_blobs(self, data): assert not fails def test_single_blob(self): - sampler = emcee.EnsembleSampler(6, 1, lambda x: (-x ** 2, 3)) + sampler = emcee.EnsembleSampler(6, 1, lambda x: (-(x ** 2), 3)) sampler.run_mcmc(np.random.normal(size=(6, 1)), 20) inference_data = from_emcee(sampler, blob_names=["blob"]) fails = check_multiple_attrs({"sample_stats": ["blob"]}, inference_data)