Skip to content

Commit

Permalink
Added dummy dimensions exceptions to plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
semohr committed May 21, 2021
1 parent a425612 commit 373ec0f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 12 deletions.
11 changes: 11 additions & 0 deletions covid19_npis/modelParams.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ def countries(self):
"""
return self._countries

def country_by_name(self, name):
"""
Returns country by name
"""
for country in self._countries:
if country.name == name:
return country

# Error
raise Error("Name not found in country list")

@countries.setter
def countries(self, countries):
"""
Expand Down
44 changes: 32 additions & 12 deletions covid19_npis/plot/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from scipy import stats

from .rcParams import *

from .utils import (
get_posterior_prior_from_trace,
get_math_from_name,
get_dist_by_name_from_sample_state,
)
from .. import modelParams

mpl.rc("figure", max_open_warning=0)
log = logging.getLogger(__name__)
Expand Down Expand Up @@ -99,9 +99,11 @@ def helper_plot(posterior, prior, name_str):
if plot_age_groups_together and ("age_group" in df.index.names):
unq_age = df.index.get_level_values("age_group").unique()
fig, ax = plt.subplots(
len(unq_age), 1, figsize=(2.2, 2.2 * len(unq_age),), squeeze=False,
len(unq_age),
1,
figsize=(2.2, 2.2 * len(unq_age),),
squeeze=False
)
ax = ax[:, 0]
for i, ag in enumerate(unq_age):
# Create pivot table i.e. time on index and draw on columns
if posterior is not None:
Expand All @@ -126,16 +128,26 @@ def helper_plot(posterior, prior, name_str):
# Set title for axis
ax_now.set_title(ag)
elif len(df.index.names) == 1:
# Exception for dummy dimensions in change point logic
check = [
"d_i_c_p",
]
if dist.name in check:
intervention, country = name_str.split("/")
country = modelParams.modelParams.country_by_name(country)
num_rows = len(country.change_points[intervention])
else:
num_rows = len(df.index.get_level_values(df.index.names[0]).unique())

if num_rows == 0:
return

fig, ax = plt.subplots(
len(df.index.get_level_values(df.index.names[0]).unique()),
1,
figsize=(
2.2,
2.2 * len(df.index.get_level_values(df.index.names[0]).unique()),
),
num_rows,
1,
figsize=(2.2, 2.2 * num_rows,),
squeeze=False,
)
ax = ax[:, 0]
for i, ag in enumerate(
df.index.get_level_values(df.index.names[0]).unique()
):
Expand All @@ -148,16 +160,24 @@ def helper_plot(posterior, prior, name_str):
else:
prior_t = None

if i >= num_rows:
continue

if num_rows == 1:
p_axes = ax
else:
p_axes = ax[i]

# Plot
_distribution(
array_posterior=posterior_t,
array_prior=prior_t,
dist_name=dist.name,
dist_math=get_math_from_name(dist.name),
suffix=f"{i}",
ax=ax[i],
ax=p_axes,
)
ax[i].set_title(ag)
p_axes.set_title(ag)
else:
i = 0
if posterior is not None:
Expand Down

0 comments on commit 373ec0f

Please sign in to comment.