From 9352af06e0e1979c5cc766a7bc9495dfd64fde37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CHenrik?= <“henrik.zunker@dlr.de”> Date: Mon, 4 Sep 2023 14:38:03 +0200 Subject: [PATCH 01/13] fix plotmap + log scale colormap --- pycode/memilio-plot/memilio/plot/plotMap.py | 22 +++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/pycode/memilio-plot/memilio/plot/plotMap.py b/pycode/memilio-plot/memilio/plot/plotMap.py index caa5d8b3af..0fe2f9db92 100755 --- a/pycode/memilio-plot/memilio/plot/plotMap.py +++ b/pycode/memilio-plot/memilio/plot/plotMap.py @@ -27,6 +27,7 @@ import pandas as pd from matplotlib import pyplot as plt from matplotlib.gridspec import GridSpec +import matplotlib.colors as mcolors from memilio.epidata import geoModificationGermany as geoger from memilio.epidata import getDataIntoPandasDataFrame as gd @@ -143,11 +144,14 @@ def extract_data( # Set no filtering if filters were set to None. if filters == None: - filters['Group'] = list(h5file[regions[i]].keys())[ + filters['Group'] = list(h5file[regions[0]].keys())[ :-2] # Remove 'Time' and 'Total'. filters['InfectionState'] = list( range(h5file[regions[i]]['Group1'].shape[1])) + if filters['Group'] == None: + filters['Group'] = list(h5file[regions[0]].keys())[:-2] + InfectionStateList = [j for j in filters['InfectionState']] # Create data frame to store results to plot. @@ -225,7 +229,7 @@ def scale_dataframe_relative(df, age_groups, df_population): """ # Merge population data of Eisenach (if counted separately) with Wartburgkreis. - if 16056 in df_population[df.columns[0]].values: + if 16056 in df_population['ID_County'].values: for i in range(1, len(df_population.columns)): df_population.loc[df_population[df.columns[0]] == 16063, df_population.columns[i] ] += df_population.loc[df_population.ID_County == 16056, df_population.columns[i]] @@ -235,12 +239,12 @@ def scale_dataframe_relative(df, age_groups, df_population): columns=[df_population.columns[0]] + age_groups) # Extrapolate on oldest age group with maximumg age 100. for region_id in df.iloc[:, 0]: - df_population_agegroups.loc[len(df_population_agegroups.index), :] = [region_id] + list( - mdfs.fit_age_group_intervals(df_population[df_population.iloc[:, 0] == region_id].iloc[:, 2:], age_groups)) + df_population_agegroups.loc[len(df_population_agegroups.index), :] = [int(region_id)] + list( + mdfs.fit_age_group_intervals(df_population[df_population.iloc[:, 0] == int(region_id)].iloc[:, 2:], age_groups)) def scale_row(elem): population_local_sum = df_population_agegroups[ - df_population_agegroups[df.columns[0]] == elem[0]].iloc[ + df_population_agegroups['ID_County'] == int(elem[0])].iloc[ :, 1:].sum(axis=1) return elem['Count'] / population_local_sum.values[0] @@ -290,10 +294,11 @@ def plot_map(data: pd.DataFrame, @param[in] outercolor Background color of the plot image. """ region_classifier = data.columns[0] + region_data = data[region_classifier].to_numpy().astype(int) data_columns = data.columns[1:] # Read and filter map data. - if data[region_classifier].isin(geoger.get_county_ids()).all(): + if np.isin(region_data, geoger.get_county_ids()).all(): try: map_data = geopandas.read_file( os.path.join( @@ -349,16 +354,21 @@ def plot_map(data: pd.DataFrame, else: cax = None + # log scale of colorbar. + # norm = mcolors.LogNorm(vmin=scale_colors[0], vmax=scale_colors[1]) + for i in range(len(data_columns)): ax = fig.add_subplot(gs[:, i+2]) if cax is not None: map_data.plot(data_columns[i], ax=ax, cax=cax, legend=True, vmin=scale_colors[0], vmax=scale_colors[1]) + # map_data.plot(data_columns[i], ax=ax, cax=cax, legend=True, norm=norm) else: # Do not plot colorbar. map_data.plot(data_columns[i], ax=ax, legend=False, vmin=scale_colors[0], vmax=scale_colors[1]) + # map_data.plot(data_columns[i], ax=ax, cax=cax, legend=True, norm=norm) ax.set_title(legend[i], fontsize=12) ax.set_axis_off() From ee501f404ab366e2a013060ba343076dd695253a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CHenrik?= <“henrik.zunker@dlr.de”> Date: Mon, 4 Sep 2023 14:39:52 +0200 Subject: [PATCH 02/13] plot gif example --- pycode/examples/plot/plotGifResultsGermany.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 pycode/examples/plot/plotGifResultsGermany.py diff --git a/pycode/examples/plot/plotGifResultsGermany.py b/pycode/examples/plot/plotGifResultsGermany.py new file mode 100644 index 0000000000..209bf99fb1 --- /dev/null +++ b/pycode/examples/plot/plotGifResultsGermany.py @@ -0,0 +1,105 @@ +import datetime as dt +import os.path +import imageio + +import numpy as np +import pandas as pd + +import memilio.epidata.getPopulationData as gpd +import memilio.plot.plotMap as pm +from memilio.epidata import geoModificationGermany as geoger + +if __name__ == '__main__': + + files_input = {'Data set 1': 'test2/p75/Results'} + file_format = 'h5' + # Define age groups which will be considered through filtering + # Keep keys and values as well as its assignment constant, remove entries + # if only part of the population should be plotted or considered, e.g., by + # setting: + # age_groups = {1: '5-14', 2: '15-34'} + age_groups = {0: '0-4', 1: '5-14', 2: '15-34', + 3: '35-59', 4: '60-79', 5: '80+'} + if len(age_groups) == 6: + filter_age = None + else: + if file_format == 'json': + filter_age = [val for val in age_groups.values()] + else: + filter_age = ['Group' + str(key) for key in age_groups.keys()] + + relative = True + num_days = 92 + for day in range(0, num_days): + + i = 0 + for file in files_input.values(): + # MEmilio backend hdf5 example + + deads = [24, 25, 26] + susceptible = [0, 1, 23] + infected = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22] + + df = pm.extract_data( + file, region_spec=None, column=None, date=day, + filters={'Group': filter_age, 'InfectionState': infected}, + file_format=file_format) + + if relative: + + try: + population = pd.read_json( + 'data/pydata/Germany/county_current_population.json') + # pandas>1.5 raise FileNotFoundError instead of ValueError + except (ValueError, FileNotFoundError): + print( + "Population data was not found. Download it from the internet.") + population = gpd.get_population_data( + read_data=False, file_format=file_format, + out_folder='data/pydata/Germany/', no_raw=True, + split_gender=False, merge_eisenach=True) + + # For fitting of different age groups we need format ">X". + age_group_values = list(age_groups.values()) + age_group_values[-1] = age_group_values[-1].replace( + '80+', '>79') + # scale data + df = pm.scale_dataframe_relative( + df, age_group_values, population) + + if i == 0: + dfs_all = pd.DataFrame(df.iloc[:, 0]) + + dfs_all[df.columns[-1] + ' ' + str(i)] = df[df.columns[-1]] + i += 1 + + filename = 'customPlot_day_' + str(day) + + dfs_all = dfs_all.apply(pd.to_numeric, errors='coerce') + + dfs_all_sorted = dfs_all.sort_values(by='Region') + dfs_all_sorted = dfs_all_sorted.reset_index(drop=True) + + min_val = dfs_all_sorted[dfs_all_sorted.columns[1:]].min().min() + max_val = dfs_all_sorted[dfs_all_sorted.columns[1:]].max().max() + + pm.plot_map( + dfs_all_sorted, scale_colors=np.array([min_val, max_val]), + legend=['', ''], + title='Synthetic data (relative) day ' + str(day), plot_colorbar=True, + output_path="plot_gif", + fig_name=filename, dpi=300, + outercolor=[205 / 255, 238 / 255, 251 / 255]) + + # create gif + frames = [] + for day in range(0, num_days): + filename = 'customPlot_day_' + str(day) + '.png' + output_path = "plot_gif" + image = imageio.v2.imread(os.path.join(output_path, filename)) + frames.append(image) + + imageio.mimsave(os.path.join(output_path, 'sim.gif'), # output gif + frames, # array of input frames + duration=10) # optional: frames per second From ffa2ab33dba37a2b59d94b47d09b8e93b8297e74 Mon Sep 17 00:00:00 2001 From: MaxBetzDLR Date: Fri, 8 Sep 2023 13:36:33 +0200 Subject: [PATCH 03/13] Rearranging gif creation and add time extraction --- pycode/examples/plot/plotGifResultsGermany.py | 141 +++++++++--------- pycode/memilio-plot/memilio/plot/plotMap.py | 27 +++- 2 files changed, 98 insertions(+), 70 deletions(-) diff --git a/pycode/examples/plot/plotGifResultsGermany.py b/pycode/examples/plot/plotGifResultsGermany.py index 209bf99fb1..937e5e0cdc 100644 --- a/pycode/examples/plot/plotGifResultsGermany.py +++ b/pycode/examples/plot/plotGifResultsGermany.py @@ -9,9 +9,70 @@ import memilio.plot.plotMap as pm from memilio.epidata import geoModificationGermany as geoger + +def create_plot_map(day, age_groups, relative, filename, file_format, files_input, output_path): + + i = 0 + for file in files_input.values(): + # MEmilio backend hdf5 example + + deads = [24, 25, 26] + susceptible = [0, 1, 23] + infected = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22] + + df = pm.extract_data( + file, region_spec=None, column=None, date=day, + filters={'Group': filter_age, 'InfectionState': infected}, + file_format=file_format) + + if relative: + + try: + population = pd.read_json( + 'data/pydata/Germany/county_current_population.json') + # pandas>1.5 raise FileNotFoundError instead of ValueError + except (ValueError, FileNotFoundError): + print( + "Population data was not found. Download it from the internet.") + population = gpd.get_population_data( + read_data=False, file_format=file_format, + out_folder='data/pydata/Germany/', no_raw=True, merge_eisenach=True) + + # For fitting of different age groups we need format ">X". + age_group_values = list(age_groups.values()) + age_group_values[-1] = age_group_values[-1].replace( + '80+', '>79') + # scale data + df = pm.scale_dataframe_relative( + df, age_group_values, population) + + if i == 0: + dfs_all = pd.DataFrame(df.iloc[:, 0]) + + dfs_all[df.columns[-1] + ' ' + str(i)] = df[df.columns[-1]] + i += 1 + + dfs_all = dfs_all.apply(pd.to_numeric, errors='coerce') + + dfs_all_sorted = dfs_all.sort_values(by='Region') + dfs_all_sorted = dfs_all_sorted.reset_index(drop=True) + + min_val = dfs_all_sorted[dfs_all_sorted.columns[1:]].min().min() + max_val = dfs_all_sorted[dfs_all_sorted.columns[1:]].max().max() + + pm.plot_map( + dfs_all_sorted, scale_colors=np.array([min_val, max_val]), + legend=['', ''], + title='Synthetic data (relative) day ' + str(day), plot_colorbar=True, + output_path=output_path, + fig_name=filename, dpi=300, + outercolor=[205 / 255, 238 / 255, 251 / 255]) + + if __name__ == '__main__': - files_input = {'Data set 1': 'test2/p75/Results'} + files_input = {'Data set 1': 'p75/Results'} file_format = 'h5' # Define age groups which will be considered through filtering # Keep keys and values as well as its assignment constant, remove entries @@ -29,77 +90,21 @@ filter_age = ['Group' + str(key) for key in age_groups.keys()] relative = True - num_days = 92 - for day in range(0, num_days): - - i = 0 - for file in files_input.values(): - # MEmilio backend hdf5 example - - deads = [24, 25, 26] - susceptible = [0, 1, 23] - infected = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22] - - df = pm.extract_data( - file, region_spec=None, column=None, date=day, - filters={'Group': filter_age, 'InfectionState': infected}, - file_format=file_format) - - if relative: - - try: - population = pd.read_json( - 'data/pydata/Germany/county_current_population.json') - # pandas>1.5 raise FileNotFoundError instead of ValueError - except (ValueError, FileNotFoundError): - print( - "Population data was not found. Download it from the internet.") - population = gpd.get_population_data( - read_data=False, file_format=file_format, - out_folder='data/pydata/Germany/', no_raw=True, - split_gender=False, merge_eisenach=True) - - # For fitting of different age groups we need format ">X". - age_group_values = list(age_groups.values()) - age_group_values[-1] = age_group_values[-1].replace( - '80+', '>79') - # scale data - df = pm.scale_dataframe_relative( - df, age_group_values, population) - - if i == 0: - dfs_all = pd.DataFrame(df.iloc[:, 0]) - - dfs_all[df.columns[-1] + ' ' + str(i)] = df[df.columns[-1]] - i += 1 - - filename = 'customPlot_day_' + str(day) - - dfs_all = dfs_all.apply(pd.to_numeric, errors='coerce') - - dfs_all_sorted = dfs_all.sort_values(by='Region') - dfs_all_sorted = dfs_all_sorted.reset_index(drop=True) - - min_val = dfs_all_sorted[dfs_all_sorted.columns[1:]].min().min() - max_val = dfs_all_sorted[dfs_all_sorted.columns[1:]].max().max() - - pm.plot_map( - dfs_all_sorted, scale_colors=np.array([min_val, max_val]), - legend=['', ''], - title='Synthetic data (relative) day ' + str(day), plot_colorbar=True, - output_path="plot_gif", - fig_name=filename, dpi=300, - outercolor=[205 / 255, 238 / 255, 251 / 255]) + time_steps = pm.extract_time_steps( + files_input[list(files_input.keys())[0]], file_format=file_format) + filename = "file" # create gif frames = [] - for day in range(0, num_days): - filename = 'customPlot_day_' + str(day) + '.png' - output_path = "plot_gif" - image = imageio.v2.imread(os.path.join(output_path, filename)) + output_path = "plot_gif" + + for day in range(0, time_steps): + create_plot_map(day, age_groups, relative, filename, + file_format, files_input, output_path) + image = imageio.v2.imread(os.path.join(output_path, filename + ".png")) frames.append(image) imageio.mimsave(os.path.join(output_path, 'sim.gif'), # output gif frames, # array of input frames - duration=10) # optional: frames per second + duration=10, + loop=0) # optional: frames per second diff --git a/pycode/memilio-plot/memilio/plot/plotMap.py b/pycode/memilio-plot/memilio/plot/plotMap.py index 0fe2f9db92..d95f4131b9 100755 --- a/pycode/memilio-plot/memilio/plot/plotMap.py +++ b/pycode/memilio-plot/memilio/plot/plotMap.py @@ -196,8 +196,8 @@ def extract_data( k += 1 else: - raise gd.ValueError( - "Time point not found for region " + str(regions[i]) + ".") + raise gd.DataError( + "Time point " + str(date) + " not found for region " + str(regions[i]) + ".") # Aggregated or matrix output. if output == 'sum': @@ -211,6 +211,29 @@ def extract_data( raise gd.DataError("Data could not be read in.") +def extract_time_steps(file, file_format='json'): + """ Reads data from a general json or specific hdf5 file as output by the + MEmilio simulation framework and extracts the number of days used. + + @param[in] file Path and filename of file to be read in, relative from current + directory. + @param[in] file_format File format; either json or h5. + @return Number of time steps. + """ + input_file = os.path.join(os.getcwd(), str(file)) + if file_format == 'json': + df = pd.read_json(input_file + '.' + file_format) + if 'Date' in df.columns: + time_steps = df['Date'].nunique() + else: + time_steps = 1 + elif file_format == 'h5': + h5file = h5py.File(input_file + '.' + file_format, 'r') + regions = list(h5file.keys()) + time_steps = len(h5file[regions[0]]['Time']) + return time_steps + + def scale_dataframe_relative(df, age_groups, df_population): """! Scales a population-related data frame relative to the size of the local populations or subpopulations (e.g., if not all age groups are From 9525d17fef1b1c7f82c00d9c534dc169a712cfbe Mon Sep 17 00:00:00 2001 From: MaxBetzDLR Date: Fri, 8 Sep 2023 14:27:18 +0200 Subject: [PATCH 04/13] Use temporary files --- pycode/examples/plot/plotGifResultsGermany.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/pycode/examples/plot/plotGifResultsGermany.py b/pycode/examples/plot/plotGifResultsGermany.py index 937e5e0cdc..d1cc8373d0 100644 --- a/pycode/examples/plot/plotGifResultsGermany.py +++ b/pycode/examples/plot/plotGifResultsGermany.py @@ -1,6 +1,7 @@ import datetime as dt import os.path import imageio +import tempfile import numpy as np import pandas as pd @@ -8,6 +9,9 @@ import memilio.epidata.getPopulationData as gpd import memilio.plot.plotMap as pm from memilio.epidata import geoModificationGermany as geoger +import memilio.epidata.progress_indicator as progind +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) def create_plot_map(day, age_groups, relative, filename, file_format, files_input, output_path): @@ -64,7 +68,7 @@ def create_plot_map(day, age_groups, relative, filename, file_format, files_inpu pm.plot_map( dfs_all_sorted, scale_colors=np.array([min_val, max_val]), legend=['', ''], - title='Synthetic data (relative) day ' + str(day), plot_colorbar=True, + title='Synthetic data (relative) day ' + f'{day:2d}', plot_colorbar=True, output_path=output_path, fig_name=filename, dpi=300, outercolor=[205 / 255, 238 / 255, 251 / 255]) @@ -90,19 +94,23 @@ def create_plot_map(day, age_groups, relative, filename, file_format, files_inpu filter_age = ['Group' + str(key) for key in age_groups.keys()] relative = True - time_steps = pm.extract_time_steps( + num_days = pm.extract_time_steps( files_input[list(files_input.keys())[0]], file_format=file_format) - filename = "file" # create gif frames = [] + filename = 'tempplot' output_path = "plot_gif" - for day in range(0, time_steps): - create_plot_map(day, age_groups, relative, filename, - file_format, files_input, output_path) - image = imageio.v2.imread(os.path.join(output_path, filename + ".png")) - frames.append(image) + with progind.Percentage() as indicator: + with tempfile.TemporaryDirectory() as tmpdirname: + for day in range(0, num_days): + create_plot_map(day, age_groups, relative, filename, + file_format, files_input, tmpdirname) + image = imageio.v2.imread( + os.path.join(tmpdirname, filename + ".png")) + frames.append(image) + indicator.set_progress((day+1)/num_days) imageio.mimsave(os.path.join(output_path, 'sim.gif'), # output gif frames, # array of input frames From 9b41ab43c55f2bb31b5443a5eae75a396851d6fb Mon Sep 17 00:00:00 2001 From: MaxBetzDLR Date: Tue, 12 Sep 2023 13:18:34 +0200 Subject: [PATCH 05/13] Test --- pycode/memilio-plot/memilio/plot/plotMap.py | 2 +- pycode/memilio-plot/memilio/plot_test/test_plot_plotMap.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/pycode/memilio-plot/memilio/plot/plotMap.py b/pycode/memilio-plot/memilio/plot/plotMap.py index d95f4131b9..abfc8af0e7 100755 --- a/pycode/memilio-plot/memilio/plot/plotMap.py +++ b/pycode/memilio-plot/memilio/plot/plotMap.py @@ -370,7 +370,7 @@ def plot_map(data: pd.DataFrame, # Use top row for title. tax = fig.add_subplot(gs[0, :]) tax.set_axis_off() - tax.set_title(title, fontsize=18) + tax.set_title(title, fontsize=16) if plot_colorbar: # Prepare colorbar. cax = fig.add_subplot(gs[1, 0]) diff --git a/pycode/memilio-plot/memilio/plot_test/test_plot_plotMap.py b/pycode/memilio-plot/memilio/plot_test/test_plot_plotMap.py index 2de1fd29ee..aee4587546 100644 --- a/pycode/memilio-plot/memilio/plot_test/test_plot_plotMap.py +++ b/pycode/memilio-plot/memilio/plot_test/test_plot_plotMap.py @@ -59,6 +59,12 @@ class TestPlotMap(fake_filesystem_unittest.TestCase): age_groups = {0: '0-4', 1: '5-14', 2: '15-34', 3: '35-59', 4: '60-79', 5: '80+'} + def test_extract_time_steps(self): + for file in self.files_input.values(): + num_days = pm.extract_time_steps( + file, file_format=self.file_format) + assert num_days == 1 + def test_extract_data(self): filter_age = None i = 0 From eecc68ae571ebd49b1dbbb7e1dfcb00c3b99dae2 Mon Sep 17 00:00:00 2001 From: HenrZu Date: Wed, 20 Sep 2023 11:49:51 +0200 Subject: [PATCH 06/13] fix naming. TODO: add dataset --- pycode/examples/plot/plotGifResultsGermany.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pycode/examples/plot/plotGifResultsGermany.py b/pycode/examples/plot/plotGifResultsGermany.py index d1cc8373d0..d6db6aa1ab 100644 --- a/pycode/examples/plot/plotGifResultsGermany.py +++ b/pycode/examples/plot/plotGifResultsGermany.py @@ -76,7 +76,7 @@ def create_plot_map(day, age_groups, relative, filename, file_format, files_inpu if __name__ == '__main__': - files_input = {'Data set 1': 'p75/Results'} + files_input = {'Data set': 'p75/Results'} file_format = 'h5' # Define age groups which will be considered through filtering # Keep keys and values as well as its assignment constant, remove entries From c03bd928fbdb73c3058d29210bfa070b00998a71 Mon Sep 17 00:00:00 2001 From: HenrZu Date: Wed, 6 Dec 2023 14:31:22 +0100 Subject: [PATCH 07/13] adjust init, function to own file --- pycode/memilio-plot/README.md | 3 + .../memilio/plot/createGIF.py} | 80 +++++++++++++++++-- pycode/memilio-plot/setup.py | 5 +- 3 files changed, 79 insertions(+), 9 deletions(-) rename pycode/{examples/plot/plotGifResultsGermany.py => memilio-plot/memilio/plot/createGIF.py} (54%) diff --git a/pycode/memilio-plot/README.md b/pycode/memilio-plot/README.md index 495714dae4..ba512d8d20 100644 --- a/pycode/memilio-plot/README.md +++ b/pycode/memilio-plot/README.md @@ -59,6 +59,9 @@ Required python packages: - mapclassify - geopandas - h5py +- imageio +- tempfile +- datetime Testing and Coverage -------------------- diff --git a/pycode/examples/plot/plotGifResultsGermany.py b/pycode/memilio-plot/memilio/plot/createGIF.py similarity index 54% rename from pycode/examples/plot/plotGifResultsGermany.py rename to pycode/memilio-plot/memilio/plot/createGIF.py index d6db6aa1ab..b081d8b04d 100644 --- a/pycode/examples/plot/plotGifResultsGermany.py +++ b/pycode/memilio-plot/memilio/plot/createGIF.py @@ -1,3 +1,23 @@ +############################################################################# +# Copyright (C) 2020-2024 MEmilio +# +# Authors: Henrik Zunker, Maximilian Betz +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + import datetime as dt import os.path import imageio @@ -14,20 +34,24 @@ warnings.simplefilter(action='ignore', category=FutureWarning) -def create_plot_map(day, age_groups, relative, filename, file_format, files_input, output_path): +def create_plot_map(day, filename, files_input, output_path, compartments, file_format='h5', relative=False, age_groups={0: '0-4', 1: '5-14', 2: '15-34', 3: '35-59', 4: '60-79', 5: '80+'}): + """! Plots region-specific information for a single day of the simulation. + @param[in] day Day of the simulation. + @param[in] filename Name of the file to be created. + @param[in] files_input Dictionary of input files. + @param[in] output_path Output path for the figure. + @param[in] compartments List of compartments to be plotted. + @param[in] file_format Format of the file to be created. Either 'h5' or 'json'. + @param[in] relative Defines if data should be scaled relative to population. + @param[in] age_groups Dictionary of age groups to be considered. + """ i = 0 for file in files_input.values(): - # MEmilio backend hdf5 example - - deads = [24, 25, 26] - susceptible = [0, 1, 23] - infected = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22] df = pm.extract_data( file, region_spec=None, column=None, date=day, - filters={'Group': filter_age, 'InfectionState': infected}, + filters={'Group': filter_age, 'InfectionState': compartments}, file_format=file_format) if relative: @@ -74,6 +98,46 @@ def create_plot_map(day, age_groups, relative, filename, file_format, files_inpu outercolor=[205 / 255, 238 / 255, 251 / 255]) +def create_gif_map_plot(output_dir, filename="simulation", relative=False): + """! Creates a gif of the simulation results by calling create_plot_map for each day of the simulation and then storing the single plots in a temporary directory. + """ + + files_input = {'Data set': 'p75/Results'} + file_format = 'h5' + + # Define age groups which will be considered through filtering + age_groups = {0: '0-4', 1: '5-14', 2: '15-34', + 3: '35-59', 4: '60-79', 5: '80+'} + if len(age_groups) == 6: + filter_age = None + else: + if file_format == 'json': + filter_age = [val for val in age_groups.values()] + else: + filter_age = ['Group' + str(key) for key in age_groups.keys()] + + # TODO: not working for json + num_days = pm.extract_time_steps( + files_input[list(files_input.keys())[0]], file_format=file_format) + + # create gif + frames = [] + with progind.Percentage() as indicator: + with tempfile.TemporaryDirectory() as tmpdirname: + for day in range(0, num_days): + create_plot_map(day, age_groups, relative, filename, + file_format, files_input, tmpdirname) + image = imageio.v2.imread( + os.path.join(tmpdirname, "tempplot.png")) + frames.append(image) + indicator.set_progress((day+1)/num_days) + + imageio.mimsave(os.path.join(output_dir, filename + '.gif'), + frames, # array of input frames + duration=10, # duration of each frame in milliseconds + loop=0) # optional: frames per second + + if __name__ == '__main__': files_input = {'Data set': 'p75/Results'} diff --git a/pycode/memilio-plot/setup.py b/pycode/memilio-plot/setup.py index 98e302b1c2..6e55837eb8 100644 --- a/pycode/memilio-plot/setup.py +++ b/pycode/memilio-plot/setup.py @@ -76,7 +76,10 @@ def run(self): 'matplotlib', 'mapclassify', 'geopandas', - 'h5py' + 'h5py', + 'imageio', + 'tempfile', + 'datetime' ], extras_require={ 'dev': [ From 3877acebc1ec32bc2cc3a180e50322101eec0837 Mon Sep 17 00:00:00 2001 From: HenrZu Date: Wed, 6 Dec 2023 15:56:40 +0100 Subject: [PATCH 08/13] some tests, use in 2020sim, still problems with tmp dir --- .../2020_npis_sarscov2_wildtype_germany.py | 8 +- pycode/memilio-plot/README.md | 1 - pycode/memilio-plot/memilio/plot/createGIF.py | 78 ++++++------------- .../memilio/plot_test/test_plot_createGIF.py | 58 ++++++++++++++ pycode/memilio-plot/setup.py | 1 - 5 files changed, 88 insertions(+), 58 deletions(-) create mode 100644 pycode/memilio-plot/memilio/plot_test/test_plot_createGIF.py diff --git a/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py b/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py index 1a9c44a3cb..eba505ec57 100644 --- a/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py +++ b/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py @@ -3,6 +3,7 @@ import os import memilio.simulation as mio import memilio.simulation.secir as secir +import memilio.plot.createGIF as mp from enum import Enum from memilio.simulation.secir import (Model, Simulation, @@ -426,7 +427,7 @@ def get_graph(self, end_date): return graph - def run(self, num_days_sim, num_runs=10, save_graph=True): + def run(self, num_days_sim, num_runs=10, save_graph=True, create_gif=True): mio.set_log_level(mio.LogLevel.Warning) end_date = self.start_date + datetime.timedelta(days=num_days_sim) @@ -459,6 +460,11 @@ def run(self, num_days_sim, num_runs=10, save_graph=True): secir.save_results( ensemble_results, ensemble_params, node_ids, self.results_dir, save_single_runs, save_percentiles) + if create_gif: + # any compartments in the model (see InfectionStates) + compartments = [1, 2, 3, 4] + mp.create_gif_map_plot( + self.results_dir + "/p75", self.results_dir, compartments) return 0 diff --git a/pycode/memilio-plot/README.md b/pycode/memilio-plot/README.md index ba512d8d20..014b8a8221 100644 --- a/pycode/memilio-plot/README.md +++ b/pycode/memilio-plot/README.md @@ -60,7 +60,6 @@ Required python packages: - geopandas - h5py - imageio -- tempfile - datetime Testing and Coverage diff --git a/pycode/memilio-plot/memilio/plot/createGIF.py b/pycode/memilio-plot/memilio/plot/createGIF.py index b081d8b04d..6855501441 100644 --- a/pycode/memilio-plot/memilio/plot/createGIF.py +++ b/pycode/memilio-plot/memilio/plot/createGIF.py @@ -47,6 +47,14 @@ def create_plot_map(day, filename, files_input, output_path, compartments, file """ i = 0 + if len(age_groups) == 6: + filter_age = None + else: + if file_format == 'json': + filter_age = [val for val in age_groups.values()] + else: + filter_age = ['Group' + str(key) for key in age_groups.keys()] + for file in files_input.values(): df = pm.extract_data( @@ -98,11 +106,18 @@ def create_plot_map(day, filename, files_input, output_path, compartments, file outercolor=[205 / 255, 238 / 255, 251 / 255]) -def create_gif_map_plot(output_dir, filename="simulation", relative=False): - """! Creates a gif of the simulation results by calling create_plot_map for each day of the simulation and then storing the single plots in a temporary directory. +def create_gif_map_plot(input_data, output_dir, compartments, filename="simulation", relative=True): + """! Creates a gif of the simulation results by calling create_plot_map for each day of the simulation and then + storing the single plots in a temporary directory. Currently only works for the results created by the parameter study. + + @param[in] input_data Path to the input data. The Path should contain a file called 'Results' which contains + the simulation results. This is the default output folder of the parameter study. + @param[in] output_dir Path where the gif should be stored. + @param[in] filename Name of the temporary file. + @param[in] relative Defines if data should be scaled relative to population. """ - files_input = {'Data set': 'p75/Results'} + files_input = {'Data set': input_data + '/Results'} file_format = 'h5' # Define age groups which will be considered through filtering @@ -111,12 +126,8 @@ def create_gif_map_plot(output_dir, filename="simulation", relative=False): if len(age_groups) == 6: filter_age = None else: - if file_format == 'json': - filter_age = [val for val in age_groups.values()] - else: - filter_age = ['Group' + str(key) for key in age_groups.keys()] + filter_age = ['Group' + str(key) for key in age_groups.keys()] - # TODO: not working for json num_days = pm.extract_time_steps( files_input[list(files_input.keys())[0]], file_format=file_format) @@ -125,10 +136,11 @@ def create_gif_map_plot(output_dir, filename="simulation", relative=False): with progind.Percentage() as indicator: with tempfile.TemporaryDirectory() as tmpdirname: for day in range(0, num_days): - create_plot_map(day, age_groups, relative, filename, - file_format, files_input, tmpdirname) + create_plot_map(day, filename, files_input, tmpdirname, + compartments, file_format, relative, age_groups) + image = imageio.v2.imread( - os.path.join(tmpdirname, "tempplot.png")) + os.path.join(tmpdirname, "filename.png")) frames.append(image) indicator.set_progress((day+1)/num_days) @@ -136,47 +148,3 @@ def create_gif_map_plot(output_dir, filename="simulation", relative=False): frames, # array of input frames duration=10, # duration of each frame in milliseconds loop=0) # optional: frames per second - - -if __name__ == '__main__': - - files_input = {'Data set': 'p75/Results'} - file_format = 'h5' - # Define age groups which will be considered through filtering - # Keep keys and values as well as its assignment constant, remove entries - # if only part of the population should be plotted or considered, e.g., by - # setting: - # age_groups = {1: '5-14', 2: '15-34'} - age_groups = {0: '0-4', 1: '5-14', 2: '15-34', - 3: '35-59', 4: '60-79', 5: '80+'} - if len(age_groups) == 6: - filter_age = None - else: - if file_format == 'json': - filter_age = [val for val in age_groups.values()] - else: - filter_age = ['Group' + str(key) for key in age_groups.keys()] - - relative = True - num_days = pm.extract_time_steps( - files_input[list(files_input.keys())[0]], file_format=file_format) - - # create gif - frames = [] - filename = 'tempplot' - output_path = "plot_gif" - - with progind.Percentage() as indicator: - with tempfile.TemporaryDirectory() as tmpdirname: - for day in range(0, num_days): - create_plot_map(day, age_groups, relative, filename, - file_format, files_input, tmpdirname) - image = imageio.v2.imread( - os.path.join(tmpdirname, filename + ".png")) - frames.append(image) - indicator.set_progress((day+1)/num_days) - - imageio.mimsave(os.path.join(output_path, 'sim.gif'), # output gif - frames, # array of input frames - duration=10, - loop=0) # optional: frames per second diff --git a/pycode/memilio-plot/memilio/plot_test/test_plot_createGIF.py b/pycode/memilio-plot/memilio/plot_test/test_plot_createGIF.py new file mode 100644 index 0000000000..b822f79973 --- /dev/null +++ b/pycode/memilio-plot/memilio/plot_test/test_plot_createGIF.py @@ -0,0 +1,58 @@ +import unittest +from unittest.mock import patch, MagicMock +import pandas as pd +import numpy as np +from memilio.plot import createGIF + + +class TestCreateGif(unittest.TestCase): + + @patch('memilio.plot.createGIF.pm.extract_data') + @patch('memilio.plot.createGIF.gpd.get_population_data') + @patch('memilio.plot.createGIF.pm.scale_dataframe_relative') + @patch('memilio.plot.createGIF.pm.plot_map') + def test_create_plot_map(self, mock_plot_map, mock_scale_dataframe_relative, mock_get_population_data, mock_extract_data): + # Mock the return values of the functions + mock_extract_data.return_value = pd.DataFrame( + {'Region': ['A', 'B'], 'Value': [1, 2]}) + mock_get_population_data.return_value = pd.DataFrame( + {'Region': ['A', 'B'], 'Population': [100, 200]}) + mock_scale_dataframe_relative.return_value = pd.DataFrame( + {'Region': ['A', 'B'], 'Value': [0.01, 0.01]}) + + # Call the function with test parameters + createGIF.create_plot_map(day=1, filename='test', files_input={'file1': 'path1'}, output_path='output', compartments=[ + 'compartment1'], file_format='h5', relative=True, age_groups={0: '0-4', 1: '5-14', 2: '15-34', 3: '35-59', 4: '60-79', 5: '80+'}) + + # Assert that the mocked functions were called with the correct parameters + mock_extract_data.assert_called() + mock_get_population_data.assert_called() + mock_scale_dataframe_relative.assert_called() + mock_plot_map.assert_called_with(pd.DataFrame({'Region': ['A', 'B'], 'Value': [0.01, 0.01]}), scale_colors=np.array([0.01, 0.01]), legend=[ + '', ''], title='Synthetic data (relative) day 1', plot_colorbar=True, output_path='output', fig_name='test', dpi=300, outercolor=[205 / 255, 238 / 255, 251 / 255]) + + @patch('memilio.plot.createGIF.pm.extract_time_steps') + @patch('memilio.plot.createGIF.create_plot_map') + @patch('memilio.plot.createGIF.imageio.v2.imread') + @patch('memilio.plot.createGIF.imageio.mimsave') + @patch('memilio.plot.createGIF.tempfile.TemporaryDirectory') + def test_createGIF(self, mock_temp_dir, mock_mimsave, mock_imread, mock_create_plot_map, mock_extract_time_steps): + # Mock the return values of the functions + mock_extract_time_steps.return_value = 5 + mock_temp_dir.return_value.__enter__.return_value = 'temp_dir' + mock_imread.return_value = 'image' + + # Call the function with test parameters + createGIF.createGIF(files_input={'file1': 'path1'}, output_dir='output', filename='test', compartments=[ + 'compartment1'], file_format='h5', relative=True, age_groups={0: '0-4', 1: '5-14', 2: '15-34', 3: '35-59', 4: '60-79', 5: '80+'}) + + # Assert that the mocked functions were called with the correct parameters + mock_extract_time_steps.assert_called_with('path1', file_format='h5') + mock_create_plot_map.assert_called() + mock_imread.assert_called_with('temp_dir/filename.png') + mock_mimsave.assert_called_with( + 'output/test.gif', ['image', 'image', 'image', 'image', 'image'], duration=10, loop=0) + + +if __name__ == '__main__': + unittest.main() diff --git a/pycode/memilio-plot/setup.py b/pycode/memilio-plot/setup.py index 6e55837eb8..30efd2c76f 100644 --- a/pycode/memilio-plot/setup.py +++ b/pycode/memilio-plot/setup.py @@ -78,7 +78,6 @@ def run(self): 'geopandas', 'h5py', 'imageio', - 'tempfile', 'datetime' ], extras_require={ From 07c431e1f2953e83ee00d98afaa0ee09ac498ffe Mon Sep 17 00:00:00 2001 From: HenrZu Date: Wed, 6 Dec 2023 16:19:10 +0100 Subject: [PATCH 09/13] some final fixes --- .../2020_npis_sarscov2_wildtype_germany.py | 2 +- pycode/memilio-plot/memilio/plot/createGIF.py | 4 ++-- pycode/memilio-plot/memilio/plot/plotMap.py | 19 +++++++++---------- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py b/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py index eba505ec57..4d953cba0b 100644 --- a/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py +++ b/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py @@ -474,5 +474,5 @@ def run(self, num_days_sim, num_runs=10, save_graph=True, create_gif=True): data_dir=os.path.join(file_path, "../../../data"), start_date=datetime.date(year=2020, month=12, day=12), results_dir=os.path.join(file_path, "../../../results_secir")) - num_days_sim = 30 + num_days_sim = 50 sim.run(num_days_sim, num_runs=2) diff --git a/pycode/memilio-plot/memilio/plot/createGIF.py b/pycode/memilio-plot/memilio/plot/createGIF.py index 6855501441..d778bc4f6e 100644 --- a/pycode/memilio-plot/memilio/plot/createGIF.py +++ b/pycode/memilio-plot/memilio/plot/createGIF.py @@ -140,11 +140,11 @@ def create_gif_map_plot(input_data, output_dir, compartments, filename="simulati compartments, file_format, relative, age_groups) image = imageio.v2.imread( - os.path.join(tmpdirname, "filename.png")) + os.path.join(tmpdirname, filename + ".png")) frames.append(image) indicator.set_progress((day+1)/num_days) imageio.mimsave(os.path.join(output_dir, filename + '.gif'), frames, # array of input frames - duration=10, # duration of each frame in milliseconds + duration=0.2, # duration of each frame in seconds loop=0) # optional: frames per second diff --git a/pycode/memilio-plot/memilio/plot/plotMap.py b/pycode/memilio-plot/memilio/plot/plotMap.py index 8ff8f66d28..5580acad49 100755 --- a/pycode/memilio-plot/memilio/plot/plotMap.py +++ b/pycode/memilio-plot/memilio/plot/plotMap.py @@ -299,7 +299,8 @@ def plot_map(data: pd.DataFrame, output_path: str = '', fig_name: str = 'customPlot', dpi: int = 300, - outercolor='white'): + outercolor='white', + log_scale=False): """! Plots region-specific information onto a interactive html map and returning svg and png image. Allows the comparisons of a variable list of data sets. @@ -315,6 +316,7 @@ def plot_map(data: pd.DataFrame, @param[in] fig_name Name of the figure created. @param[in] dpi Dots-per-inch value for the exported figure. @param[in] outercolor Background color of the plot image. + @param[in] log_scale Defines if the colorbar is plotted in log scale. """ region_classifier = data.columns[0] region_data = data[region_classifier].to_numpy().astype(int) @@ -377,28 +379,25 @@ def plot_map(data: pd.DataFrame, else: cax = None - # log scale of colorbar. - # norm = mcolors.LogNorm(vmin=scale_colors[0], vmax=scale_colors[1]) + if log_scale: + norm = mcolors.LogNorm(vmin=scale_colors[0], vmax=scale_colors[1]) for i in range(len(data_columns)): ax = fig.add_subplot(gs[:, i+2]) - if cax is not None: + if log_scale: + map_data.plot(data_columns[i], ax=ax, cax=cax, legend=True, + norm=norm) + elif cax is not None: map_data.plot(data_columns[i], ax=ax, cax=cax, legend=True, vmin=scale_colors[0], vmax=scale_colors[1]) - # map_data.plot(data_columns[i], ax=ax, cax=cax, legend=True, norm=norm) else: # Do not plot colorbar. map_data.plot(data_columns[i], ax=ax, legend=False, vmin=scale_colors[0], vmax=scale_colors[1]) - # map_data.plot(data_columns[i], ax=ax, cax=cax, legend=True, norm=norm) ax.set_title(legend[i], fontsize=12) ax.set_axis_off() plt.subplots_adjust(bottom=0.1) - plt.savefig(os.path.join(output_path, fig_name + '.png'), dpi=dpi) - plt.savefig(os.path.join(output_path, fig_name + '.svg'), dpi=dpi) - - plt.show() From b38be5136e1268f859dab509916df79acba40765 Mon Sep 17 00:00:00 2001 From: HenrZu Date: Wed, 6 Dec 2023 16:27:54 +0100 Subject: [PATCH 10/13] save memory --- pycode/memilio-plot/memilio/plot/createGIF.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pycode/memilio-plot/memilio/plot/createGIF.py b/pycode/memilio-plot/memilio/plot/createGIF.py index d778bc4f6e..5578dac17d 100644 --- a/pycode/memilio-plot/memilio/plot/createGIF.py +++ b/pycode/memilio-plot/memilio/plot/createGIF.py @@ -25,6 +25,7 @@ import numpy as np import pandas as pd +import matplotlib.pyplot as plt import memilio.epidata.getPopulationData as gpd import memilio.plot.plotMap as pm @@ -142,6 +143,9 @@ def create_gif_map_plot(input_data, output_dir, compartments, filename="simulati image = imageio.v2.imread( os.path.join(tmpdirname, filename + ".png")) frames.append(image) + + # Close the current figure to free up memory + plt.close('all') indicator.set_progress((day+1)/num_days) imageio.mimsave(os.path.join(output_dir, filename + '.gif'), From a9f1503d3f7a678961c5e3e4978e1b902ab14836 Mon Sep 17 00:00:00 2001 From: HenrZu Date: Thu, 7 Dec 2023 08:51:44 +0100 Subject: [PATCH 11/13] tests now working --- .../memilio/plot_test/test_plot_createGIF.py | 70 +++++++++++++------ 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/pycode/memilio-plot/memilio/plot_test/test_plot_createGIF.py b/pycode/memilio-plot/memilio/plot_test/test_plot_createGIF.py index b822f79973..5ae9bab443 100644 --- a/pycode/memilio-plot/memilio/plot_test/test_plot_createGIF.py +++ b/pycode/memilio-plot/memilio/plot_test/test_plot_createGIF.py @@ -8,17 +8,17 @@ class TestCreateGif(unittest.TestCase): @patch('memilio.plot.createGIF.pm.extract_data') - @patch('memilio.plot.createGIF.gpd.get_population_data') + @patch('pandas.read_json') @patch('memilio.plot.createGIF.pm.scale_dataframe_relative') @patch('memilio.plot.createGIF.pm.plot_map') - def test_create_plot_map(self, mock_plot_map, mock_scale_dataframe_relative, mock_get_population_data, mock_extract_data): + def test_create_plot_map(self, mock_plot_map, mock_scale_dataframe_relative, mock_read_json, mock_extract_data): # Mock the return values of the functions mock_extract_data.return_value = pd.DataFrame( - {'Region': ['A', 'B'], 'Value': [1, 2]}) - mock_get_population_data.return_value = pd.DataFrame( - {'Region': ['A', 'B'], 'Population': [100, 200]}) + {'Region': [0, 1], 'Value': [1, 2]}) + mock_read_json.return_value = pd.DataFrame( + {'Region': [0, 1], 'Population': [100, 200]}) mock_scale_dataframe_relative.return_value = pd.DataFrame( - {'Region': ['A', 'B'], 'Value': [0.01, 0.01]}) + {'Region': [0, 1], 'Value': [0.01, 0.01]}) # Call the function with test parameters createGIF.create_plot_map(day=1, filename='test', files_input={'file1': 'path1'}, output_path='output', compartments=[ @@ -26,32 +26,62 @@ def test_create_plot_map(self, mock_plot_map, mock_scale_dataframe_relative, moc # Assert that the mocked functions were called with the correct parameters mock_extract_data.assert_called() - mock_get_population_data.assert_called() + mock_read_json.assert_called() mock_scale_dataframe_relative.assert_called() - mock_plot_map.assert_called_with(pd.DataFrame({'Region': ['A', 'B'], 'Value': [0.01, 0.01]}), scale_colors=np.array([0.01, 0.01]), legend=[ - '', ''], title='Synthetic data (relative) day 1', plot_colorbar=True, output_path='output', fig_name='test', dpi=300, outercolor=[205 / 255, 238 / 255, 251 / 255]) + + assert mock_plot_map.called + + # Get the arguments passed to the mock and assert that they are correct + args, kwargs = mock_plot_map.call_args + + pd.testing.assert_frame_equal(args[0], pd.DataFrame( + {'Region': [0, 1], 'Value 0': [0.01, 0.01]})) + np.testing.assert_array_equal( + kwargs['scale_colors'], np.array([0.01, 0.01])) + self.assertEqual(kwargs['legend'], ['', '']) + self.assertEqual(kwargs['title'], 'Synthetic data (relative) day 1') + self.assertEqual(kwargs['plot_colorbar'], True) + self.assertEqual(kwargs['output_path'], 'output') + self.assertEqual(kwargs['fig_name'], 'test') + self.assertEqual(kwargs['dpi'], 300) + self.assertEqual(kwargs['outercolor'], [ + 205 / 255, 238 / 255, 251 / 255]) @patch('memilio.plot.createGIF.pm.extract_time_steps') @patch('memilio.plot.createGIF.create_plot_map') - @patch('memilio.plot.createGIF.imageio.v2.imread') - @patch('memilio.plot.createGIF.imageio.mimsave') - @patch('memilio.plot.createGIF.tempfile.TemporaryDirectory') - def test_createGIF(self, mock_temp_dir, mock_mimsave, mock_imread, mock_create_plot_map, mock_extract_time_steps): + @patch('tempfile.TemporaryDirectory') + @patch('imageio.v2.imread') + @patch('imageio.mimsave') + def test_create_gif_map_plot(self, mock_mimsave, mock_imread, mock_tempdir, mock_create_plot_map, mock_extract_time_steps): # Mock the return values of the functions - mock_extract_time_steps.return_value = 5 - mock_temp_dir.return_value.__enter__.return_value = 'temp_dir' + mock_extract_time_steps.return_value = 10 + mock_tempdir.return_value.__enter__.return_value = 'tempdir' mock_imread.return_value = 'image' # Call the function with test parameters - createGIF.createGIF(files_input={'file1': 'path1'}, output_dir='output', filename='test', compartments=[ - 'compartment1'], file_format='h5', relative=True, age_groups={0: '0-4', 1: '5-14', 2: '15-34', 3: '35-59', 4: '60-79', 5: '80+'}) + createGIF.create_gif_map_plot(input_data='input', output_dir='output', compartments=[ + 'compartment1'], filename='test') # Assert that the mocked functions were called with the correct parameters - mock_extract_time_steps.assert_called_with('path1', file_format='h5') + mock_extract_time_steps.assert_called_with( + 'input/Results', file_format='h5') mock_create_plot_map.assert_called() - mock_imread.assert_called_with('temp_dir/filename.png') + mock_tempdir.assert_called() + mock_imread.assert_called_with('tempdir/test.png') mock_mimsave.assert_called_with( - 'output/test.gif', ['image', 'image', 'image', 'image', 'image'], duration=10, loop=0) + 'output/test.gif', ['image']*10, duration=0.2, loop=0) + + # Get the arguments passed to the mock and assert that they are correct + args, kwargs = mock_create_plot_map.call_args + self.assertEqual(args[0], 9) + self.assertEqual(args[1], 'test') + self.assertEqual(args[2], {'Data set': 'input/Results'}) + self.assertEqual(args[3], 'tempdir') + self.assertEqual(args[4], ['compartment1']) + self.assertEqual(args[5], 'h5') + self.assertEqual(args[6], True) + self.assertEqual(args[7], { + 0: '0-4', 1: '5-14', 2: '15-34', 3: '35-59', 4: '60-79', 5: '80+'}) if __name__ == '__main__': From 7b72dbf249c88bd9d487a9f5c7636f5962d7be32 Mon Sep 17 00:00:00 2001 From: Henrik Zunker <69154294+HenrZu@users.noreply.github.com> Date: Thu, 7 Dec 2023 16:11:33 +0100 Subject: [PATCH 12/13] Apply suggestions from code review Co-authored-by: jubicker <113909589+jubicker@users.noreply.github.com> --- .../examples/simulation/2020_npis_sarscov2_wildtype_germany.py | 2 +- pycode/memilio-plot/memilio/plot/createGIF.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py b/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py index 4d953cba0b..322ecf8411 100644 --- a/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py +++ b/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py @@ -462,7 +462,7 @@ def run(self, num_days_sim, num_runs=10, save_graph=True, create_gif=True): save_single_runs, save_percentiles) if create_gif: # any compartments in the model (see InfectionStates) - compartments = [1, 2, 3, 4] + compartments = [c for c in range(1,8)] mp.create_gif_map_plot( self.results_dir + "/p75", self.results_dir, compartments) return 0 diff --git a/pycode/memilio-plot/memilio/plot/createGIF.py b/pycode/memilio-plot/memilio/plot/createGIF.py index 5578dac17d..9c4dd7deba 100644 --- a/pycode/memilio-plot/memilio/plot/createGIF.py +++ b/pycode/memilio-plot/memilio/plot/createGIF.py @@ -71,7 +71,7 @@ def create_plot_map(day, filename, files_input, output_path, compartments, file # pandas>1.5 raise FileNotFoundError instead of ValueError except (ValueError, FileNotFoundError): print( - "Population data was not found. Download it from the internet.") + "Population data was not found. Downloading it from the internet.") population = gpd.get_population_data( read_data=False, file_format=file_format, out_folder='data/pydata/Germany/', no_raw=True, merge_eisenach=True) From 9177377894a1629fd42d25c46414f3864ffd67ed Mon Sep 17 00:00:00 2001 From: HenrZu Date: Thu, 7 Dec 2023 16:29:06 +0100 Subject: [PATCH 13/13] review suggestions --- .../2020_npis_sarscov2_wildtype_germany.py | 2 +- pycode/memilio-plot/memilio/plot/createGIF.py | 17 +++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py b/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py index 322ecf8411..e8ecc3f1fa 100644 --- a/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py +++ b/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py @@ -462,7 +462,7 @@ def run(self, num_days_sim, num_runs=10, save_graph=True, create_gif=True): save_single_runs, save_percentiles) if create_gif: # any compartments in the model (see InfectionStates) - compartments = [c for c in range(1,8)] + compartments = [c for c in range(1, 8)] mp.create_gif_map_plot( self.results_dir + "/p75", self.results_dir, compartments) return 0 diff --git a/pycode/memilio-plot/memilio/plot/createGIF.py b/pycode/memilio-plot/memilio/plot/createGIF.py index 9c4dd7deba..340eed4224 100644 --- a/pycode/memilio-plot/memilio/plot/createGIF.py +++ b/pycode/memilio-plot/memilio/plot/createGIF.py @@ -47,7 +47,6 @@ def create_plot_map(day, filename, files_input, output_path, compartments, file @param[in] age_groups Dictionary of age groups to be considered. """ - i = 0 if len(age_groups) == 6: filter_age = None else: @@ -56,6 +55,9 @@ def create_plot_map(day, filename, files_input, output_path, compartments, file else: filter_age = ['Group' + str(key) for key in age_groups.keys()] + # In file_input there can be two different files. When we enter two files, + # both files are plotted side by side in the same figure. + file_index = 0 for file in files_input.values(): df = pm.extract_data( @@ -84,11 +86,11 @@ def create_plot_map(day, filename, files_input, output_path, compartments, file df = pm.scale_dataframe_relative( df, age_group_values, population) - if i == 0: + if file_index == 0: dfs_all = pd.DataFrame(df.iloc[:, 0]) - dfs_all[df.columns[-1] + ' ' + str(i)] = df[df.columns[-1]] - i += 1 + dfs_all[df.columns[-1] + ' ' + str(file_index)] = df[df.columns[-1]] + file_index += 1 dfs_all = dfs_all.apply(pd.to_numeric, errors='coerce') @@ -107,7 +109,8 @@ def create_plot_map(day, filename, files_input, output_path, compartments, file outercolor=[205 / 255, 238 / 255, 251 / 255]) -def create_gif_map_plot(input_data, output_dir, compartments, filename="simulation", relative=True): +def create_gif_map_plot(input_data, output_dir, compartments, filename="simulation", relative=True, age_groups={0: '0-4', 1: '5-14', 2: '15-34', + 3: '35-59', 4: '60-79', 5: '80+'}): """! Creates a gif of the simulation results by calling create_plot_map for each day of the simulation and then storing the single plots in a temporary directory. Currently only works for the results created by the parameter study. @@ -116,14 +119,12 @@ def create_gif_map_plot(input_data, output_dir, compartments, filename="simulati @param[in] output_dir Path where the gif should be stored. @param[in] filename Name of the temporary file. @param[in] relative Defines if data should be scaled relative to population. + @param[in] age_groups Dictionary of age groups to be considered. """ files_input = {'Data set': input_data + '/Results'} file_format = 'h5' - # Define age groups which will be considered through filtering - age_groups = {0: '0-4', 1: '5-14', 2: '15-34', - 3: '35-59', 4: '60-79', 5: '80+'} if len(age_groups) == 6: filter_age = None else: