Skip to content

Commit

Permalink
Merge pull request #162 from jmccreight/bug_model_graph
Browse files Browse the repository at this point in the history
fix and fix up model graphs
  • Loading branch information
jmccreight committed May 12, 2023
2 parents adb5a87 + a8b4da2 commit f773770
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 25 deletions.
3 changes: 1 addition & 2 deletions autotest/test_nhm_self_drive.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import pathlib as pl

# import pytest
import pywatershed as pws
import xarray as xr

import pywatershed as pws

n_time_steps = 50

Expand Down
27 changes: 15 additions & 12 deletions pywatershed/analysis/model_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pathlib as pl
import tempfile
import warnings
from pprint import pprint

from ..base.model import Model

Expand Down Expand Up @@ -34,6 +35,8 @@ def __init__(
if not has_pydot:
warnings.warn("pydot not available")

self.graph = None

self.model = model
self.show_params = show_params
self.process_colors = process_colors
Expand All @@ -46,8 +49,6 @@ def __init__(
self.node_spacing = node_spacing
self.hide_variables = hide_variables

self.graph = None

return

def build_graph(self):
Expand All @@ -60,12 +61,13 @@ def build_graph(self):
)

# Solve the connections
files = []
connections = []
self.files = []
self.connections = []
for process in self.model.process_order:
frm_already = []
for var, frm in self.model.process_input_from[process].items():
var_con = f":{var}"

if self.hide_variables:
var_con = ""
if frm in frm_already:
Expand All @@ -79,13 +81,14 @@ def build_graph(self):
color = self.default_edge_color
if self.process_colors:
color = self.process_colors[frm]
connections += [
self.connections += [
(f"{frm}{var_con}", f"{process}{var_con}", color)
]

else:
file_name = frm.name
files += [file_name]
connections += [
self.files += [file_name]
self.connections += [
(
f"Files:{file_name.split('.')[0]}",
f"{process}{var_con}",
Expand All @@ -95,7 +98,7 @@ def build_graph(self):

# Build the file node, reset the position
self._current_pos = 0
self.file_node = self._file_node(files)
self.file_node = self._file_node(self.files)

# build the graph
self.graph = pydot.Dot(
Expand All @@ -109,20 +112,20 @@ def build_graph(self):
for process in self.model.process_order:
self.graph.add_node(self.process_nodes[process])

for con in connections:
for con in self.connections:
self.graph.add_edge(pydot.Edge(con[0], con[1], color=con[2]))

return

def SVG(self, verbose: bool = False):
def SVG(self, verbose: bool = False, dpi=45):
"""Display an SVG in jupyter notebook (via tempfile)."""

if not has_ipython:
warnings.warn("IPython is not available")
tmp_file = pl.Path(tempfile.NamedTemporaryFile().name)
if self.graph is None:
self.build_graph()
self.graph.write_svg(tmp_file)
self.graph.write_svg(tmp_file, prog=["dot", f"-Gdpi={dpi}"])
if verbose:
print(f"Displaying SVG written to temp file: {tmp_file}")

Expand Down Expand Up @@ -183,7 +186,7 @@ def _process_node(self, process, show_params: bool = False):
if vv in mass_budget_vars:
border_color_str = 'border="1" COLOR="BLUE"'
label += f" <TR>\n"
label += f' <TD COLSPAN="4" BGCOLOR="{category_colors[varset_name]}" {border_color_str} PORT="{vv}" ><FONT POINT-SIZE="9.0">{vv}</FONT></TD>\n'
label += f' <TD COLSPAN="4" BGCOLOR="{category_colors[varset_name]}" {border_color_str} PORT="{vv}"><FONT POINT-SIZE="9.0">{vv}</FONT></TD>\n'
label += f" </TR>\n"

label += f"</TABLE>>\n"
Expand Down
26 changes: 16 additions & 10 deletions pywatershed/analysis/process_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ def __init__(
# HRU one-time setups
self.hru_gdf = (
gpd.read_file(self.hru_shapefile)
# .drop("nhm_id", axis=1)
.rename(columns={"nhru_v1_1": "nhm_id"}).set_index("nhm_id")
.drop("nhm_id", axis=1)
.rename(columns={"nhru_v1_1": "nhm_id"})
.set_index("nhm_id")
)

# segment one-time setup
Expand All @@ -94,15 +95,13 @@ def __init__(
return

def plot(self, var_name: str, process: StorageUnit, cmap: str = None):
var_dims = list(
meta.get_vars(var_name)[var_name]["dimensions"].values()
)
var_dims = list(meta.get_vars(var_name)[var_name]["dims"])
if "nsegment" in var_dims:
if not cmap:
cmap = "cool"
return self.plot_seg_var(self, var_name, process, cmap)
return self.plot_seg_var(var_name, process, cmap)
elif "nhru" in var_dims:
return self.plot_hru_var(self, var_name, process)
return self.plot_hru_var(var_name, process)
else:
raise ValueError()

Expand Down Expand Up @@ -197,16 +196,23 @@ def get_hru_var(self, var_name: str, model: Model):
).set_index("nhm_id")
return data_df

def plot_hru(
def plot_hru_var(
self,
var_name: str,
model: Model = None,
process: StorageUnit,
data: np.ndarray = None,
data_units: str = None,
nhm_id: np.ndarray = None,
):
if data is None:
data_df = self.get_hru_var(var_name, model)
# data_df = self.get_hru_var(var_name, model)
data_df = pd.DataFrame(
{
"nhm_id": process.control.params.parameters["nhm_id"],
var_name: process[var_name],
}
).set_index("nhm_id")

else:
if nhm_id is None:
nhm_id = model.control.params.parameters["nhm_id"]
Expand Down
2 changes: 1 addition & 1 deletion pywatershed/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
vars = deepcopy(class_vars)
c_inputs = inputs.pop(comp)
_ = vars.pop(comp)

self._inputs_from[comp] = {}
# inputs_from_prev[comp] = {}
# inputs
Expand Down Expand Up @@ -178,7 +179,6 @@ def _find_input_files(self):
load_n_time_batches=self._load_n_time_batches,
)
for process in self.process_order:
self.process_input_from[process] = {}
for input, frm in self._inputs_from[process].items():
if not frm:
fname = file_inputs[input]._fname
Expand Down

0 comments on commit f773770

Please sign in to comment.