Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import memilio.simulation as mio
import memilio.simulation.secir as secir
import memilio.plot.createGIF as mp

from enum import Enum
from memilio.simulation.secir import (Model, Simulation,
Expand Down Expand Up @@ -426,7 +427,7 @@ def get_graph(self, end_date):

return graph

def run(self, num_days_sim, num_runs=10, save_graph=True):
def run(self, num_days_sim, num_runs=10, save_graph=True, create_gif=True):
mio.set_log_level(mio.LogLevel.Warning)
end_date = self.start_date + datetime.timedelta(days=num_days_sim)

Expand Down Expand Up @@ -459,6 +460,11 @@ def run(self, num_days_sim, num_runs=10, save_graph=True):
secir.save_results(
ensemble_results, ensemble_params, node_ids, self.results_dir,
save_single_runs, save_percentiles)
if create_gif:
# any compartments in the model (see InfectionStates)
compartments = [c for c in range(1, 8)]
mp.create_gif_map_plot(
self.results_dir + "/p75", self.results_dir, compartments)
return 0


Expand All @@ -468,5 +474,5 @@ def run(self, num_days_sim, num_runs=10, save_graph=True):
data_dir=os.path.join(file_path, "../../../data"),
start_date=datetime.date(year=2020, month=12, day=12),
results_dir=os.path.join(file_path, "../../../results_secir"))
num_days_sim = 30
num_days_sim = 50
sim.run(num_days_sim, num_runs=2)
2 changes: 2 additions & 0 deletions pycode/memilio-plot/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ Required python packages:
- mapclassify
- geopandas
- h5py
- imageio
- datetime

Testing and Coverage
--------------------
Expand Down
155 changes: 155 additions & 0 deletions pycode/memilio-plot/memilio/plot/createGIF.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
#############################################################################
# Copyright (C) 2020-2024 MEmilio
#
# Authors: Henrik Zunker, Maximilian Betz
#
# Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#############################################################################

import datetime as dt
import os.path
import imageio
import tempfile

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import memilio.epidata.getPopulationData as gpd
import memilio.plot.plotMap as pm
from memilio.epidata import geoModificationGermany as geoger
import memilio.epidata.progress_indicator as progind
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)


def create_plot_map(day, filename, files_input, output_path, compartments, file_format='h5', relative=False, age_groups={0: '0-4', 1: '5-14', 2: '15-34', 3: '35-59', 4: '60-79', 5: '80+'}):
"""! Plots region-specific information for a single day of the simulation.
@param[in] day Day of the simulation.
@param[in] filename Name of the file to be created.
@param[in] files_input Dictionary of input files.
@param[in] output_path Output path for the figure.
@param[in] compartments List of compartments to be plotted.
@param[in] file_format Format of the file to be created. Either 'h5' or 'json'.
@param[in] relative Defines if data should be scaled relative to population.
@param[in] age_groups Dictionary of age groups to be considered.
"""

if len(age_groups) == 6:
filter_age = None
else:
if file_format == 'json':
filter_age = [val for val in age_groups.values()]
else:
filter_age = ['Group' + str(key) for key in age_groups.keys()]

# In file_input there can be two different files. When we enter two files,
# both files are plotted side by side in the same figure.
file_index = 0
for file in files_input.values():

df = pm.extract_data(
file, region_spec=None, column=None, date=day,
filters={'Group': filter_age, 'InfectionState': compartments},
file_format=file_format)

if relative:

try:
population = pd.read_json(
'data/pydata/Germany/county_current_population.json')
# pandas>1.5 raise FileNotFoundError instead of ValueError
except (ValueError, FileNotFoundError):
print(
"Population data was not found. Downloading it from the internet.")
population = gpd.get_population_data(
read_data=False, file_format=file_format,
out_folder='data/pydata/Germany/', no_raw=True, merge_eisenach=True)

# For fitting of different age groups we need format ">X".
age_group_values = list(age_groups.values())
age_group_values[-1] = age_group_values[-1].replace(
'80+', '>79')
# scale data
df = pm.scale_dataframe_relative(
df, age_group_values, population)

if file_index == 0:
dfs_all = pd.DataFrame(df.iloc[:, 0])

dfs_all[df.columns[-1] + ' ' + str(file_index)] = df[df.columns[-1]]
file_index += 1

