diff --git a/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py b/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py index 1a9c44a3cb..e8ecc3f1fa 100644 --- a/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py +++ b/pycode/examples/simulation/2020_npis_sarscov2_wildtype_germany.py @@ -3,6 +3,7 @@ import os import memilio.simulation as mio import memilio.simulation.secir as secir +import memilio.plot.createGIF as mp from enum import Enum from memilio.simulation.secir import (Model, Simulation, @@ -426,7 +427,7 @@ def get_graph(self, end_date): return graph - def run(self, num_days_sim, num_runs=10, save_graph=True): + def run(self, num_days_sim, num_runs=10, save_graph=True, create_gif=True): mio.set_log_level(mio.LogLevel.Warning) end_date = self.start_date + datetime.timedelta(days=num_days_sim) @@ -459,6 +460,11 @@ def run(self, num_days_sim, num_runs=10, save_graph=True): secir.save_results( ensemble_results, ensemble_params, node_ids, self.results_dir, save_single_runs, save_percentiles) + if create_gif: + # any compartments in the model (see InfectionStates) + compartments = [c for c in range(1, 8)] + mp.create_gif_map_plot( + self.results_dir + "/p75", self.results_dir, compartments) return 0 @@ -468,5 +474,5 @@ def run(self, num_days_sim, num_runs=10, save_graph=True): data_dir=os.path.join(file_path, "../../../data"), start_date=datetime.date(year=2020, month=12, day=12), results_dir=os.path.join(file_path, "../../../results_secir")) - num_days_sim = 30 + num_days_sim = 50 sim.run(num_days_sim, num_runs=2) diff --git a/pycode/memilio-plot/README.md b/pycode/memilio-plot/README.md index 495714dae4..014b8a8221 100644 --- a/pycode/memilio-plot/README.md +++ b/pycode/memilio-plot/README.md @@ -59,6 +59,8 @@ Required python packages: - mapclassify - geopandas - h5py +- imageio +- datetime Testing and Coverage -------------------- diff --git a/pycode/memilio-plot/memilio/plot/createGIF.py b/pycode/memilio-plot/memilio/plot/createGIF.py new file mode 100644 index 0000000000..340eed4224 --- /dev/null +++ b/pycode/memilio-plot/memilio/plot/createGIF.py @@ -0,0 +1,155 @@ +############################################################################# +# Copyright (C) 2020-2024 MEmilio +# +# Authors: Henrik Zunker, Maximilian Betz +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +import datetime as dt +import os.path +import imageio +import tempfile + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +import memilio.epidata.getPopulationData as gpd +import memilio.plot.plotMap as pm +from memilio.epidata import geoModificationGermany as geoger +import memilio.epidata.progress_indicator as progind +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) + + +def create_plot_map(day, filename, files_input, output_path, compartments, file_format='h5', relative=False, age_groups={0: '0-4', 1: '5-14', 2: '15-34', 3: '35-59', 4: '60-79', 5: '80+'}): + """! Plots region-specific information for a single day of the simulation. + @param[in] day Day of the simulation. + @param[in] filename Name of the file to be created. + @param[in] files_input Dictionary of input files. + @param[in] output_path Output path for the figure. + @param[in] compartments List of compartments to be plotted. + @param[in] file_format Format of the file to be created. Either 'h5' or 'json'. + @param[in] relative Defines if data should be scaled relative to population. + @param[in] age_groups Dictionary of age groups to be considered. + """ + + if len(age_groups) == 6: + filter_age = None + else: + if file_format == 'json': + filter_age = [val for val in age_groups.values()] + else: + filter_age = ['Group' + str(key) for key in age_groups.keys()] + + # In file_input there can be two different files. When we enter two files, + # both files are plotted side by side in the same figure. + file_index = 0 + for file in files_input.values(): + + df = pm.extract_data( + file, region_spec=None, column=None, date=day, + filters={'Group': filter_age, 'InfectionState': compartments}, + file_format=file_format) + + if relative: + + try: + population = pd.read_json( + 'data/pydata/Germany/county_current_population.json') + # pandas>1.5 raise FileNotFoundError instead of ValueError + except (ValueError, FileNotFoundError): + print( + "Population data was not found. Downloading it from the internet.") + population = gpd.get_population_data( + read_data=False, file_format=file_format, + out_folder='data/pydata/Germany/', no_raw=True, merge_eisenach=True) + + # For fitting of different age groups we need format ">X". + age_group_values = list(age_groups.values()) + age_group_values[-1] = age_group_values[-1].replace( + '80+', '>79') + # scale data + df = pm.scale_dataframe_relative( + df, age_group_values, population) + + if file_index == 0: + dfs_all = pd.DataFrame(df.iloc[:, 0]) + + dfs_all[df.columns[-1] + ' ' + str(file_index)] = df[df.columns[-1]] + file_index += 1 + + dfs_all = dfs_all.apply(pd.to_numeric, errors='coerce') + + dfs_all_sorted = dfs_all.sort_values(by='Region') + dfs_all_sorted = dfs_all_sorted.reset_index(drop=True) + + min_val = dfs_all_sorted[dfs_all_sorted.columns[1:]].min().min() + max_val = dfs_all_sorted[dfs_all_sorted.columns[1:]].max().max() + + pm.plot_map( + dfs_all_sorted, scale_colors=np.array([min_val, max_val]), + legend=['', ''], + title='Synthetic data (relative) day ' + f'{day:2d}', plot_colorbar=True, + output_path=output_path, + fig_name=filename, dpi=300, + outercolor=[205 / 255, 238 / 255, 251 / 255]) + + +def create_gif_map_plot(input_data, output_dir, compartments, filename="simulation", relative=True, age_groups={0: '0-4', 1: '5-14', 2: '15-34', + 3: '35-59', 4: '60-79', 5: '80+'}): + """! Creates a gif of the simulation results by calling create_plot_map for each day of the simulation and then + storing the single plots in a temporary directory. Currently only works for the results created by the parameter study. + + @param[in] input_data Path to the input data. The Path should contain a file called 'Results' which contains + the simulation results. This is the default output folder of the parameter study. + @param[in] output_dir Path where the gif should be stored. + @param[in] filename Name of the temporary file. + @param[in] relative Defines if data should be scaled relative to population. + @param[in] age_groups Dictionary of age groups to be considered. + """ + + files_input = {'Data set': input_data + '/Results'} + file_format = 'h5' + + if len(age_groups) == 6: + filter_age = None + else: + filter_age = ['Group' + str(key) for key in age_groups.keys()] + + num_days = pm.extract_time_steps( + files_input[list(files_input.keys())[0]], file_format=file_format) + + # create gif + frames = [] + with progind.Percentage() as indicator: + with tempfile.TemporaryDirectory() as tmpdirname: + for day in range(0, num_days): + create_plot_map(day, filename, files_input, tmpdirname, + compartments, file_format, relative, age_groups) + + image = imageio.v2.imread( + os.path.join(tmpdirname, filename + ".png")) + frames.append(image) + + # Close the current figure to free up memory + plt.close('all') + indicator.set_progress((day+1)/num_days) + + imageio.mimsave(os.path.join(output_dir, filename + '.gif'), + frames, # array of input frames + duration=0.2, # duration of each frame in seconds + loop=0) # optional: frames per second diff --git a/pycode/memilio-plot/memilio/plot/plotMap.py b/pycode/memilio-plot/memilio/plot/plotMap.py index a1b60ea315..5580acad49 100755 --- a/pycode/memilio-plot/memilio/plot/plotMap.py +++ b/pycode/memilio-plot/memilio/plot/plotMap.py @@ -27,6 +27,7 @@ import pandas as pd from matplotlib import pyplot as plt from matplotlib.gridspec import GridSpec +import matplotlib.colors as mcolors from memilio.epidata import geoModificationGermany as geoger from memilio.epidata import getDataIntoPandasDataFrame as gd @@ -143,11 +144,14 @@ def extract_data( # Set no filtering if filters were set to None. if filters == None: - filters['Group'] = list(h5file[regions[i]].keys())[ + filters['Group'] = list(h5file[regions[0]].keys())[ :-2] # Remove 'Time' and 'Total'. filters['InfectionState'] = list( range(h5file[regions[i]]['Group1'].shape[1])) + if filters['Group'] == None: + filters['Group'] = list(h5file[regions[0]].keys())[:-2] + InfectionStateList = [j for j in filters['InfectionState']] # Create data frame to store results to plot. @@ -192,8 +196,8 @@ def extract_data( k += 1 else: - raise gd.ValueError( - "Time point not found for region " + str(regions[i]) + ".") + raise gd.DataError( + "Time point " + str(date) + " not found for region " + str(regions[i]) + ".") # Aggregated or matrix output. if output == 'sum': @@ -207,6 +211,29 @@ def extract_data( raise gd.DataError("Data could not be read in.") +def extract_time_steps(file, file_format='json'): + """ Reads data from a general json or specific hdf5 file as output by the + MEmilio simulation framework and extracts the number of days used. + + @param[in] file Path and filename of file to be read in, relative from current + directory. + @param[in] file_format File format; either json or h5. + @return Number of time steps. + """ + input_file = os.path.join(os.getcwd(), str(file)) + if file_format == 'json': + df = pd.read_json(input_file + '.' + file_format) + if 'Date' in df.columns: + time_steps = df['Date'].nunique() + else: + time_steps = 1 + elif file_format == 'h5': + h5file = h5py.File(input_file + '.' + file_format, 'r') + regions = list(h5file.keys()) + time_steps = len(h5file[regions[0]]['Time']) + return time_steps + + def scale_dataframe_relative(df, age_groups, df_population): """! Scales a population-related data frame relative to the size of the local populations or subpopulations (e.g., if not all age groups are @@ -225,7 +252,7 @@ def scale_dataframe_relative(df, age_groups, df_population): """ # Merge population data of Eisenach (if counted separately) with Wartburgkreis. - if 16056 in df_population[df.columns[0]].values: + if 16056 in df_population['ID_County'].values: for i in range(1, len(df_population.columns)): df_population.loc[df_population[df.columns[0]] == 16063, df_population.columns[i] ] += df_population.loc[df_population.ID_County == 16056, df_population.columns[i]] @@ -235,12 +262,12 @@ def scale_dataframe_relative(df, age_groups, df_population): columns=[df_population.columns[0]] + age_groups) # Extrapolate on oldest age group with maximumg age 100. for region_id in df.iloc[:, 0]: - df_population_agegroups.loc[len(df_population_agegroups.index), :] = [region_id] + list( - mdfs.fit_age_group_intervals(df_population[df_population.iloc[:, 0] == region_id].iloc[:, 2:], age_groups)) + df_population_agegroups.loc[len(df_population_agegroups.index), :] = [int(region_id)] + list( + mdfs.fit_age_group_intervals(df_population[df_population.iloc[:, 0] == int(region_id)].iloc[:, 2:], age_groups)) def scale_row(elem): population_local_sum = df_population_agegroups[ - df_population_agegroups[df.columns[0]] == elem[0]].iloc[ + df_population_agegroups['ID_County'] == int(elem[0])].iloc[ :, 1:].sum(axis=1) return elem['Count'] / population_local_sum.values[0] @@ -272,7 +299,8 @@ def plot_map(data: pd.DataFrame, output_path: str = '', fig_name: str = 'customPlot', dpi: int = 300, - outercolor='white'): + outercolor='white', + log_scale=False): """! Plots region-specific information onto a interactive html map and returning svg and png image. Allows the comparisons of a variable list of data sets. @@ -288,12 +316,14 @@ def plot_map(data: pd.DataFrame, @param[in] fig_name Name of the figure created. @param[in] dpi Dots-per-inch value for the exported figure. @param[in] outercolor Background color of the plot image. + @param[in] log_scale Defines if the colorbar is plotted in log scale. """ region_classifier = data.columns[0] + region_data = data[region_classifier].to_numpy().astype(int) data_columns = data.columns[1:] # Read and filter map data. - if data[region_classifier].isin(geoger.get_county_ids()).all(): + if np.isin(region_data, geoger.get_county_ids()).all(): try: map_data = geopandas.read_file( os.path.join( @@ -342,17 +372,23 @@ def plot_map(data: pd.DataFrame, # Use top row for title. tax = fig.add_subplot(gs[0, :]) tax.set_axis_off() - tax.set_title(title, fontsize=18) + tax.set_title(title, fontsize=16) if plot_colorbar: # Prepare colorbar. cax = fig.add_subplot(gs[1, 0]) else: cax = None + if log_scale: + norm = mcolors.LogNorm(vmin=scale_colors[0], vmax=scale_colors[1]) + for i in range(len(data_columns)): ax = fig.add_subplot(gs[:, i+2]) - if cax is not None: + if log_scale: + map_data.plot(data_columns[i], ax=ax, cax=cax, legend=True, + norm=norm) + elif cax is not None: map_data.plot(data_columns[i], ax=ax, cax=cax, legend=True, vmin=scale_colors[0], vmax=scale_colors[1]) else: @@ -364,8 +400,4 @@ def plot_map(data: pd.DataFrame, ax.set_axis_off() plt.subplots_adjust(bottom=0.1) - plt.savefig(os.path.join(output_path, fig_name + '.png'), dpi=dpi) - plt.savefig(os.path.join(output_path, fig_name + '.svg'), dpi=dpi) - - plt.show() diff --git a/pycode/memilio-plot/memilio/plot_test/test_plot_createGIF.py b/pycode/memilio-plot/memilio/plot_test/test_plot_createGIF.py new file mode 100644 index 0000000000..5ae9bab443 --- /dev/null +++ b/pycode/memilio-plot/memilio/plot_test/test_plot_createGIF.py @@ -0,0 +1,88 @@ +import unittest +from unittest.mock import patch, MagicMock +import pandas as pd +import numpy as np +from memilio.plot import createGIF + + +class TestCreateGif(unittest.TestCase): + + @patch('memilio.plot.createGIF.pm.extract_data') + @patch('pandas.read_json') + @patch('memilio.plot.createGIF.pm.scale_dataframe_relative') + @patch('memilio.plot.createGIF.pm.plot_map') + def test_create_plot_map(self, mock_plot_map, mock_scale_dataframe_relative, mock_read_json, mock_extract_data): + # Mock the return values of the functions + mock_extract_data.return_value = pd.DataFrame( + {'Region': [0, 1], 'Value': [1, 2]}) + mock_read_json.return_value = pd.DataFrame( + {'Region': [0, 1], 'Population': [100, 200]}) + mock_scale_dataframe_relative.return_value = pd.DataFrame( + {'Region': [0, 1], 'Value': [0.01, 0.01]}) + + # Call the function with test parameters + createGIF.create_plot_map(day=1, filename='test', files_input={'file1': 'path1'}, output_path='output', compartments=[ + 'compartment1'], file_format='h5', relative=True, age_groups={0: '0-4', 1: '5-14', 2: '15-34', 3: '35-59', 4: '60-79', 5: '80+'}) + + # Assert that the mocked functions were called with the correct parameters + mock_extract_data.assert_called() + mock_read_json.assert_called() + mock_scale_dataframe_relative.assert_called() + + assert mock_plot_map.called + + # Get the arguments passed to the mock and assert that they are correct + args, kwargs = mock_plot_map.call_args + + pd.testing.assert_frame_equal(args[0], pd.DataFrame( + {'Region': [0, 1], 'Value 0': [0.01, 0.01]})) + np.testing.assert_array_equal( + kwargs['scale_colors'], np.array([0.01, 0.01])) + self.assertEqual(kwargs['legend'], ['', '']) + self.assertEqual(kwargs['title'], 'Synthetic data (relative) day 1') + self.assertEqual(kwargs['plot_colorbar'], True) + self.assertEqual(kwargs['output_path'], 'output') + self.assertEqual(kwargs['fig_name'], 'test') + self.assertEqual(kwargs['dpi'], 300) + self.assertEqual(kwargs['outercolor'], [ + 205 / 255, 238 / 255, 251 / 255]) + + @patch('memilio.plot.createGIF.pm.extract_time_steps') + @patch('memilio.plot.createGIF.create_plot_map') + @patch('tempfile.TemporaryDirectory') + @patch('imageio.v2.imread') + @patch('imageio.mimsave') + def test_create_gif_map_plot(self, mock_mimsave, mock_imread, mock_tempdir, mock_create_plot_map, mock_extract_time_steps): + # Mock the return values of the functions + mock_extract_time_steps.return_value = 10 + mock_tempdir.return_value.__enter__.return_value = 'tempdir' + mock_imread.return_value = 'image' + + # Call the function with test parameters + createGIF.create_gif_map_plot(input_data='input', output_dir='output', compartments=[ + 'compartment1'], filename='test') + + # Assert that the mocked functions were called with the correct parameters + mock_extract_time_steps.assert_called_with( + 'input/Results', file_format='h5') + mock_create_plot_map.assert_called() + mock_tempdir.assert_called() + mock_imread.assert_called_with('tempdir/test.png') + mock_mimsave.assert_called_with( + 'output/test.gif', ['image']*10, duration=0.2, loop=0) + + # Get the arguments passed to the mock and assert that they are correct + args, kwargs = mock_create_plot_map.call_args + self.assertEqual(args[0], 9) + self.assertEqual(args[1], 'test') + self.assertEqual(args[2], {'Data set': 'input/Results'}) + self.assertEqual(args[3], 'tempdir') + self.assertEqual(args[4], ['compartment1']) + self.assertEqual(args[5], 'h5') + self.assertEqual(args[6], True) + self.assertEqual(args[7], { + 0: '0-4', 1: '5-14', 2: '15-34', 3: '35-59', 4: '60-79', 5: '80+'}) + + +if __name__ == '__main__': + unittest.main() diff --git a/pycode/memilio-plot/memilio/plot_test/test_plot_plotMap.py b/pycode/memilio-plot/memilio/plot_test/test_plot_plotMap.py index 14803ba07d..b21cb3c53f 100644 --- a/pycode/memilio-plot/memilio/plot_test/test_plot_plotMap.py +++ b/pycode/memilio-plot/memilio/plot_test/test_plot_plotMap.py @@ -59,6 +59,12 @@ class TestPlotMap(fake_filesystem_unittest.TestCase): age_groups = {0: '0-4', 1: '5-14', 2: '15-34', 3: '35-59', 4: '60-79', 5: '80+'} + def test_extract_time_steps(self): + for file in self.files_input.values(): + num_days = pm.extract_time_steps( + file, file_format=self.file_format) + assert num_days == 1 + def test_extract_data(self): filter_age = None i = 0 diff --git a/pycode/memilio-plot/setup.py b/pycode/memilio-plot/setup.py index 98e302b1c2..30efd2c76f 100644 --- a/pycode/memilio-plot/setup.py +++ b/pycode/memilio-plot/setup.py @@ -76,7 +76,9 @@ def run(self): 'matplotlib', 'mapclassify', 'geopandas', - 'h5py' + 'h5py', + 'imageio', + 'datetime' ], extras_require={ 'dev': [