dfs_all = dfs_all.apply(pd.to_numeric, errors='coerce')

dfs_all_sorted = dfs_all.sort_values(by='Region')
dfs_all_sorted = dfs_all_sorted.reset_index(drop=True)

min_val = dfs_all_sorted[dfs_all_sorted.columns[1:]].min().min()
max_val = dfs_all_sorted[dfs_all_sorted.columns[1:]].max().max()

pm.plot_map(
dfs_all_sorted, scale_colors=np.array([min_val, max_val]),
legend=['', ''],
title='Synthetic data (relative) day ' + f'{day:2d}', plot_colorbar=True,
output_path=output_path,
fig_name=filename, dpi=300,
outercolor=[205 / 255, 238 / 255, 251 / 255])


def create_gif_map_plot(input_data, output_dir, compartments, filename="simulation", relative=True, age_groups={0: '0-4', 1: '5-14', 2: '15-34',
3: '35-59', 4: '60-79', 5: '80+'}):
"""! Creates a gif of the simulation results by calling create_plot_map for each day of the simulation and then
storing the single plots in a temporary directory. Currently only works for the results created by the parameter study.

@param[in] input_data Path to the input data. The Path should contain a file called 'Results' which contains
the simulation results. This is the default output folder of the parameter study.
@param[in] output_dir Path where the gif should be stored.
@param[in] filename Name of the temporary file.
@param[in] relative Defines if data should be scaled relative to population.
@param[in] age_groups Dictionary of age groups to be considered.
"""

files_input = {'Data set': input_data + '/Results'}
file_format = 'h5'

if len(age_groups) == 6:
filter_age = None
else:
filter_age = ['Group' + str(key) for key in age_groups.keys()]

num_days = pm.extract_time_steps(
files_input[list(files_input.keys())[0]], file_format=file_format)

# create gif
frames = []
with progind.Percentage() as indicator:
with tempfile.TemporaryDirectory() as tmpdirname:
for day in range(0, num_days):
create_plot_map(day, filename, files_input, tmpdirname,
compartments, file_format, relative, age_groups)

image = imageio.v2.imread(
os.path.join(tmpdirname, filename + ".png"))
frames.append(image)

# Close the current figure to free up memory
plt.close('all')
indicator.set_progress((day+1)/num_days)

imageio.mimsave(os.path.join(output_dir, filename + '.gif'),
frames, # array of input frames
duration=0.2, # duration of each frame in seconds
loop=0) # optional: frames per second
62 changes: 47 additions & 15 deletions pycode/memilio-plot/memilio/plot/plotMap.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec
import matplotlib.colors as mcolors

from memilio.epidata import geoModificationGermany as geoger
from memilio.epidata import getDataIntoPandasDataFrame as gd
Expand Down Expand Up @@ -143,11 +144,14 @@ def extract_data(

# Set no filtering if filters were set to None.
if filters == None:
filters['Group'] = list(h5file[regions[i]].keys())[
filters['Group'] = list(h5file[regions[0]].keys())[
:-2] # Remove 'Time' and 'Total'.
filters['InfectionState'] = list(
range(h5file[regions[i]]['Group1'].shape[1]))

if filters['Group'] == None:
filters['Group'] = list(h5file[regions[0]].keys())[:-2]

InfectionStateList = [j for j in filters['InfectionState']]

# Create data frame to store results to plot.
Expand Down Expand Up @@ -192,8 +196,8 @@ def extract_data(

k += 1
else:
raise gd.ValueError(
"Time point not found for region " + str(regions[i]) + ".")
raise gd.DataError(
"Time point " + str(date) + " not found for region " + str(regions[i]) + ".")

# Aggregated or matrix output.
if output == 'sum':
Expand All @@ -207,6 +211,29 @@ def extract_data(
raise gd.DataError("Data could not be read in.")


def extract_time_steps(file, file_format='json'):
""" Reads data from a general json or specific hdf5 file as output by the
MEmilio simulation framework and extracts the number of days used.

@param[in] file Path and filename of file to be read in, relative from current
directory.
@param[in] file_format File format; either json or h5.
@return Number of time steps.
"""
input_file = os.path.join(os.getcwd(), str(file))
if file_format == 'json':
df = pd.read_json(input_file + '.' + file_format)
if 'Date' in df.columns:
time_steps = df['Date'].nunique()
else:
time_steps = 1
elif file_format == 'h5':
h5file = h5py.File(input_file + '.' + file_format, 'r')
regions = list(h5file.keys())
time_steps = len(h5file[regions[0]]['Time'])
return time_steps


def scale_dataframe_relative(df, age_groups, df_population):
"""! Scales a population-related data frame relative to the size of the
local populations or subpopulations (e.g., if not all age groups are
Expand All @@ -225,7 +252,7 @@ def scale_dataframe_relative(df, age_groups, df_population):
"""

# Merge population data of Eisenach (if counted separately) with Wartburgkreis.
if 16056 in df_population[df.columns[0]].values:
if 16056 in df_population['ID_County'].values:
for i in range(1, len(df_population.columns)):
df_population.loc[df_population[df.columns[0]] == 16063, df_population.columns[i]
] += df_population.loc[df_population.ID_County == 16056, df_population.columns[i]]
Expand All @@ -235,12 +262,12 @@ def scale_dataframe_relative(df, age_groups, df_population):
columns=[df_population.columns[0]] + age_groups)
# Extrapolate on oldest age group with maximumg age 100.
for region_id in df.iloc[:, 0]:
df_population_agegroups.loc[len(df_population_agegroups.index), :] = [region_id] + list(
mdfs.fit_age_group_intervals(df_population[df_population.iloc[:, 0] == region_id].iloc[:, 2:], age_groups))
df_population_agegroups.loc[len(df_population_agegroups.index), :] = [int(region_id)] + list(
mdfs.fit_age_group_intervals(df_population[df_population.iloc[:, 0] == int(region_id)].iloc[:, 2:], age_groups))

def scale_row(elem):
population_local_sum = df_population_agegroups[
df_population_agegroups[df.columns[0]] == elem[0]].iloc[
df_population_agegroups['ID_County'] == int(elem[0])].iloc[
:, 1:].sum(axis=1)
return elem['Count'] / population_local_sum.values[0]

Expand Down Expand Up @@ -272,7 +299,8 @@ def plot_map(data: pd.DataFrame,
output_path: str = '',
fig_name: str = 'customPlot',
dpi: int = 300,
outercolor='white'):
outercolor='white',
log_scale=False):
"""! Plots region-specific information onto a interactive html map and
returning svg and png image. Allows the comparisons of a variable list of
data sets.
Expand All @@ -288,12 +316,14 @@ def plot_map(data: pd.DataFrame,
@param[in] fig_name Name of the figure created.
@param[in] dpi Dots-per-inch value for the exported figure.
@param[in] outercolor Background color of the plot image.
@param[in] log_scale Defines if the colorbar is plotted in log scale.
"""
region_classifier = data.columns[0]
region_data = data[region_classifier].to_numpy().astype(int)

data_columns = data.columns[1:]
# Read and filter map data.
if data[region_classifier].isin(geoger.get_county_ids()).all():
if np.isin(region_data, geoger.get_county_ids()).all():
try:
map_data = geopandas.read_file(
os.path.join(
Expand Down Expand Up @@ -342,17 +372,23 @@ def plot_map(data: pd.DataFrame,
# Use top row for title.
tax = fig.add_subplot(gs[0, :])
tax.set_axis_off()
tax.set_title(title, fontsize=18)
tax.set_title(title, fontsize=16)
if plot_colorbar:
# Prepare colorbar.
cax = fig.add_subplot(gs[1, 0])
else:
cax = None

if log_scale:
norm = mcolors.LogNorm(vmin=scale_colors[0], vmax=scale_colors[1])

for i in range(len(data_columns)):

ax = fig.add_subplot(gs[:, i+2])
if cax is not None:
if log_scale:
map_data.plot(data_columns[i], ax=ax, cax=cax, legend=True,
norm=norm)
elif cax is not None:
map_data.plot(data_columns[i], ax=ax, cax=cax, legend=True,
vmin=scale_colors[0], vmax=scale_colors[1])
else:
Expand All @@ -364,8 +400,4 @@ def plot_map(data: pd.DataFrame,
ax.set_axis_off()

plt.subplots_adjust(bottom=0.1)

plt.savefig(os.path.join(output_path, fig_name + '.png'), dpi=dpi)
plt.savefig(os.path.join(output_path, fig_name + '.svg'), dpi=dpi)

plt.show()
Loading