From fff61ead2f0dba3ca4b161952a64c07f7b1fd466 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 20 May 2024 18:07:41 -0700 Subject: [PATCH 1/9] graphcast improvements --- examples/weather/graphcast/__init__.py | 15 - examples/weather/graphcast/conf/config.yaml | 32 +- .../weather/graphcast/conf/config_small.yaml | 71 +++++ examples/weather/graphcast/train_graphcast.py | 10 +- examples/weather/graphcast/validation.py | 4 + modulus/datapipes/climate/era5_hdf5.py | 43 +++ modulus/datapipes/climate/synthetic.py | 62 +++- modulus/models/graphcast/graph_cast_net.py | 16 +- modulus/utils/graphcast/graph.py | 121 +++---- modulus/utils/graphcast/graph_utils.py | 82 +++-- modulus/utils/graphcast/icosahedral_mesh.py | 294 ++++++++++++++++++ modulus/utils/graphcast/icospheres.py | 63 ---- 12 files changed, 589 insertions(+), 224 deletions(-) delete mode 100644 examples/weather/graphcast/__init__.py create mode 100644 examples/weather/graphcast/conf/config_small.yaml create mode 100644 modulus/utils/graphcast/icosahedral_mesh.py delete mode 100644 modulus/utils/graphcast/icospheres.py diff --git a/examples/weather/graphcast/__init__.py b/examples/weather/graphcast/__init__.py deleted file mode 100644 index b2f171d4ac..0000000000 --- a/examples/weather/graphcast/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/examples/weather/graphcast/conf/config.yaml b/examples/weather/graphcast/conf/config.yaml index 89dd85b2d1..4bc446b6ff 100644 --- a/examples/weather/graphcast/conf/config.yaml +++ b/examples/weather/graphcast/conf/config.yaml @@ -21,7 +21,8 @@ hydra: dir: ./outputs/ processor_layers: 16 -hidden_dim: 64 # 512 +hidden_dim: 512 +multimesh_level: 6 segments: 1 force_single_checkpoint: False checkpoint_encoder: True @@ -32,18 +33,18 @@ checkpoint_encoder_finetune: True checkpoint_processor_finetune: True checkpoint_decoder_finetune: True concat_trick: True -cugraphops_encoder: False -cugraphops_processor: False -cugraphops_decoder: False -recompute_activation: False +cugraphops_encoder: True +cugraphops_processor: True +cugraphops_decoder: True +recompute_activation: True wb_mode: "disabled" synthetic_dataset: false dataset_path: "/data" -static_dataset_path: "datasets/static" -latlon_res: (721, 1440) -num_samples_per_year_train: 1448 +static_dataset_path: null +latlon_res: [721, 1440] +num_samples_per_year_train: 1408 num_workers: 0 # 8 -num_channels: 3 # 34 +num_channels: 474 num_channels_val: 3 num_val_steps: 8 num_val_spy: 3 # SPY: Samples Per Year @@ -59,13 +60,12 @@ num_iters_step1: 1000 num_iters_step2: 299000 num_iters_step3: 11000 step_change_freq: 1000 -save_freq: 1 # 500 -val_freq: 1 # 1000 -ckpt_path: "checkpoints_34var" -val_dir: "validation_34var" -ckpt_name: "model_34var.pt" -use_apex: False +save_freq: 500 +val_freq: 1000 +ckpt_path: "checkpoints" +val_dir: "validation" +ckpt_name: "model" +use_apex: True pyt_profiler: False profile: False profile_range: (90, 110) -icospheres_path: "icospheres.json" diff --git a/examples/weather/graphcast/conf/config_small.yaml b/examples/weather/graphcast/conf/config_small.yaml new file mode 100644 index 0000000000..8282bfc4df --- /dev/null +++ b/examples/weather/graphcast/conf/config_small.yaml @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: True + run: + dir: ./outputs_small/ + +processor_layers: 16 +hidden_dim: 512 +multimesh_level: 5 +segments: 1 +force_single_checkpoint: False +checkpoint_encoder: False +checkpoint_processor: False +checkpoint_decoder: False +force_single_checkpoint_finetune: False +checkpoint_encoder_finetune: True +checkpoint_processor_finetune: True +checkpoint_decoder_finetune: True +concat_trick: True +cugraphops_encoder: True +cugraphops_processor: True +cugraphops_decoder: True +recompute_activation: True +wb_mode: "disabled" +synthetic_dataset: false +dataset_path: "/data" +static_dataset_path: null +latlon_res: [181, 360] +num_samples_per_year_train: 1408 +num_workers: 0 # 8 +num_channels: 73 +num_channels_val: 3 +num_val_steps: 8 +num_val_spy: 3 # SPY: Samples Per Year +grad_clip_norm: 32.0 +jit: False +amp: False +amp_dtype: "bfloat16" +full_bf16: True +watch_model: False +lr: 1e-3 +lr_step3: 3e-7 +num_iters_step1: 1000 +num_iters_step2: 299000 +num_iters_step3: 11000 +step_change_freq: 1000 +save_freq: 500 +val_freq: 1000 +ckpt_path: "checkpoints_small" +val_dir: "validation_small" +ckpt_name: "model_small" +use_apex: True +pyt_profiler: False +profile: False +profile_range: (90, 110) diff --git a/examples/weather/graphcast/train_graphcast.py b/examples/weather/graphcast/train_graphcast.py index c8647d8e88..9bac614691 100644 --- a/examples/weather/graphcast/train_graphcast.py +++ b/examples/weather/graphcast/train_graphcast.py @@ -87,8 +87,9 @@ def __init__(self, cfg: DictConfig, dist, rank_zero_logger): # instantiate the model self.model = GraphCastNet( - meshgraph_path=to_absolute_path(cfg.icospheres_path), static_dataset_path=static_dataset_path, + multimesh_level=cfg.multimesh_level, + input_res=tuple(cfg.latlon_res), input_dim_grid_nodes=cfg.num_channels, input_dim_mesh_nodes=3, input_dim_edges=4, @@ -140,10 +141,14 @@ def __init__(self, cfg: DictConfig, dist, rank_zero_logger): DataPipe = ( SyntheticWeatherDataLoader if cfg.synthetic_dataset else ERA5HDF5Datapipe ) + self.interpolation_shape = ( + cfg.latlon_res if cfg.latlon_res != (721, 1440) else None + ) # interpolate if not in native resolution self.datapipe = DataPipe( data_dir=to_absolute_path(os.path.join(cfg.dataset_path, "train")), stats_dir=to_absolute_path(os.path.join(cfg.dataset_path, "stats")), channels=[i for i in range(cfg.num_channels)], + interpolation_shape=self.interpolation_shape, num_samples_per_year=cfg.num_samples_per_year_train, num_steps=1, batch_size=1, @@ -271,6 +276,7 @@ def main(cfg: DictConfig) -> None: iter < cfg.num_iters_step1 + cfg.num_iters_step2 + cfg.num_iters_step3 ), "Training is already finished!" for i, data in enumerate(trainer.datapipe): + # profiling if cfg.profile and iter == cfg.profile_range[0]: rank_zero_logger.info("Starting profile", "green") @@ -321,6 +327,8 @@ def main(cfg: DictConfig) -> None: data_dir=os.path.join(cfg.dataset_path, "train"), stats_dir=os.path.join(cfg.dataset_path, "stats"), channels=[i for i in range(cfg.num_channels)], + interpolation_shape=trainer.interpolation_shape, + num_samples_per_year=cfg.num_samples_per_year_train, num_steps=num_rollout_steps, batch_size=1, num_workers=cfg.num_workers, diff --git a/examples/weather/graphcast/validation.py b/examples/weather/graphcast/validation.py index d7c8bb4219..0b6419b68f 100644 --- a/examples/weather/graphcast/validation.py +++ b/examples/weather/graphcast/validation.py @@ -35,10 +35,14 @@ def __init__(self, cfg: DictConfig, model, dtype, dist): self.model = model self.dtype = dtype self.dist = dist + interpolation_shape = ( + cfg.latlon_res if cfg.latlon_res != (721, 1440) else None + ) # interpolate if not in native resolution self.val_datapipe = ERA5HDF5Datapipe( data_dir=os.path.join(cfg.dataset_path, "test"), stats_dir=os.path.join(cfg.dataset_path, "stats"), channels=[i for i in range(cfg.num_channels)], + interpolation_shape=interpolation_shape, num_steps=cfg.num_val_steps, batch_size=1, num_samples_per_year=cfg.num_val_spy, diff --git a/modulus/datapipes/climate/era5_hdf5.py b/modulus/datapipes/climate/era5_hdf5.py index beeeb83f56..586d8f5ac5 100644 --- a/modulus/datapipes/climate/era5_hdf5.py +++ b/modulus/datapipes/climate/era5_hdf5.py @@ -68,6 +68,12 @@ class ERA5HDF5Datapipe(Datapipe): stride 2 = 12 hours delta t, by default 1 num_steps : int, optional Number of timesteps are included in the output variables, by default 1 + interpolation_shape: Tuple[int, int], optional + Shape for resizing (H, W), by default None (no interpolation) + interpolation_type: str, optional + Interpolation type for resizing. Supports ["INTERP_NN", "INTERP_LINEAR", "INTERP_CUBIC", + "INTERP_LANCZOS3", "INTERP_TRIANGULAR", "INTERP_GAUSSIAN"]. Interpolation is performed + only if `interpolation_shape` is not None. by default "INTERP_LINEAR". patch_size : Union[Tuple[int, int], int, None], optional If specified, crops input and output variables so image dimensions are divisible by patch_size, by default None @@ -93,6 +99,8 @@ def __init__( batch_size: int = 1, num_steps: int = 1, stride: int = 1, + interpolation_shape: Union[Tuple[int, int], None] = None, + interpolation_type: str = "INTERP_LINEAR", patch_size: Union[Tuple[int, int], int, None] = None, num_samples_per_year: Union[int, None] = None, shuffle: bool = True, @@ -109,6 +117,8 @@ def __init__( self.stats_dir = Path(stats_dir) if stats_dir is not None else None self.channels = channels self.stride = stride + self.interpolation_shape = interpolation_shape + self.interpolation_type = interpolation_type self.num_steps = num_steps self.num_samples_per_year = num_samples_per_year self.process_rank = process_rank @@ -132,6 +142,22 @@ def __init__( if self.stats_dir is not None and not self.stats_dir.is_dir(): raise IOError(f"Error, stats directory {self.stats_dir} does not exist") + # Check interpolation type + if self.interpolation_shape is not None: + valid_interpolation = [ + "INTERP_NN", + "INTERP_LINEAR", + "INTERP_CUBIC", + "INTERP_LANCZOS3", + "INTERP_TRIANGULAR", + "INTERP_GAUSSIAN", + ] + if self.interpolation_type not in valid_interpolation: + raise ValueError( + f"Interpolation type {self.interpolation_type} not supported" + ) + self.interpolation_type = getattr(dali.types, self.interpolation_type) + self.parse_dataset_files() self.load_statistics() @@ -271,6 +297,7 @@ def _create_pipeline(self) -> dali.Pipeline: num_outputs=2, parallel=True, batch=False, + layout=["CHW", "FCHW"], ) if self.device.type == "cuda": # Move tensors to GPU as external_source won't do that. @@ -285,6 +312,22 @@ def _create_pipeline(self) -> dali.Pipeline: if self.stats_dir is not None: invar = dali.fn.normalize(invar, mean=self.mu[0], stddev=self.sd[0]) outvar = dali.fn.normalize(outvar, mean=self.mu, stddev=self.sd) + # Resize. + if self.interpolation_shape is not None: + invar = dali.fn.resize( + invar, + resize_x=self.interpolation_shape[1], + resize_y=self.interpolation_shape[0], + interp_type=self.interpolation_type, + antialias=False, + ) + outvar = dali.fn.resize( + outvar, + resize_x=self.interpolation_shape[1], + resize_y=self.interpolation_shape[0], + interp_type=self.interpolation_type, + antialias=False, + ) # Set outputs. pipe.set_outputs(invar, outvar) diff --git a/modulus/datapipes/climate/synthetic.py b/modulus/datapipes/climate/synthetic.py index 0d117560bd..bed5819eaa 100644 --- a/modulus/datapipes/climate/synthetic.py +++ b/modulus/datapipes/climate/synthetic.py @@ -109,26 +109,54 @@ def generate_data( Returns: numpy.ndarray: A 4D array of temperature values across days, channels, latitudes, and longitudes. """ - days: np.ndarray = np.arange(num_days) + days = np.arange(num_days) latitudes, longitudes = grid_size - daily_temps: np.ndarray = np.zeros( - (num_days, num_channels, latitudes, longitudes) + + # Create altitude effect and reshape + altitude_effect = np.arange(num_channels) * -0.5 + altitude_effect = altitude_effect[ + :, np.newaxis, np.newaxis + ] # Shape: (num_channels, 1, 1) + altitude_effect = np.tile( + altitude_effect, (1, latitudes, longitudes) + ) # Shape: (num_channels, latitudes, longitudes) + altitude_effect = altitude_effect[ + np.newaxis, :, :, : + ] # Shape: (1, num_channels, latitudes, longitudes) + altitude_effect = np.tile( + altitude_effect, (num_days, 1, 1, 1) + ) # Shape: (num_days, num_channels, latitudes, longitudes) + + # Create latitude variation and reshape + lat_variation = np.linspace(-amplitude, amplitude, latitudes) + lat_variation = lat_variation[:, np.newaxis] # Shape: (latitudes, 1) + lat_variation = np.tile( + lat_variation, (1, longitudes) + ) # Shape: (latitudes, longitudes) + lat_variation = lat_variation[ + np.newaxis, np.newaxis, :, : + ] # Shape: (1, 1, latitudes, longitudes) + lat_variation = np.tile( + lat_variation, (num_days, num_channels, 1, 1) + ) # Shape: (num_days, num_channels, latitudes, longitudes) + + # Create time effect and reshape + time_effect = np.sin(2 * np.pi * days / 365) + time_effect = time_effect[ + :, np.newaxis, np.newaxis, np.newaxis + ] # Shape: (num_days, 1, 1, 1) + time_effect = np.tile( + time_effect, (1, num_channels, latitudes, longitudes) + ) # Shape: (num_days, num_channels, latitudes, longitudes) + + # Generate noise + noise = np.random.normal( + scale=noise_level, size=(num_days, num_channels, latitudes, longitudes) ) - for day in days: - for channel in range(num_channels): - altitude_effect: float = ( - channel * -0.5 - ) # Temperature decreases with altitude - lat_variation: np.ndarray = np.linspace( - -amplitude, amplitude, latitudes - ) - daily_temps[day, channel] = ( - base_temp - + altitude_effect - + np.outer(lat_variation, np.sin(2 * np.pi * day / 365)) - + np.random.normal(scale=noise_level, size=(latitudes, longitudes)) - ) + # Calculate daily temperatures + daily_temps = base_temp + altitude_effect + lat_variation + time_effect + noise + return daily_temps def __len__(self) -> int: diff --git a/modulus/models/graphcast/graph_cast_net.py b/modulus/models/graphcast/graph_cast_net.py index d6f848d00e..adb6e97d24 100644 --- a/modulus/models/graphcast/graph_cast_net.py +++ b/modulus/models/graphcast/graph_cast_net.py @@ -66,11 +66,10 @@ class GraphCastNet(Module): Parameters ---------- - meshgraph_path : str - Path to the meshgraph file. If not provided, the meshgraph will be created - using PyMesh. static_dataset_path : str Path to the static dataset file. + multimesh_level: int, optional + Level of the multi-mesh, by default 6 input_res: Tuple[int, int] Input resolution of the latitude-longitude grid input_dim_grid_nodes : int, optional @@ -140,8 +139,8 @@ class GraphCastNet(Module): def __init__( self, - meshgraph_path: str, static_dataset_path: str, + multimesh_level: int = 6, input_res: tuple = (721, 1440), input_dim_grid_nodes: int = 474, input_dim_mesh_nodes: int = 3, @@ -183,13 +182,8 @@ def __init__( activation_fn = get_activation(activation_fn) # construct the graph - try: - self.graph = Graph(meshgraph_path, self.lat_lon_grid) - except FileNotFoundError: - raise FileNotFoundError( - "The icospheres_path is corrupted. " - "Tried using pymesh to generate the graph but could not find pymesh" - ) + self.graph = Graph(self.lat_lon_grid, multimesh_level) + self.mesh_graph = self.graph.create_mesh_graph(verbose=False) self.g2m_graph = self.graph.create_g2m_graph(verbose=False) self.m2g_graph = self.graph.create_m2g_graph(verbose=False) diff --git a/modulus/utils/graphcast/graph.py b/modulus/utils/graphcast/graph.py index e224f72201..0f13aec12e 100644 --- a/modulus/utils/graphcast/graph.py +++ b/modulus/utils/graphcast/graph.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging import numpy as np @@ -25,11 +24,16 @@ from .graph_utils import ( add_edge_features, add_node_features, - cell_to_adj, create_graph, create_heterograph, - get_edge_len, + get_face_centroids, latlon2xyz, + max_edge_length, +) +from .icosahedral_mesh import ( + faces_to_edges, + get_hierarchy_of_triangular_meshes_for_sphere, + merge_meshes, ) logger = logging.getLogger(__name__) @@ -40,47 +44,33 @@ class Graph: Parameters ---------- - icospheres_path : str - Path to the icospheres json file. - If the file does not exist, it will try to generate it using PyMesh. lat_lon_grid : Tensor Tensor with shape (lat, lon, 2) that includes the latitudes and longitudes meshgrid. + multimesh_level: int, optional + Level of the multi-mesh, by default 6 dtype : torch.dtype, optional Data type of the graph, by default torch.float """ def __init__( - self, icospheres_path: str, lat_lon_grid: Tensor, dtype=torch.float + self, lat_lon_grid: Tensor, multimesh_level=6, dtype=torch.float ) -> None: self.dtype = dtype - # Get or generate the icospheres - try: - with open(icospheres_path, "r") as f: - loaded_dict = json.load(f) - icospheres = { - key: (np.array(value) if isinstance(value, list) else value) - for key, value in loaded_dict.items() - } - logger.info(f"Opened pre-computed graph at {icospheres_path}.") - except FileNotFoundError: - from modulus.utils.graphcast.icospheres import ( - generate_and_save_icospheres, - ) - - logger.info( - f"Could not open {icospheres_path}...generating mesh from scratch." - ) - generate_and_save_icospheres() - - self.icospheres = icospheres - self.max_order = ( - len([key for key in self.icospheres.keys() if "faces" in key]) - 2 - ) # flatten lat/lon gird self.lat_lon_grid_flat = lat_lon_grid.permute(2, 0, 1).view(2, -1).permute(1, 0) + # create the multi-mesh + _meshes = get_hierarchy_of_triangular_meshes_for_sphere(splits=multimesh_level) + merged_mesh = merge_meshes(_meshes) + self.multimesh_src, self.multimesh_dst = faces_to_edges(merged_mesh.faces) + self.multimesh_vertices = np.array(merged_mesh.vertices) + self.multimesh_faces = merged_mesh.faces + finest_mesh = _meshes[-1] + self.finest_mesh_src, self.finest_mesh_dst = faces_to_edges(finest_mesh.faces) + self.finest_mesh_vertices = np.array(finest_mesh.vertices) + def create_mesh_graph(self, verbose: bool = True) -> Tensor: """Create the multimesh graph. @@ -94,19 +84,15 @@ def create_mesh_graph(self, verbose: bool = True) -> Tensor: DGLGraph Multimesh graph. """ - # create the bi-directional mesh graph - multimesh_faces = self.icospheres["order_0_faces"] - for i in range(1, self.max_order + 1): - multimesh_faces = np.concatenate( - (multimesh_faces, self.icospheres["order_" + str(i) + "_faces"]) - ) - - src, dst = cell_to_adj(multimesh_faces) mesh_graph = create_graph( - src, dst, to_bidirected=True, add_self_loop=False, dtype=torch.int32 + self.multimesh_src, + self.multimesh_dst, + to_bidirected=True, + add_self_loop=False, + dtype=torch.int32, ) mesh_pos = torch.tensor( - self.icospheres["order_" + str(self.max_order) + "_vertices"], + self.multimesh_vertices, dtype=torch.float32, ) mesh_graph = add_edge_features(mesh_graph, mesh_pos) @@ -132,53 +118,31 @@ def create_g2m_graph(self, verbose: bool = True) -> Tensor: Graph2mesh graph. """ # get the max edge length of icosphere with max order - edge_src = self.icospheres["order_" + str(self.max_order) + "_vertices"][ - self.icospheres["order_" + str(self.max_order) + "_faces"][:, 0] - ] - edge_dst = self.icospheres["order_" + str(self.max_order) + "_vertices"][ - self.icospheres["order_" + str(self.max_order) + "_faces"][:, 1] - ] - edge_len_1 = np.max(get_edge_len(edge_src, edge_dst)) - edge_src = self.icospheres["order_" + str(self.max_order) + "_vertices"][ - self.icospheres["order_" + str(self.max_order) + "_faces"][:, 0] - ] - edge_dst = self.icospheres["order_" + str(self.max_order) + "_vertices"][ - self.icospheres["order_" + str(self.max_order) + "_faces"][:, 2] - ] - edge_len_2 = np.max(get_edge_len(edge_src, edge_dst)) - edge_src = self.icospheres["order_" + str(self.max_order) + "_vertices"][ - self.icospheres["order_" + str(self.max_order) + "_faces"][:, 1] - ] - edge_dst = self.icospheres["order_" + str(self.max_order) + "_vertices"][ - self.icospheres["order_" + str(self.max_order) + "_faces"][:, 2] - ] - edge_len_3 = np.max(get_edge_len(edge_src, edge_dst)) - edge_len = max([edge_len_1, edge_len_2, edge_len_3]) + + max_edge_len = max_edge_length( + self.finest_mesh_vertices, self.finest_mesh_src, self.finest_mesh_dst + ) # create the grid2mesh bipartite graph cartesian_grid = latlon2xyz(self.lat_lon_grid_flat) n_nbrs = 4 - neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit( - self.icospheres["order_" + str(self.max_order) + "_vertices"] - ) + neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(self.multimesh_vertices) distances, indices = neighbors.kneighbors(cartesian_grid) src, dst = [], [] for i in range(len(cartesian_grid)): for j in range(n_nbrs): - if distances[i][j] <= 0.6 * edge_len: + if distances[i][j] <= 0.6 * max_edge_len: src.append(i) dst.append(indices[i][j]) - # NOTE this gives 1,624,344 edges, in the paper it is 1,618,746 - # this number is very sensitive to the chosen edge_len, not clear - # in the paper what they use. + # NOTE this gives 1,618,820 edges, in the paper it is 1,618,746 g2m_graph = create_heterograph( src, dst, ("grid", "g2m", "mesh"), dtype=torch.int32 - ) # number of edges is 3,114,720, exactly matches with the paper + ) g2m_graph.srcdata["pos"] = cartesian_grid.to(torch.float32) g2m_graph.dstdata["pos"] = torch.tensor( - self.icospheres["order_" + str(self.max_order) + "_vertices"], + self.multimesh_vertices, dtype=torch.float32, ) g2m_graph = add_edge_features( @@ -213,24 +177,21 @@ def create_m2g_graph(self, verbose: bool = True) -> Tensor: """ # create the mesh2grid bipartite graph cartesian_grid = latlon2xyz(self.lat_lon_grid_flat) - n_nbrs = 1 - neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit( - self.icospheres["order_" + str(self.max_order) + "_face_centroid"] + face_centroids = get_face_centroids( + self.multimesh_vertices, self.multimesh_faces ) + n_nbrs = 1 + neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(face_centroids) _, indices = neighbors.kneighbors(cartesian_grid) indices = indices.flatten() - src = [ - p - for i in indices - for p in self.icospheres["order_" + str(self.max_order) + "_faces"][i] - ] + src = [p for i in indices for p in self.multimesh_faces[i]] dst = [i for i in range(len(cartesian_grid)) for _ in range(3)] m2g_graph = create_heterograph( src, dst, ("mesh", "m2g", "grid"), dtype=torch.int32 ) # number of edges is 3,114,720, exactly matches with the paper m2g_graph.srcdata["pos"] = torch.tensor( - self.icospheres["order_" + str(self.max_order) + "_vertices"], + self.multimesh_vertices, dtype=torch.float32, ) m2g_graph.dstdata["pos"] = cartesian_grid.to(dtype=torch.float32) diff --git a/modulus/utils/graphcast/graph_utils.py b/modulus/utils/graphcast/graph_utils.py index bc50048c32..0936869d41 100644 --- a/modulus/utils/graphcast/graph_utils.py +++ b/modulus/utils/graphcast/graph_utils.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Tuple import dgl import numpy as np @@ -368,26 +368,6 @@ def rad2deg(rad): return rad * 180 / np.pi -def get_edge_len(edge_src: Tensor, edge_dst: Tensor, axis: int = 1): - """returns the length of the edge - - Parameters - ---------- - edge_src : Tensor - Tensor of shape (N, 3) containing the source of the edge - edge_dst : Tensor - Tensor of shape (N, 3) containing the destination of the edge - axis : int, optional - Axis along which the norm is computed, by default 1 - - Returns - ------- - Tensor - Tensor of shape (N, ) containing the length of the edge - """ - return np.linalg.norm(edge_src - edge_dst, axis=axis) - - def cell_to_adj(cells: List[List[int]]): """creates adjancy matrix in COO format from mesh cells @@ -405,3 +385,63 @@ def cell_to_adj(cells: List[List[int]]): src = [cells[i][indx] for i in range(num_cells) for indx in [0, 1, 2]] dst = [cells[i][indx] for i in range(num_cells) for indx in [1, 2, 0]] return src, dst + + +def max_edge_length( + vertices: List[List[float]], source_nodes: List[int], destination_nodes: List[int] +) -> float: + """ + Compute the maximum edge length in a graph. + + Parameters: + vertices (List[List[float]]): A list of tuples representing the coordinates of the vertices. + source_nodes (List[int]): A list of indices representing the source nodes of the edges. + destination_nodes (List[int]): A list of indices representing the destination nodes of the edges. + + Returns: + The maximum edge length in the graph (float). + """ + vertices_np = np.array(vertices) + source_coords = vertices_np[source_nodes] + dest_coords = vertices_np[destination_nodes] + + # Compute the squared distances for all edges + squared_differences = np.sum((source_coords - dest_coords) ** 2, axis=1) + + # Compute the maximum edge length + max_length = np.sqrt(np.max(squared_differences)) + + return max_length + + +def get_face_centroids( + vertices: List[Tuple[float, float, float]], faces: List[List[int]] +) -> List[Tuple[float, float, float]]: + """ + Compute the centroids of triangular faces in a graph. + + Parameters: + vertices (List[Tuple[float, float, float]]): A list of tuples representing the coordinates of the vertices. + faces (List[List[int]]): A list of lists, where each inner list contains three indices representing a triangular face. + + Returns: + List[Tuple[float, float, float]]: A list of tuples representing the centroids of the faces. + """ + centroids = [] + + for face in faces: + # Extract the coordinates of the vertices for the current face + v0 = vertices[face[0]] + v1 = vertices[face[1]] + v2 = vertices[face[2]] + + # Compute the centroid of the triangle + centroid = ( + (v0[0] + v1[0] + v2[0]) / 3, + (v0[1] + v1[1] + v2[1]) / 3, + (v0[2] + v1[2] + v2[2]) / 3, + ) + + centroids.append(centroid) + + return centroids diff --git a/modulus/utils/graphcast/icosahedral_mesh.py b/modulus/utils/graphcast/icosahedral_mesh.py new file mode 100644 index 0000000000..7c9099c9d2 --- /dev/null +++ b/modulus/utils/graphcast/icosahedral_mesh.py @@ -0,0 +1,294 @@ +# ignore_header_test +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: S101 +"""Utils for creating icosahedral meshes.""" + +import itertools +from typing import List, NamedTuple, Sequence, Tuple + +import numpy as np +from scipy.spatial import transform + + +class TriangularMesh(NamedTuple): + """Data structure for triangular meshes. + + Attributes: + vertices: spatial positions of the vertices of the mesh of shape + [num_vertices, num_dims]. + faces: triangular faces of the mesh of shape [num_faces, 3]. Contains + integer indices into `vertices`. + + """ + + vertices: np.ndarray + faces: np.ndarray + + +def merge_meshes(mesh_list: Sequence[TriangularMesh]) -> TriangularMesh: + """Merges all meshes into one. Assumes the last mesh is the finest. + + Args: + mesh_list: Sequence of meshes, from coarse to fine refinement levels. The + vertices and faces may contain those from preceding, coarser levels. + + Returns: + `TriangularMesh` for which the vertices correspond to the highest + resolution mesh in the hierarchy, and the faces are the join set of the + faces at all levels of the hierarchy. + """ + for mesh_i, mesh_ip1 in itertools.pairwise(mesh_list): + num_nodes_mesh_i = mesh_i.vertices.shape[0] + assert np.allclose(mesh_i.vertices, mesh_ip1.vertices[:num_nodes_mesh_i]) + + return TriangularMesh( + vertices=mesh_list[-1].vertices, + faces=np.concatenate([mesh.faces for mesh in mesh_list], axis=0), + ) + + +def get_hierarchy_of_triangular_meshes_for_sphere(splits: int) -> List[TriangularMesh]: + """Returns a sequence of meshes, each with triangularization sphere. + + Starting with a regular icosahedron (12 vertices, 20 faces, 30 edges) with + circumscribed unit sphere. Then, each triangular face is iteratively + subdivided into 4 triangular faces `splits` times. The new vertices are then + projected back onto the unit sphere. All resulting meshes are returned in a + list, from lowest to highest resolution. + + The vertices in each face are specified in counter-clockwise order as + observed from the outside the icosahedron. + + Args: + splits: How many times to split each triangle. + Returns: + Sequence of `TriangularMesh`s of length `splits + 1` each with: + + vertices: [num_vertices, 3] vertex positions in 3D, all with unit norm. + faces: [num_faces, 3] with triangular faces joining sets of 3 vertices. + Each row contains three indices into the vertices array, indicating + the vertices adjacent to the face. Always with positive orientation + (counterclock-wise when looking from the outside). + """ + current_mesh = get_icosahedron() + output_meshes = [current_mesh] + for _ in range(splits): + current_mesh = _two_split_unit_sphere_triangle_faces(current_mesh) + output_meshes.append(current_mesh) + return output_meshes + + +def get_icosahedron() -> TriangularMesh: + """Returns a regular icosahedral mesh with circumscribed unit sphere. + + See https://en.wikipedia.org/wiki/Regular_icosahedron#Cartesian_coordinates + for details on the construction of the regular icosahedron. + + The vertices in each face are specified in counter-clockwise order as observed + from the outside of the icosahedron. + + Returns: + TriangularMesh with: + + vertices: [num_vertices=12, 3] vertex positions in 3D, all with unit norm. + faces: [num_faces=20, 3] with triangular faces joining sets of 3 vertices. + Each row contains three indices into the vertices array, indicating + the vertices adjacent to the face. Always with positive orientation ( + counterclock-wise when looking from the outside). + + """ + phi = (1 + np.sqrt(5)) / 2 + vertices = [] + for c1 in [1.0, -1.0]: + for c2 in [phi, -phi]: + vertices.append((c1, c2, 0.0)) + vertices.append((0.0, c1, c2)) + vertices.append((c2, 0.0, c1)) + + vertices = np.array(vertices, dtype=np.float32) + vertices /= np.linalg.norm([1.0, phi]) + + # I did this manually, checking the orientation one by one. + faces = [ + (0, 1, 2), + (0, 6, 1), + (8, 0, 2), + (8, 4, 0), + (3, 8, 2), + (3, 2, 7), + (7, 2, 1), + (0, 4, 6), + (4, 11, 6), + (6, 11, 5), + (1, 5, 7), + (4, 10, 11), + (4, 8, 10), + (10, 8, 3), + (10, 3, 9), + (11, 10, 9), + (11, 9, 5), + (5, 9, 7), + (9, 3, 7), + (1, 6, 5), + ] + + # By default the top is an aris parallel to the Y axis. + # Need to rotate around the y axis by half the supplementary to the + # angle between faces divided by two to get the desired orientation. + # /O\ (top arist) + # / \ Z + # (adjacent face)/ \ (adjacent face) ^ + # / angle_between_faces \ | + # / \ | + # / \ YO-----> X + # This results in: + # (adjacent faceis now top plane) + # ----------------------O\ (top arist) + # \ + # \ + # \ (adjacent face) + # \ + # \ + # \ + + angle_between_faces = 2 * np.arcsin(phi / np.sqrt(3)) + rotation_angle = (np.pi - angle_between_faces) / 2 + rotation = transform.Rotation.from_euler(seq="y", angles=rotation_angle) + rotation_matrix = rotation.as_matrix() + vertices = np.dot(vertices, rotation_matrix) + + return TriangularMesh( + vertices=vertices.astype(np.float32), faces=np.array(faces, dtype=np.int32) + ) + + +def _two_split_unit_sphere_triangle_faces( + triangular_mesh: TriangularMesh, +) -> TriangularMesh: + """Splits each triangular face into 4 triangles keeping the orientation.""" + + # Every time we split a triangle into 4 we will be adding 3 extra vertices, + # located at the edge centres. + # This class handles the positioning of the new vertices, and avoids creating + # duplicates. + new_vertices_builder = _ChildVerticesBuilder(triangular_mesh.vertices) + + new_faces = [] + for ind1, ind2, ind3 in triangular_mesh.faces: + # Transform each triangular face into 4 triangles, + # preserving the orientation. + # ind3 + # / \ + # / \ + # / #3 \ + # / \ + # ind31 -------------- ind23 + # / \ / \ + # / \ #4 / \ + # / #1 \ / #2 \ + # / \ / \ + # ind1 ------------ ind12 ------------ ind2 + ind12 = new_vertices_builder.get_new_child_vertex_index((ind1, ind2)) + ind23 = new_vertices_builder.get_new_child_vertex_index((ind2, ind3)) + ind31 = new_vertices_builder.get_new_child_vertex_index((ind3, ind1)) + # Note how each of the 4 triangular new faces specifies the order of the + # vertices to preserve the orientation of the original face. As the input + # face should always be counter-clockwise as specified in the diagram, + # this means child faces should also be counter-clockwise. + new_faces.extend( + [ + [ind1, ind12, ind31], # 1 + [ind12, ind2, ind23], # 2 + [ind31, ind23, ind3], # 3 + [ind12, ind23, ind31], # 4 + ] + ) + return TriangularMesh( + vertices=new_vertices_builder.get_all_vertices(), + faces=np.array(new_faces, dtype=np.int32), + ) + + +class _ChildVerticesBuilder(object): + """Bookkeeping of new child vertices added to an existing set of vertices.""" + + def __init__(self, parent_vertices): + + # Because the same new vertex will be required when splitting adjacent + # triangles (which share an edge) we keep them in a hash table indexed by + # sorted indices of the vertices adjacent to the edge, to avoid creating + # duplicated child vertices. + self._child_vertices_index_mapping = {} + self._parent_vertices = parent_vertices + # We start with all previous vertices. + self._all_vertices_list = list(parent_vertices) + + def _get_child_vertex_key(self, parent_vertex_indices): + return tuple(sorted(parent_vertex_indices)) + + def _create_child_vertex(self, parent_vertex_indices): + """Creates a new vertex.""" + # Position for new vertex is the middle point, between the parent points, + # projected to unit sphere. + child_vertex_position = self._parent_vertices[list(parent_vertex_indices)].mean( + 0 + ) + child_vertex_position /= np.linalg.norm(child_vertex_position) + + # Add the vertex to the output list. The index for this new vertex will + # match the length of the list before adding it. + child_vertex_key = self._get_child_vertex_key(parent_vertex_indices) + self._child_vertices_index_mapping[child_vertex_key] = len( + self._all_vertices_list + ) + self._all_vertices_list.append(child_vertex_position) + + def get_new_child_vertex_index(self, parent_vertex_indices): + """Returns index for a child vertex, creating it if necessary.""" + # Get the key to see if we already have a new vertex in the middle. + child_vertex_key = self._get_child_vertex_key(parent_vertex_indices) + if child_vertex_key not in self._child_vertices_index_mapping: + self._create_child_vertex(parent_vertex_indices) + return self._child_vertices_index_mapping[child_vertex_key] + + def get_all_vertices(self): + """Returns an array with old vertices.""" + return np.array(self._all_vertices_list) + + +def faces_to_edges(faces: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Transforms polygonal faces to sender and receiver indices. + + It does so by transforming every face into N_i edges. Such if the triangular + face has indices [0, 1, 2], three edges are added 0->1, 1->2, and 2->0. + + If all faces have consistent orientation, and the surface represented by the + faces is closed, then every edge in a polygon with a certain orientation + is also part of another polygon with the opposite orientation. In this + situation, the edges returned by the method are always bidirectional. + + Args: + faces: Integer array of shape [num_faces, 3]. Contains node indices + adjacent to each face. + Returns: + Tuple with sender/receiver indices, each of shape [num_edges=num_faces*3]. + + """ + assert faces.ndim == 2 + assert faces.shape[-1] == 3 + senders = np.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]]) + receivers = np.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]]) + return senders, receivers diff --git a/modulus/utils/graphcast/icospheres.py b/modulus/utils/graphcast/icospheres.py deleted file mode 100644 index 1bd6cda721..0000000000 --- a/modulus/utils/graphcast/icospheres.py +++ /dev/null @@ -1,63 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -import numpy as np - -try: - import pymesh -except ImportError: - Warning("pymesh is not installed. Please install it to use icosphere.") - -# TODO apply a transformation to make faces parallel to ploes - - -def generate_and_save_icospheres( - save_path: str = "icospheres.json", level: int = 6 -) -> None: # pragma: no cover - """enerate icospheres from level 0 to 6 (inclusive) and save them to a json file. - - Parameters - ---------- - path : str - Path to save the json file. - """ - radius = 1 - center = np.array((0, 0, 0)) - icospheres = {"vertices": [], "faces": []} - - # Generate icospheres from level 0 to 6 (inclusive) - for order in range(level + 1): - icosphere = pymesh.generate_icosphere(radius, center, refinement_order=order) - icospheres["order_" + str(order) + "_vertices"] = icosphere.vertices - icospheres["order_" + str(order) + "_faces"] = icosphere.faces - icosphere.add_attribute("face_centroid") - icospheres[ - "order_" + str(order) + "_face_centroid" - ] = icosphere.get_face_attribute("face_centroid") - - # save icosphere vertices and faces to a json file - icospheres_dict = { - key: (value.tolist() if isinstance(value, np.ndarray) else value) - for key, value in icospheres.items() - } - with open(save_path, "w") as f: - json.dump(icospheres_dict, f) - - -if __name__ == "__main__": - generate_and_save_icospheres(level=6) From b7c0c634b27aa3d42429822b033d8e2d7d762118 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Nabian Date: Tue, 21 May 2024 13:59:08 -0700 Subject: [PATCH 2/9] formatting --- examples/weather/graphcast/conf/config.yaml | 2 + examples/weather/graphcast/train_graphcast.py | 1 + modulus/models/gnn_layers/embedder.py | 3 +- modulus/models/gnn_layers/mesh_edge_block.py | 3 +- .../models/gnn_layers/mesh_graph_decoder.py | 3 +- .../models/gnn_layers/mesh_graph_encoder.py | 3 +- modulus/models/gnn_layers/mesh_graph_mlp.py | 49 +++++++++++-------- modulus/models/gnn_layers/mesh_node_block.py | 3 +- modulus/models/graphcast/graph_cast_net.py | 3 +- .../models/graphcast/graph_cast_processor.py | 3 +- 10 files changed, 46 insertions(+), 27 deletions(-) diff --git a/examples/weather/graphcast/conf/config.yaml b/examples/weather/graphcast/conf/config.yaml index 4bc446b6ff..dd15690eb2 100644 --- a/examples/weather/graphcast/conf/config.yaml +++ b/examples/weather/graphcast/conf/config.yaml @@ -24,6 +24,8 @@ processor_layers: 16 hidden_dim: 512 multimesh_level: 6 segments: 1 +norm_type: "TELayerNorm" + # "TELayerNorm" or "LayerNorm" force_single_checkpoint: False checkpoint_encoder: True checkpoint_processor: False diff --git a/examples/weather/graphcast/train_graphcast.py b/examples/weather/graphcast/train_graphcast.py index 9bac614691..6ad3222c23 100644 --- a/examples/weather/graphcast/train_graphcast.py +++ b/examples/weather/graphcast/train_graphcast.py @@ -96,6 +96,7 @@ def __init__(self, cfg: DictConfig, dist, rank_zero_logger): output_dim_grid_nodes=cfg.num_channels, processor_layers=cfg.processor_layers, hidden_dim=cfg.hidden_dim, + norm_type=cfg.norm_type, do_concat_trick=cfg.concat_trick, use_cugraphops_encoder=cfg.cugraphops_encoder, use_cugraphops_processor=cfg.cugraphops_processor, diff --git a/modulus/models/gnn_layers/embedder.py b/modulus/models/gnn_layers/embedder.py index bc0fd698aa..68c8d0afe8 100644 --- a/modulus/models/gnn_layers/embedder.py +++ b/modulus/models/gnn_layers/embedder.py @@ -139,7 +139,8 @@ class GraphCastDecoderEmbedder(nn.Module): activation_fn : nn.Module, optional Type of activation function, by default nn.SiLU() norm_type : str, optional - Normalization type, by default "LayerNorm" + Normalization type ["TELayerNorm", "LayerNorm"]. + Use "TELayerNorm" for optimal performance. By default "LayerNorm". recompute_activation : bool, optional Flag for recomputing activation in backward to save memory, by default False. Currently, only SiLU is supported. diff --git a/modulus/models/gnn_layers/mesh_edge_block.py b/modulus/models/gnn_layers/mesh_edge_block.py index d9f4c2c69b..c53b86f7a9 100644 --- a/modulus/models/gnn_layers/mesh_edge_block.py +++ b/modulus/models/gnn_layers/mesh_edge_block.py @@ -44,7 +44,8 @@ class MeshEdgeBlock(nn.Module): activation_fn : nn.Module, optional Type of activation function, by default nn.SiLU() norm_type : str, optional - normalization type, by default "LayerNorm" + Normalization type ["TELayerNorm", "LayerNorm"]. + Use "TELayerNorm" for optimal performance. By default "LayerNorm". do_conat_trick: : bool, default=False Whether to replace concat+MLP with MLP+idx+sum recompute_activation : bool, optional diff --git a/modulus/models/gnn_layers/mesh_graph_decoder.py b/modulus/models/gnn_layers/mesh_graph_decoder.py index 60d9adf273..f724ece0b2 100644 --- a/modulus/models/gnn_layers/mesh_graph_decoder.py +++ b/modulus/models/gnn_layers/mesh_graph_decoder.py @@ -52,7 +52,8 @@ class MeshGraphDecoder(nn.Module): activation_fn : nn.Module, optional Type of activation function, by default nn.SiLU() norm_type : str, optional - Normalization type, by default "LayerNorm" + Normalization type ["TELayerNorm", "LayerNorm"]. + Use "TELayerNorm" for optimal performance. By default "LayerNorm". do_conat_trick: : bool, default=False Whether to replace concat+MLP with MLP+idx+sum recompute_activation : bool, optional diff --git a/modulus/models/gnn_layers/mesh_graph_encoder.py b/modulus/models/gnn_layers/mesh_graph_encoder.py index 3cefcf8bc8..9ace3c5230 100644 --- a/modulus/models/gnn_layers/mesh_graph_encoder.py +++ b/modulus/models/gnn_layers/mesh_graph_encoder.py @@ -54,7 +54,8 @@ class MeshGraphEncoder(nn.Module): activation_fn : nn.Module, optional Type of activation function, by default nn.SiLU() norm_type : str, optional - Normalization type, by default "LayerNorm" + Normalization type ["TELayerNorm", "LayerNorm"]. + Use "TELayerNorm" for optimal performance. By default "LayerNorm". do_conat_trick: : bool, default=False Whether to replace concat+MLP with MLP+idx+sum recompute_activation : bool, optional diff --git a/modulus/models/gnn_layers/mesh_graph_mlp.py b/modulus/models/gnn_layers/mesh_graph_mlp.py index c9b4fd5f6d..a689ed8a8c 100644 --- a/modulus/models/gnn_layers/mesh_graph_mlp.py +++ b/modulus/models/gnn_layers/mesh_graph_mlp.py @@ -26,11 +26,11 @@ from .utils import CuGraphCSC, concat_efeat, sum_efeat try: - from apex.normalization import FusedLayerNorm + from transformer_engine import pytorch as te - apex_imported = True + te_imported = True except ImportError: - apex_imported = False + te_imported = False class CustomSiLuLinearAutogradFunction(torch.autograd.Function): @@ -118,7 +118,8 @@ class MeshGraphMLP(nn.Module): activation_fn : nn.Module, optional , by default nn.SiLU() norm_type : str, optional - normalization type, by default "LayerNorm" + Normalization type ["TELayerNorm", "LayerNorm"]. + Use "TELayerNorm" for optimal performance. By default "LayerNorm". recompute_activation : bool, optional Flag for recomputing recompute_activation in backward to save memory, by default False. Currently, only SiLU is supported. @@ -147,14 +148,17 @@ def __init__( if norm_type is not None: if norm_type not in [ "LayerNorm", - "GraphNorm", - "InstanceNorm", - "BatchNorm", - "MessageNorm", + "TELayerNorm", ]: - raise ValueError(norm_type) - if norm_type == "LayerNorm" and apex_imported: - norm_layer = FusedLayerNorm + raise ValueError( + f"Invalid norm type {norm_type}. Supported types are LayerNorm and TELayerNorm." + ) + if norm_type == "TELayerNorm" and te_imported: + norm_layer = te.LayerNorm + elif norm_type == "TELayerNorm" and not te_imported: + raise ValueError( + "TELayerNorm requires transformer-engine to be installed." + ) else: norm_layer = getattr(nn, norm_type) layers.append(norm_layer(output_dim)) @@ -223,7 +227,8 @@ class MeshGraphEdgeMLPConcat(MeshGraphMLP): activation_fn : nn.Module, optional type of activation function, by default nn.SiLU() norm_type : str, optional - normalization type, by default "LayerNorm" + Normalization type ["TELayerNorm", "LayerNorm"]. + Use "TELayerNorm" for optimal performance. By default "LayerNorm". bias : bool, optional whether to use bias in the MLP, by default True recompute_activation : bool, optional @@ -295,7 +300,8 @@ class MeshGraphEdgeMLPSum(nn.Module): activation_fn : nn.Module, optional type of activation function, by default nn.SiLU() norm_type : str, optional - normalization type, by default "LayerNorm" + Normalization type ["TELayerNorm", "LayerNorm"]. + Use "TELayerNorm" for optimal performance. By default "LayerNorm". bias : bool, optional whether to use bias in the MLP, by default True recompute_activation : bool, optional @@ -349,14 +355,17 @@ def __init__( if norm_type is not None: if norm_type not in [ "LayerNorm", - "GraphNorm", - "InstanceNorm", - "BatchNorm", - "MessageNorm", + "TELayerNorm", ]: - raise ValueError(norm_type) - if norm_type == "LayerNorm" and apex_imported: - norm_layer = FusedLayerNorm + raise ValueError( + f"Invalid norm type {norm_type}. Supported types are LayerNorm and TELayerNorm." + ) + if norm_type == "TELayerNorm" and te_imported: + norm_layer = te.LayerNorm + elif norm_type == "TELayerNorm" and not te_imported: + raise ValueError( + "TELayerNorm requires transformer-engine to be installed." + ) else: norm_layer = getattr(nn, norm_type) layers.append(norm_layer(output_dim)) diff --git a/modulus/models/gnn_layers/mesh_node_block.py b/modulus/models/gnn_layers/mesh_node_block.py index 2c93543673..db34839eff 100644 --- a/modulus/models/gnn_layers/mesh_node_block.py +++ b/modulus/models/gnn_layers/mesh_node_block.py @@ -46,7 +46,8 @@ class MeshNodeBlock(nn.Module): activation_fn : nn.Module, optional Type of activation function, by default nn.SiLU() norm_type : str, optional - Normalization type, by default "LayerNorm" + Normalization type ["TELayerNorm", "LayerNorm"]. + Use "TELayerNorm" for optimal performance. By default "LayerNorm". recompute_activation : bool, optional Flag for recomputing activation in backward to save memory, by default False. Currently, only SiLU is supported. diff --git a/modulus/models/graphcast/graph_cast_net.py b/modulus/models/graphcast/graph_cast_net.py index adb6e97d24..046adae29e 100644 --- a/modulus/models/graphcast/graph_cast_net.py +++ b/modulus/models/graphcast/graph_cast_net.py @@ -91,7 +91,8 @@ class GraphCastNet(Module): activation_fn : str, optional Type of activation function, by default "silu" norm_type : str, optional - Normalization type, by default "LayerNorm" + Normalization type ["TELayerNorm", "LayerNorm"]. + Use "TELayerNorm" for optimal performance. By default "LayerNorm". use_cugraphops_encoder : bool, default=False Flag to select cugraphops kernels in encoder use_cugraphops_processor : bool, default=False diff --git a/modulus/models/graphcast/graph_cast_processor.py b/modulus/models/graphcast/graph_cast_processor.py index 5253fd2cb3..6fb6f2776d 100644 --- a/modulus/models/graphcast/graph_cast_processor.py +++ b/modulus/models/graphcast/graph_cast_processor.py @@ -46,7 +46,8 @@ class GraphCastProcessor(nn.Module): activation_fn : nn.Module, optional type of activation function, by default nn.SiLU() norm_type : str, optional - normalization type, by default "LayerNorm" + Normalization type ["TELayerNorm", "LayerNorm"]. + Use "TELayerNorm" for optimal performance. By default "LayerNorm". do_conat_trick: : bool, default=False whether to replace concat+MLP with MLP+idx+sum recompute_activation : bool, optional From 1420f88eea690fbcc20969af572913873a35a2de Mon Sep 17 00:00:00 2001 From: Mohammad Amin Nabian Date: Tue, 21 May 2024 14:24:22 -0700 Subject: [PATCH 3/9] linting --- test/models/graphcast/icospheres.json | 1 - test/models/graphcast/test_concat_trick.py | 4 ++-- test/models/graphcast/test_cugraphops.py | 8 +++----- test/models/graphcast/test_grad_checkpointing.py | 5 ++--- test/models/graphcast/test_graphcast.py | 12 ++++++------ test/models/graphcast/test_graphcast_snmg.py | 2 +- test/models/graphcast/utils.py | 8 -------- 7 files changed, 14 insertions(+), 26 deletions(-) delete mode 100644 test/models/graphcast/icospheres.json diff --git a/test/models/graphcast/icospheres.json b/test/models/graphcast/icospheres.json deleted file mode 100644 index c75d841674..0000000000 --- a/test/models/graphcast/icospheres.json +++ /dev/null @@ -1 +0,0 @@ -{"vertices": [], "faces": [], "order_0_vertices": [[-0.5257311121191336, 0.85065080835204, 0.0], [0.5257311121191336, 0.85065080835204, 0.0], [-0.5257311121191336, -0.85065080835204, 0.0], [0.5257311121191336, -0.85065080835204, 0.0], [0.0, -0.5257311121191336, 0.85065080835204], [0.0, 0.5257311121191336, 0.85065080835204], [0.0, -0.5257311121191336, -0.85065080835204], [0.0, 0.5257311121191336, -0.85065080835204], [0.85065080835204, 0.0, -0.5257311121191336], [0.85065080835204, 0.0, 0.5257311121191336], [-0.85065080835204, 0.0, -0.5257311121191336], [-0.85065080835204, 0.0, 0.5257311121191336]], "order_0_faces": [[0, 11, 5], [0, 5, 1], [0, 1, 7], [0, 7, 10], [0, 10, 11], [1, 5, 9], [5, 11, 4], [11, 10, 2], [10, 7, 6], [7, 1, 8], [3, 9, 4], [3, 4, 2], [3, 2, 6], [3, 6, 8], [3, 8, 9], [5, 4, 9], [2, 4, 11], [6, 2, 10], [8, 6, 7], [9, 8, 1]], "order_0_face_centroid": [[-0.4587939734903912, 0.4587939734903912, 0.4587939734903912], [0.0, 0.7423442429410713, 0.28355026945068], [0.0, 0.7423442429410713, -0.28355026945068], [-0.4587939734903912, 0.4587939734903912, -0.4587939734903912], [-0.7423442429410713, 0.28355026945068, 0.0], [0.4587939734903912, 0.4587939734903912, 0.4587939734903912], [-0.28355026945068, 0.0, 0.7423442429410713], [-0.7423442429410713, -0.28355026945068, 0.0], [-0.28355026945068, 0.0, -0.7423442429410713], [0.4587939734903912, 0.4587939734903912, -0.4587939734903912], [0.4587939734903912, -0.4587939734903912, 0.4587939734903912], [0.0, -0.7423442429410713, 0.28355026945068], [0.0, -0.7423442429410713, -0.28355026945068], [0.4587939734903912, -0.4587939734903912, -0.4587939734903912], [0.7423442429410713, -0.28355026945068, 0.0], [0.28355026945068, 0.0, 0.7423442429410713], [-0.4587939734903912, -0.4587939734903912, 0.4587939734903912], [-0.4587939734903912, -0.4587939734903912, -0.4587939734903912], [0.28355026945068, 0.0, -0.7423442429410713], [0.7423442429410713, 0.28355026945068, 0.0]], "order_1_vertices": [[-0.5257311121191336, 0.85065080835204, 0.0], [0.5257311121191336, 0.85065080835204, 0.0], [-0.5257311121191336, -0.85065080835204, 0.0], [0.5257311121191336, -0.85065080835204, 0.0], [0.0, -0.5257311121191336, 0.85065080835204], [0.0, 0.5257311121191336, 0.85065080835204], [0.0, -0.5257311121191336, -0.85065080835204], [0.0, 0.5257311121191336, -0.85065080835204], [0.85065080835204, 0.0, -0.5257311121191336], [0.85065080835204, 0.0, 0.5257311121191336], [-0.85065080835204, 0.0, -0.5257311121191336], [-0.85065080835204, 0.0, 0.5257311121191336], [-0.8090169943749475, 0.5, 0.3090169943749474], [-0.5, 0.3090169943749474, 0.8090169943749475], [-0.3090169943749474, 0.8090169943749475, 0.5], [0.3090169943749474, 0.8090169943749475, 0.5], [0.0, 1.0, 0.0], [0.3090169943749474, 0.8090169943749475, -0.5], [-0.3090169943749474, 0.8090169943749475, -0.5], [-0.5, 0.3090169943749474, -0.8090169943749475], [-0.8090169943749475, 0.5, -0.3090169943749474], [-1.0, 0.0, 0.0], [0.5, 0.3090169943749474, 0.8090169943749475], [0.8090169943749475, 0.5, 0.3090169943749474], [-0.5, -0.3090169943749474, 0.8090169943749475], [0.0, 0.0, 1.0], [-0.8090169943749475, -0.5, -0.3090169943749474], [-0.8090169943749475, -0.5, 0.3090169943749474], [0.0, 0.0, -1.0], [-0.5, -0.3090169943749474, -0.8090169943749475], [0.8090169943749475, 0.5, -0.3090169943749474], [0.5, 0.3090169943749474, -0.8090169943749475], [0.8090169943749475, -0.5, 0.3090169943749474], [0.5, -0.3090169943749474, 0.8090169943749475], [0.3090169943749474, -0.8090169943749475, 0.5], [-0.3090169943749474, -0.8090169943749475, 0.5], [0.0, -1.0, 0.0], [-0.3090169943749474, -0.8090169943749475, -0.5], [0.3090169943749474, -0.8090169943749475, -0.5], [0.5, -0.3090169943749474, -0.8090169943749475], [0.8090169943749475, -0.5, -0.3090169943749474], [1.0, 0.0, 0.0]], "order_1_faces": [[0, 12, 14], [11, 13, 12], [5, 14, 13], [12, 13, 14], [0, 14, 16], [5, 15, 14], [1, 16, 15], [14, 15, 16], [0, 16, 18], [1, 17, 16], [7, 18, 17], [16, 17, 18], [0, 18, 20], [7, 19, 18], [10, 20, 19], [18, 19, 20], [0, 20, 12], [10, 21, 20], [11, 12, 21], [20, 21, 12], [1, 15, 23], [5, 22, 15], [9, 23, 22], [15, 22, 23], [5, 13, 25], [11, 24, 13], [4, 25, 24], [13, 24, 25], [11, 21, 27], [10, 26, 21], [2, 27, 26], [21, 26, 27], [10, 19, 29], [7, 28, 19], [6, 29, 28], [19, 28, 29], [7, 17, 31], [1, 30, 17], [8, 31, 30], [17, 30, 31], [3, 32, 34], [9, 33, 32], [4, 34, 33], [32, 33, 34], [3, 34, 36], [4, 35, 34], [2, 36, 35], [34, 35, 36], [3, 36, 38], [2, 37, 36], [6, 38, 37], [36, 37, 38], [3, 38, 40], [6, 39, 38], [8, 40, 39], [38, 39, 40], [3, 40, 32], [8, 41, 40], [9, 32, 41], [40, 41, 32], [5, 25, 22], [4, 33, 25], [9, 22, 33], [25, 33, 22], [2, 35, 27], [4, 24, 35], [11, 27, 24], [35, 24, 27], [6, 37, 29], [2, 26, 37], [10, 29, 26], [37, 26, 29], [8, 39, 31], [6, 28, 39], [7, 31, 28], [39, 28, 31], [9, 41, 23], [8, 30, 41], [1, 23, 30], [41, 30, 23]], "order_1_face_centroid": [[-0.5479217002896761, 0.7198892675756624, 0.2696723314583158], [-0.7198892675756624, 0.2696723314583158, 0.5479217002896761], [-0.2696723314583158, 0.5479217002896761, 0.7198892675756624], [-0.5393446629166316, 0.5393446629166316, 0.5393446629166316], [-0.2782493688313603, 0.8865559342423291, 0.16666666666666666], [0.0, 0.7145883669563428, 0.6168836027840133], [0.2782493688313603, 0.8865559342423291, 0.16666666666666666], [0.0, 0.872677996249965, 0.3333333333333333], [-0.2782493688313603, 0.8865559342423291, -0.16666666666666666], [0.2782493688313603, 0.8865559342423291, -0.16666666666666666], [0.0, 0.7145883669563428, -0.6168836027840133], [0.0, 0.872677996249965, -0.3333333333333333], [-0.5479217002896761, 0.7198892675756624, -0.2696723314583158], [-0.2696723314583158, 0.5479217002896761, -0.7198892675756624], [-0.7198892675756624, 0.2696723314583158, -0.5479217002896761], [-0.5393446629166316, 0.5393446629166316, -0.5393446629166316], [-0.7145883669563428, 0.6168836027840133, 0.0], [-0.8865559342423291, 0.16666666666666666, -0.2782493688313603], [-0.8865559342423291, 0.16666666666666666, 0.2782493688313603], [-0.872677996249965, 0.3333333333333333, 0.0], [0.5479217002896761, 0.7198892675756624, 0.2696723314583158], [0.2696723314583158, 0.5479217002896761, 0.7198892675756624], [0.7198892675756624, 0.2696723314583158, 0.5479217002896761], [0.5393446629166316, 0.5393446629166316, 0.5393446629166316], [-0.16666666666666666, 0.2782493688313603, 0.8865559342423291], [-0.6168836027840133, 0.0, 0.7145883669563428], [-0.16666666666666666, -0.2782493688313603, 0.8865559342423291], [-0.3333333333333333, 0.0, 0.872677996249965], [-0.8865559342423291, -0.16666666666666666, 0.2782493688313603], [-0.8865559342423291, -0.16666666666666666, -0.2782493688313603], [-0.7145883669563428, -0.6168836027840133, 0.0], [-0.872677996249965, -0.3333333333333333, 0.0], [-0.6168836027840133, 0.0, -0.7145883669563428], [-0.16666666666666666, 0.2782493688313603, -0.8865559342423291], [-0.16666666666666666, -0.2782493688313603, -0.8865559342423291], [-0.3333333333333333, 0.0, -0.872677996249965], [0.2696723314583158, 0.5479217002896761, -0.7198892675756624], [0.5479217002896761, 0.7198892675756624, -0.2696723314583158], [0.7198892675756624, 0.2696723314583158, -0.5479217002896761], [0.5393446629166316, 0.5393446629166316, -0.5393446629166316], [0.5479217002896761, -0.7198892675756624, 0.2696723314583158], [0.7198892675756624, -0.2696723314583158, 0.5479217002896761], [0.2696723314583158, -0.5479217002896761, 0.7198892675756624], [0.5393446629166316, -0.5393446629166316, 0.5393446629166316], [0.2782493688313603, -0.8865559342423291, 0.16666666666666666], [0.0, -0.7145883669563428, 0.6168836027840133], [-0.2782493688313603, -0.8865559342423291, 0.16666666666666666], [0.0, -0.872677996249965, 0.3333333333333333], [0.2782493688313603, -0.8865559342423291, -0.16666666666666666], [-0.2782493688313603, -0.8865559342423291, -0.16666666666666666], [0.0, -0.7145883669563428, -0.6168836027840133], [0.0, -0.872677996249965, -0.3333333333333333], [0.5479217002896761, -0.7198892675756624, -0.2696723314583158], [0.2696723314583158, -0.5479217002896761, -0.7198892675756624], [0.7198892675756624, -0.2696723314583158, -0.5479217002896761], [0.5393446629166316, -0.5393446629166316, -0.5393446629166316], [0.7145883669563428, -0.6168836027840133, 0.0], [0.8865559342423291, -0.16666666666666666, -0.2782493688313603], [0.8865559342423291, -0.16666666666666666, 0.2782493688313603], [0.872677996249965, -0.3333333333333333, 0.0], [0.16666666666666666, 0.2782493688313603, 0.8865559342423291], [0.16666666666666666, -0.2782493688313603, 0.8865559342423291], [0.6168836027840133, 0.0, 0.7145883669563428], [0.3333333333333333, 0.0, 0.872677996249965], [-0.5479217002896761, -0.7198892675756624, 0.2696723314583158], [-0.2696723314583158, -0.5479217002896761, 0.7198892675756624], [-0.7198892675756624, -0.2696723314583158, 0.5479217002896761], [-0.5393446629166316, -0.5393446629166316, 0.5393446629166316], [-0.2696723314583158, -0.5479217002896761, -0.7198892675756624], [-0.5479217002896761, -0.7198892675756624, -0.2696723314583158], [-0.7198892675756624, -0.2696723314583158, -0.5479217002896761], [-0.5393446629166316, -0.5393446629166316, -0.5393446629166316], [0.6168836027840133, 0.0, -0.7145883669563428], [0.16666666666666666, -0.2782493688313603, -0.8865559342423291], [0.16666666666666666, 0.2782493688313603, -0.8865559342423291], [0.3333333333333333, 0.0, -0.872677996249965], [0.8865559342423291, 0.16666666666666666, 0.2782493688313603], [0.8865559342423291, 0.16666666666666666, -0.2782493688313603], [0.7145883669563428, 0.6168836027840133, 0.0], [0.872677996249965, 0.3333333333333333, 0.0]]} \ No newline at end of file diff --git a/test/models/graphcast/test_concat_trick.py b/test/models/graphcast/test_concat_trick.py index 3279f16dd3..a724dbd40f 100644 --- a/test/models/graphcast/test_concat_trick.py +++ b/test/models/graphcast/test_concat_trick.py @@ -54,8 +54,8 @@ def test_concat_trick(pytestconfig, recomp_act, num_channels=2, res_h=11, res_w= # Instantiate the model model = GraphCastNet( - meshgraph_path=icosphere_path, static_dataset_path=None, + multimesh_level=1, input_res=(res_h, res_w), input_dim_grid_nodes=num_channels, input_dim_mesh_nodes=3, @@ -72,8 +72,8 @@ def test_concat_trick(pytestconfig, recomp_act, num_channels=2, res_h=11, res_w= # Instantiate the model with concat trick enabled model_ct = GraphCastNet( - meshgraph_path=icosphere_path, static_dataset_path=None, + multimesh_level=1, input_res=(res_h, res_w), input_dim_grid_nodes=num_channels, input_dim_mesh_nodes=3, diff --git a/test/models/graphcast/test_cugraphops.py b/test/models/graphcast/test_cugraphops.py index 96d4d78d5b..aa382fdacd 100644 --- a/test/models/graphcast/test_cugraphops.py +++ b/test/models/graphcast/test_cugraphops.py @@ -24,7 +24,7 @@ import pytest # noqa: E402 import torch # noqa: E402 from pytest_utils import import_or_fail # noqa: E402 -from utils import fix_random_seeds, get_icosphere_path # noqa: E402 +from utils import fix_random_seeds # noqa: E402 @import_or_fail("dgl") @@ -34,8 +34,6 @@ def test_cugraphops( pytestconfig, recomp_act, concat_trick, num_channels=2, res_h=21, res_w=10 ): """Test cugraphops""" - icosphere_path = get_icosphere_path() - from modulus.models.graphcast.graph_cast_net import GraphCastNet if recomp_act and not common.utils.is_fusion_available("FusionDefinition"): @@ -54,8 +52,8 @@ def test_cugraphops( np.random.seed(0) model = GraphCastNet( - meshgraph_path=icosphere_path, static_dataset_path=None, + multimesh_level=1, input_res=(res_h, res_w), input_dim_grid_nodes=num_channels, input_dim_mesh_nodes=3, @@ -74,8 +72,8 @@ def test_cugraphops( fix_random_seeds() model_dgl = GraphCastNet( - meshgraph_path=icosphere_path, static_dataset_path=None, + multimesh_level=1, input_res=(res_h, res_w), input_dim_grid_nodes=num_channels, input_dim_mesh_nodes=3, diff --git a/test/models/graphcast/test_grad_checkpointing.py b/test/models/graphcast/test_grad_checkpointing.py index d6c0d693d2..4528f1993b 100644 --- a/test/models/graphcast/test_grad_checkpointing.py +++ b/test/models/graphcast/test_grad_checkpointing.py @@ -17,21 +17,20 @@ import pytest import torch from pytest_utils import import_or_fail -from utils import create_random_input, fix_random_seeds, get_icosphere_path +from utils import create_random_input, fix_random_seeds @import_or_fail("dgl") @pytest.mark.parametrize("device", ["cuda:0", "cpu"]) def test_grad_checkpointing(device, pytestconfig, num_channels=2, res_h=15, res_w=15): """Test gradient checkpointing""" - icosphere_path = get_icosphere_path() from modulus.models.graphcast.graph_cast_net import GraphCastNet # constants model_kwds = { - "meshgraph_path": icosphere_path, "static_dataset_path": None, + "multimesh_level": 2, "input_res": (res_h, res_w), "input_dim_grid_nodes": num_channels, "input_dim_mesh_nodes": 3, diff --git a/test/models/graphcast/test_graphcast.py b/test/models/graphcast/test_graphcast.py index 0061cacd1d..c76ea3ebfe 100644 --- a/test/models/graphcast/test_graphcast.py +++ b/test/models/graphcast/test_graphcast.py @@ -36,8 +36,8 @@ def test_graphcast_forward(device, pytestconfig, num_channels=2, res_h=10, res_w from modulus.models.graphcast.graph_cast_net import GraphCastNet model_kwds = { - "meshgraph_path": icosphere_path, "static_dataset_path": None, + "multimesh_level": 1, "input_res": (res_h, res_w), "input_dim_grid_nodes": num_channels, "input_dim_mesh_nodes": 3, @@ -70,8 +70,8 @@ def test_graphcast_constructor( # Define dictionary of constructor args arg_list = [ { - "meshgraph_path": icosphere_path, "static_dataset_path": None, + "multimesh_level": 1, "input_res": (res_h, res_w), "input_dim_grid_nodes": num_channels_1, "input_dim_mesh_nodes": 3, @@ -82,8 +82,8 @@ def test_graphcast_constructor( "do_concat_trick": True, }, { - "meshgraph_path": icosphere_path, "static_dataset_path": None, + "multimesh_level": 1, "input_res": (res_h, res_w), "input_dim_grid_nodes": num_channels_2, "input_dim_mesh_nodes": 3, @@ -119,8 +119,8 @@ def test_GraphCast_optims(device, pytestconfig, num_channels=2, res_h=10, res_w= def setup_model(): """Set up fresh model and inputs for each optim test""" model_kwds = { - "meshgraph_path": icosphere_path, "static_dataset_path": None, + "multimesh_level": 1, "input_res": (res_h, res_w), "input_dim_grid_nodes": num_channels, "input_dim_mesh_nodes": 3, @@ -162,8 +162,8 @@ def test_graphcast_checkpoint(device, pytestconfig, num_channels=2, res_h=10, re from modulus.models.graphcast.graph_cast_net import GraphCastNet model_kwds = { - "meshgraph_path": icosphere_path, "static_dataset_path": None, + "multimesh_level": 1, "input_res": (res_h, res_w), "input_dim_grid_nodes": num_channels, "input_dim_mesh_nodes": 3, @@ -197,8 +197,8 @@ def test_GraphCast_deploy(device, pytestconfig, num_channels=2, res_h=10, res_w= from modulus.models.graphcast.graph_cast_net import GraphCastNet model_kwds = { - "meshgraph_path": icosphere_path, "static_dataset_path": None, + "multimesh_level": 1, "input_res": (res_h, res_w), "input_dim_grid_nodes": num_channels, "input_dim_mesh_nodes": 3, diff --git a/test/models/graphcast/test_graphcast_snmg.py b/test/models/graphcast/test_graphcast_snmg.py index 82c2806d4a..ce59f070df 100644 --- a/test/models/graphcast/test_graphcast_snmg.py +++ b/test/models/graphcast/test_graphcast_snmg.py @@ -57,8 +57,8 @@ def run_test_distributed_graphcast( res_w = 32 model_kwds = { - "meshgraph_path": icosphere_path, "static_dataset_path": None, + "multimesh_level": 2, "input_res": (res_h, res_w), "input_dim_grid_nodes": 34, "input_dim_mesh_nodes": 3, diff --git a/test/models/graphcast/utils.py b/test/models/graphcast/utils.py index 8ac66c69f4..b64cc0bad4 100644 --- a/test/models/graphcast/utils.py +++ b/test/models/graphcast/utils.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import List import numpy as np @@ -37,13 +36,6 @@ def create_random_input(input_res, dim): return torch.randn(1, dim, *input_res) -def get_icosphere_path(): - """Get path to icosphere mesh""" - script_path = os.path.abspath(__file__) - icosphere_path = os.path.join(os.path.dirname(script_path), "icospheres.json") - return icosphere_path - - def compare_quantiles( t: torch.Tensor, ref: torch.Tensor, quantiles: List[float], tolerances: List[float] ): From b7c6ef17fd3c0101b918e1654d3d84cbc0adeeec Mon Sep 17 00:00:00 2001 From: Mohammad Amin Nabian Date: Tue, 21 May 2024 15:06:07 -0700 Subject: [PATCH 4/9] fix tests --- test/models/graphcast/test_concat_trick.py | 4 +--- test/models/graphcast/test_graphcast.py | 4 +--- test/models/graphcast/test_graphcast_snmg.py | 4 +--- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/test/models/graphcast/test_concat_trick.py b/test/models/graphcast/test_concat_trick.py index a724dbd40f..65b28ccac4 100644 --- a/test/models/graphcast/test_concat_trick.py +++ b/test/models/graphcast/test_concat_trick.py @@ -24,9 +24,7 @@ import pytest # noqa: E402 import torch # noqa: E402 from pytest_utils import import_or_fail # noqa: E402 -from utils import fix_random_seeds, get_icosphere_path # noqa: E402 - -icosphere_path = get_icosphere_path() +from utils import fix_random_seeds # noqa: E402 @import_or_fail("dgl") diff --git a/test/models/graphcast/test_graphcast.py b/test/models/graphcast/test_graphcast.py index c76ea3ebfe..28650b48d9 100644 --- a/test/models/graphcast/test_graphcast.py +++ b/test/models/graphcast/test_graphcast.py @@ -22,11 +22,9 @@ import common import pytest -from graphcast.utils import create_random_input, fix_random_seeds, get_icosphere_path +from graphcast.utils import create_random_input, fix_random_seeds from pytest_utils import import_or_fail -icosphere_path = get_icosphere_path() - @import_or_fail("dgl") @pytest.mark.parametrize("device", ["cuda:0", "cpu"]) diff --git a/test/models/graphcast/test_graphcast_snmg.py b/test/models/graphcast/test_graphcast_snmg.py index ce59f070df..f97aec98ed 100644 --- a/test/models/graphcast/test_graphcast_snmg.py +++ b/test/models/graphcast/test_graphcast_snmg.py @@ -22,15 +22,13 @@ import pytest import torch -from graphcast.utils import create_random_input, fix_random_seeds, get_icosphere_path +from graphcast.utils import create_random_input, fix_random_seeds from pytest_utils import import_or_fail from torch.nn.parallel import DistributedDataParallel from modulus.distributed import DistributedManager from modulus.models.graphcast.graph_cast_net import GraphCastNet -icosphere_path = get_icosphere_path() - def run_test_distributed_graphcast( rank: int, From 08c3edc3a74a4a9a29de759141c8319d323c9ef1 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Nabian Date: Wed, 22 May 2024 10:11:52 -0700 Subject: [PATCH 5/9] fix tests --- .../weather/graphcast/conf/config_small.yaml | 25 ++++++++++--------- examples/weather/graphcast/train_graphcast.py | 14 +++++++++-- modulus/utils/graphcast/data_utils.py | 18 +++++++------ test/models/test_fcn_mip_plugin.py | 3 +-- 4 files changed, 36 insertions(+), 24 deletions(-) diff --git a/examples/weather/graphcast/conf/config_small.yaml b/examples/weather/graphcast/conf/config_small.yaml index 8282bfc4df..55ea50019e 100644 --- a/examples/weather/graphcast/conf/config_small.yaml +++ b/examples/weather/graphcast/conf/config_small.yaml @@ -24,26 +24,27 @@ processor_layers: 16 hidden_dim: 512 multimesh_level: 5 segments: 1 +norm_type: "TELayerNorm" force_single_checkpoint: False checkpoint_encoder: False checkpoint_processor: False checkpoint_decoder: False force_single_checkpoint_finetune: False -checkpoint_encoder_finetune: True -checkpoint_processor_finetune: True -checkpoint_decoder_finetune: True -concat_trick: True -cugraphops_encoder: True -cugraphops_processor: True -cugraphops_decoder: True -recompute_activation: True -wb_mode: "disabled" +checkpoint_encoder_finetune: False +checkpoint_processor_finetune: False +checkpoint_decoder_finetune: False +concat_trick: False +cugraphops_encoder: False +cugraphops_processor: False +cugraphops_decoder: False +recompute_activation: False +wb_mode: "online" synthetic_dataset: false -dataset_path: "/data" -static_dataset_path: null +dataset_path: "/data/era5_73var" #"/code/datasets/era5_73var" +static_dataset_path: "/code/static" #"/code/mnabian/static" latlon_res: [181, 360] num_samples_per_year_train: 1408 -num_workers: 0 # 8 +num_workers: 8 num_channels: 73 num_channels_val: 3 num_val_steps: 8 diff --git a/examples/weather/graphcast/train_graphcast.py b/examples/weather/graphcast/train_graphcast.py index 6ad3222c23..e131f88932 100644 --- a/examples/weather/graphcast/train_graphcast.py +++ b/examples/weather/graphcast/train_graphcast.py @@ -184,11 +184,17 @@ def __init__(self, cfg: DictConfig, dist, rank_zero_logger): self.criterion = CellAreaWeightedLossFunction(self.area) try: self.optimizer = apex.optimizers.FusedAdam( - self.model.parameters(), lr=cfg.lr, betas=(0.9, 0.95), weight_decay=0.1 + self.model.parameters(), + lr=cfg.lr, + betas=(0.9, 0.95), + adam_w_mode=True, + weight_decay=0.1, ) rank_zero_logger.info("Using FusedAdam optimizer") except: - self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.lr) + self.optimizer = torch.optim.AdamW( + self.model.parameters(), lr=cfg.lr, betas=(0.9, 0.95), weight_decay=0.1 + ) scheduler1 = LinearLR( self.optimizer, start_factor=1e-3, @@ -250,11 +256,15 @@ def main(cfg: DictConfig) -> None: entity="Modulus", name="GraphCast-Training", group="GraphCast-DDP-Group", + mode=cfg.wb_mode, ) # Wandb logger logger = PythonLogger("main") # General python logger rank_zero_logger = RankZeroLoggingWrapper(logger, dist) # Rank 0 logger rank_zero_logger.file_logging() + # print ranks and devices + logger.info(f"Rank: {dist.rank}, Device: {dist.device}") + # specify the datapipe if cfg.synthetic_dataset: DataPipe = SyntheticWeatherDataLoader diff --git a/modulus/utils/graphcast/data_utils.py b/modulus/utils/graphcast/data_utils.py index b7b46b4ea6..082d05a24b 100644 --- a/modulus/utils/graphcast/data_utils.py +++ b/modulus/utils/graphcast/data_utils.py @@ -17,9 +17,9 @@ import os import netCDF4 as nc -import numpy as np import torch from torch import Tensor +from torch.nn.functional import interpolate from .graph_utils import deg2rad @@ -57,9 +57,10 @@ def get_lsm(self) -> Tensor: # pragma: no cover Tensor Land-sea mask with shape (1, 1, lat, lon). """ - ds = nc.Dataset(self.lsm_path) - lsm = np.expand_dims(ds["lsm"], axis=0) - return torch.tensor(lsm, dtype=torch.float32) + ds = torch.tensor(nc.Dataset(self.lsm_path)["lsm"], dtype=torch.float32) + ds = torch.unsqueeze(ds, dim=0) + ds = interpolate(ds, size=(self.lat.size(0), self.lon.size(0)), mode="bilinear") + return ds def get_geop(self, normalize: bool = True) -> Tensor: # pragma: no cover """Get geopotential from netCDF file. @@ -74,11 +75,12 @@ def get_geop(self, normalize: bool = True) -> Tensor: # pragma: no cover Tensor Normalized geopotential with shape (1, 1, lat, lon). """ - ds = nc.Dataset(self.geop_path) - geop = np.expand_dims(ds["z"], axis=0) + ds = torch.tensor(nc.Dataset(self.geop_path)["z"], dtype=torch.float32) + ds = torch.unsqueeze(ds, dim=0) + ds = interpolate(ds, size=(self.lat.size(0), self.lon.size(0)), mode="bilinear") if normalize: - geop = (geop - geop.mean()) / geop.std() - return torch.tensor(geop, dtype=torch.float32) + ds = (ds - ds.mean()) / ds.std() + return ds def get_lat_lon(self) -> Tensor: # pragma: no cover """Computes cosine of latitudes and sine and cosine of longitudes. diff --git a/test/models/test_fcn_mip_plugin.py b/test/models/test_fcn_mip_plugin.py index bba2491d2c..75ca77b802 100644 --- a/test/models/test_fcn_mip_plugin.py +++ b/test/models/test_fcn_mip_plugin.py @@ -179,10 +179,9 @@ def test_dlwp(tmp_path, batch_size, device, pytestconfig): def save_untrained_graphcast(path): """Function to save untrained GraphCast""" - icosphere_path = path / "icospheres.json" config = { - "meshgraph_path": icosphere_path.as_posix(), "static_dataset_path": None, + "multimesh_level": 1, "input_dim_grid_nodes": 2, "input_dim_mesh_nodes": 3, "input_dim_edges": 4, From 2e0969b5791a9aeb58ce5737dddedfed3d72260c Mon Sep 17 00:00:00 2001 From: Mohammad Amin Nabian Date: Wed, 22 May 2024 10:36:29 -0700 Subject: [PATCH 6/9] add requirements.txt --- examples/weather/graphcast/requirements.txt | 1 + 1 file changed, 1 insertion(+) create mode 100644 examples/weather/graphcast/requirements.txt diff --git a/examples/weather/graphcast/requirements.txt b/examples/weather/graphcast/requirements.txt new file mode 100644 index 0000000000..eb8fad4e56 --- /dev/null +++ b/examples/weather/graphcast/requirements.txt @@ -0,0 +1 @@ +git+https://github.com/NVIDIA/TransformerEngine.git@stable From 94f6faf382bca11148a1055d331a52130f8065fa Mon Sep 17 00:00:00 2001 From: Mohammad Amin Nabian Date: Wed, 22 May 2024 11:26:26 -0700 Subject: [PATCH 7/9] exclude graphcast from fcn-mip-plugin tests --- test/models/test_fcn_mip_plugin.py | 59 ------------------------------ 1 file changed, 59 deletions(-) diff --git a/test/models/test_fcn_mip_plugin.py b/test/models/test_fcn_mip_plugin.py index 75ca77b802..490bd67931 100644 --- a/test/models/test_fcn_mip_plugin.py +++ b/test/models/test_fcn_mip_plugin.py @@ -36,14 +36,6 @@ def dlwp_data_dir(): return path -@pytest.fixture -def graphcast_data_dir(): - """Data dir for graphcast package""" - - path = "/data/nfs/modulus-data/plugin_data/graphcast/" - return path - - def _copy_directory(src, dst): if not os.path.exists(dst): os.makedirs(dst) @@ -176,57 +168,6 @@ def test_dlwp(tmp_path, batch_size, device, pytestconfig): assert out.shape == x.shape -def save_untrained_graphcast(path): - """Function to save untrained GraphCast""" - - config = { - "static_dataset_path": None, - "multimesh_level": 1, - "input_dim_grid_nodes": 2, - "input_dim_mesh_nodes": 3, - "input_dim_edges": 4, - "output_dim_grid_nodes": 2, - "processor_layers": 3, - "hidden_dim": 2, - "do_concat_trick": True, - } - - from modulus.models.graphcast import GraphCastNet - - model = GraphCastNet(**config) - - config_path = path / "config.json" - with config_path.open("w") as f: - json.dump(config, f) - - check_point_path = path / "weights.tar" - save_ddp_checkpoint(model, check_point_path, del_device_buffer=False) - - url = f"file://{path.as_posix()}" - package = Package(url, seperator="/") - return package - - -@nfsdata_or_fail -@import_or_fail(["dgl", "ruamel.yaml", "tensorly", "torch_harmonics", "tltorch"]) -def test_graphcast(tmp_path, graphcast_data_dir, pytestconfig): - """Test GraphCast plugin""" - - from modulus.models.fcn_mip_plugin import graphcast_34ch - - source_dir = graphcast_data_dir - _copy_directory(source_dir, tmp_path) - - package = save_untrained_graphcast( - tmp_path - ) # here package needs to load after icosphere.json is copied. - model = graphcast_34ch(package, pretrained=False) - x = torch.randn(1, 34, 721, 1440).to("cuda") - with torch.no_grad(): - out = model(x) - assert out.shape == x.shape - - @nfsdata_or_fail @import_or_fail(["dgl", "ruamel.yaml", "tensorly", "torch_harmonics", "tltorch"]) @pytest.mark.parametrize("batch_size", [1, 2]) From 6d3d4393057466a0187fbebe1e6cfb94195e9968 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Nabian Date: Wed, 22 May 2024 12:22:58 -0700 Subject: [PATCH 8/9] taking static dataset outside of model definition --- examples/weather/graphcast/conf/config.yaml | 2 +- .../weather/graphcast/conf/config_small.yaml | 3 +- examples/weather/graphcast/train_graphcast.py | 50 +++++++++++++++++-- modulus/models/graphcast/graph_cast_net.py | 33 +----------- test/models/graphcast/test_concat_trick.py | 2 - test/models/graphcast/test_cugraphops.py | 2 - test/models/graphcast/test_graphcast.py | 6 --- test/models/graphcast/test_graphcast_snmg.py | 1 - 8 files changed, 50 insertions(+), 49 deletions(-) diff --git a/examples/weather/graphcast/conf/config.yaml b/examples/weather/graphcast/conf/config.yaml index dd15690eb2..0e6227d09d 100644 --- a/examples/weather/graphcast/conf/config.yaml +++ b/examples/weather/graphcast/conf/config.yaml @@ -39,7 +39,7 @@ cugraphops_encoder: True cugraphops_processor: True cugraphops_decoder: True recompute_activation: True -wb_mode: "disabled" +wb_mode: "online" synthetic_dataset: false dataset_path: "/data" static_dataset_path: null diff --git a/examples/weather/graphcast/conf/config_small.yaml b/examples/weather/graphcast/conf/config_small.yaml index 55ea50019e..0f6d54b8ae 100644 --- a/examples/weather/graphcast/conf/config_small.yaml +++ b/examples/weather/graphcast/conf/config_small.yaml @@ -45,7 +45,8 @@ static_dataset_path: "/code/static" #"/code/mnabian/static" latlon_res: [181, 360] num_samples_per_year_train: 1408 num_workers: 8 -num_channels: 73 +num_channels_climate: 73 +num_channels_static: 1 num_channels_val: 3 num_val_steps: 8 num_val_spy: 3 # SPY: Samples Per Year diff --git a/examples/weather/graphcast/train_graphcast.py b/examples/weather/graphcast/train_graphcast.py index e131f88932..c4c532fd04 100644 --- a/examples/weather/graphcast/train_graphcast.py +++ b/examples/weather/graphcast/train_graphcast.py @@ -43,6 +43,7 @@ from validation import Validation from modulus.datapipes.climate import ERA5HDF5Datapipe, SyntheticWeatherDataLoader from modulus.distributed import DistributedManager +from modulus.utils.graphcast.data_utils import StaticData import hydra from hydra.utils import to_absolute_path @@ -85,15 +86,22 @@ def __init__(self, cfg: DictConfig, dist, rank_zero_logger): else: raise ValueError("Invalid dtype for config amp") + # Handle the number of static channels + if not self.static_dataset_path: + cfg.num_channels_static = 0 + rank_zero_logger.warning( + "Static dataset path is not provided. Setting num_channels_static to 0." + ) + # instantiate the model self.model = GraphCastNet( static_dataset_path=static_dataset_path, multimesh_level=cfg.multimesh_level, input_res=tuple(cfg.latlon_res), - input_dim_grid_nodes=cfg.num_channels, + input_dim_grid_nodes=cfg.num_channels_climate + cfg.num_channels_static, input_dim_mesh_nodes=3, input_dim_edges=4, - output_dim_grid_nodes=cfg.num_channels, + output_dim_grid_nodes=cfg.num_channels_climate, processor_layers=cfg.processor_layers, hidden_dim=cfg.hidden_dim, norm_type=cfg.norm_type, @@ -148,7 +156,7 @@ def __init__(self, cfg: DictConfig, dist, rank_zero_logger): self.datapipe = DataPipe( data_dir=to_absolute_path(os.path.join(cfg.dataset_path, "train")), stats_dir=to_absolute_path(os.path.join(cfg.dataset_path, "stats")), - channels=[i for i in range(cfg.num_channels)], + channels=[i for i in range(cfg.num_channels_climate)], interpolation_shape=self.interpolation_shape, num_samples_per_year=cfg.num_samples_per_year_train, num_steps=1, @@ -226,6 +234,36 @@ def __init__(self, cfg: DictConfig, dist, rank_zero_logger): device=dist.device, ) + # Get the static data + if self.static_dataset_path: + self.static_data = StaticData( + static_dataset_path, self.latitudes, self.longitudes + ).get() + self.static_data = self.static_data.to(dtype=self.dtype).to( + device=dist.device + ) + assert cfg.num_channels_static == self.static_data.size(1), ( + f"Number of static channels in model ({cfg.num_channels_static}) " + + f"does not match the static data ({self.static_data.size(1)})" + ) + if ( + self.model.is_distributed and self.model.expect_partitioned_input + ): # TODO verify + # if input itself is distributed, we also need to distribute static data + self.static_data( + self.static_data[0].view(cfg.num_channels_static, -1).permute(1, 0) + ) + self.static_data = self.g2m_graph.get_src_node_features_in_partition( + self.static_data + ) + self.static_data = self.static_data.permute(1, 0).unsqueeze(dim=0) + self.static_data = self.static_data.to(dtype=self.dtype).to( + device=dist.device + ) + + else: + self.static_data = None + @hydra.main(version_base="1.3", config_path="conf", config_name="config") def main(cfg: DictConfig) -> None: @@ -337,7 +375,7 @@ def main(cfg: DictConfig) -> None: trainer.datapipe = DataPipe( data_dir=os.path.join(cfg.dataset_path, "train"), stats_dir=os.path.join(cfg.dataset_path, "stats"), - channels=[i for i in range(cfg.num_channels)], + channels=[i for i in range(cfg.num_channels_climate)], interpolation_shape=trainer.interpolation_shape, num_samples_per_year=cfg.num_samples_per_year_train, num_steps=num_rollout_steps, @@ -357,6 +395,10 @@ def main(cfg: DictConfig) -> None: # TODO modify for history > 0 data_x = data[0]["invar"] data_y = data[0]["outvar"] + + # add static data + invar = torch.concat((invar, trainer.static_data), dim=1) + # move to device & dtype data_x = data_x.to(dtype=trainer.dtype) grid_nfeat = data_x diff --git a/modulus/models/graphcast/graph_cast_net.py b/modulus/models/graphcast/graph_cast_net.py index 046adae29e..616196a0f4 100644 --- a/modulus/models/graphcast/graph_cast_net.py +++ b/modulus/models/graphcast/graph_cast_net.py @@ -37,7 +37,6 @@ from modulus.models.layers import get_activation from modulus.models.meta import ModelMetaData from modulus.models.module import Module -from modulus.utils.graphcast.data_utils import StaticData from modulus.utils.graphcast.graph import Graph from .graph_cast_processor import GraphCastProcessor @@ -66,8 +65,6 @@ class GraphCastNet(Module): Parameters ---------- - static_dataset_path : str - Path to the static dataset file. multimesh_level: int, optional Level of the multi-mesh, by default 6 input_res: Tuple[int, int] @@ -140,7 +137,6 @@ class GraphCastNet(Module): def __init__( self, - static_dataset_path: str, multimesh_level: int = 6, input_res: tuple = (721, 1440), input_dim_grid_nodes: int = 474, @@ -177,7 +173,6 @@ def __init__( self.lat_lon_grid = torch.stack( torch.meshgrid(self.latitudes, self.longitudes, indexing="ij"), dim=-1 ) - self.has_static_data = static_dataset_path is not None # Set activation function activation_fn = get_activation(activation_fn) @@ -235,25 +230,6 @@ def __init__( self.mesh_ndata ) - # Get the static data - if self.has_static_data: - self.static_data = StaticData( - static_dataset_path, self.latitudes, self.longitudes - ).get() - num_static_feat = self.static_data.size(1) - input_dim_grid_nodes += num_static_feat - if self.is_distributed and expect_partitioned_input: - # if input itself is distributed, we also need to distribute static data - self.static_data( - self.static_data[0].view(num_static_feat, -1).permute(1, 0) - ) - self.static_data = self.g2m_graph.get_src_node_features_in_partition( - self.static_data - ) - self.static_data = self.static_data.permute(1, 0).unsqueeze(dim=0) - else: - self.static_data = None - self.input_dim_grid_nodes = input_dim_grid_nodes self.output_dim_grid_nodes = output_dim_grid_nodes self.input_res = input_res @@ -640,15 +616,10 @@ def prepare_input(self, invar: Tensor, expect_partitioned_input: bool) -> Tensor """ if expect_partitioned_input and self.is_distributed: # partitioned input is [N, C, P] instead of [N, C, H, W] - if self.has_static_data: - invar = torch.concat((invar, self.static_data), dim=1) - invar = invar[0].permute(1, 0) + invar = invar[0].permute(1, 0) else: if invar.size(0) != 1: raise ValueError("GraphCast does not support batch size > 1") - # concat static data - if self.has_static_data: - invar = torch.concat((invar, self.static_data), dim=1) invar = invar[0].view(self.input_dim_grid_nodes, -1).permute(1, 0) if self.is_distributed: # partition node features @@ -711,8 +682,6 @@ def to(self, *args: Any, **kwargs: Any) -> Self: self.m2g_edata = self.m2g_edata.to(*args, **kwargs) self.mesh_ndata = self.mesh_ndata.to(*args, **kwargs) self.mesh_edata = self.mesh_edata.to(*args, **kwargs) - if self.has_static_data: - self.static_data = self.static_data.to(*args, **kwargs) device, _, _, _ = torch._C._nn._parse_to(*args, **kwargs) self.g2m_graph = self.g2m_graph.to(device) diff --git a/test/models/graphcast/test_concat_trick.py b/test/models/graphcast/test_concat_trick.py index 65b28ccac4..4c7900bcfa 100644 --- a/test/models/graphcast/test_concat_trick.py +++ b/test/models/graphcast/test_concat_trick.py @@ -52,7 +52,6 @@ def test_concat_trick(pytestconfig, recomp_act, num_channels=2, res_h=11, res_w= # Instantiate the model model = GraphCastNet( - static_dataset_path=None, multimesh_level=1, input_res=(res_h, res_w), input_dim_grid_nodes=num_channels, @@ -70,7 +69,6 @@ def test_concat_trick(pytestconfig, recomp_act, num_channels=2, res_h=11, res_w= # Instantiate the model with concat trick enabled model_ct = GraphCastNet( - static_dataset_path=None, multimesh_level=1, input_res=(res_h, res_w), input_dim_grid_nodes=num_channels, diff --git a/test/models/graphcast/test_cugraphops.py b/test/models/graphcast/test_cugraphops.py index aa382fdacd..48f71c627c 100644 --- a/test/models/graphcast/test_cugraphops.py +++ b/test/models/graphcast/test_cugraphops.py @@ -52,7 +52,6 @@ def test_cugraphops( np.random.seed(0) model = GraphCastNet( - static_dataset_path=None, multimesh_level=1, input_res=(res_h, res_w), input_dim_grid_nodes=num_channels, @@ -72,7 +71,6 @@ def test_cugraphops( fix_random_seeds() model_dgl = GraphCastNet( - static_dataset_path=None, multimesh_level=1, input_res=(res_h, res_w), input_dim_grid_nodes=num_channels, diff --git a/test/models/graphcast/test_graphcast.py b/test/models/graphcast/test_graphcast.py index 28650b48d9..b54b8167e6 100644 --- a/test/models/graphcast/test_graphcast.py +++ b/test/models/graphcast/test_graphcast.py @@ -34,7 +34,6 @@ def test_graphcast_forward(device, pytestconfig, num_channels=2, res_h=10, res_w from modulus.models.graphcast.graph_cast_net import GraphCastNet model_kwds = { - "static_dataset_path": None, "multimesh_level": 1, "input_res": (res_h, res_w), "input_dim_grid_nodes": num_channels, @@ -68,7 +67,6 @@ def test_graphcast_constructor( # Define dictionary of constructor args arg_list = [ { - "static_dataset_path": None, "multimesh_level": 1, "input_res": (res_h, res_w), "input_dim_grid_nodes": num_channels_1, @@ -80,7 +78,6 @@ def test_graphcast_constructor( "do_concat_trick": True, }, { - "static_dataset_path": None, "multimesh_level": 1, "input_res": (res_h, res_w), "input_dim_grid_nodes": num_channels_2, @@ -117,7 +114,6 @@ def test_GraphCast_optims(device, pytestconfig, num_channels=2, res_h=10, res_w= def setup_model(): """Set up fresh model and inputs for each optim test""" model_kwds = { - "static_dataset_path": None, "multimesh_level": 1, "input_res": (res_h, res_w), "input_dim_grid_nodes": num_channels, @@ -160,7 +156,6 @@ def test_graphcast_checkpoint(device, pytestconfig, num_channels=2, res_h=10, re from modulus.models.graphcast.graph_cast_net import GraphCastNet model_kwds = { - "static_dataset_path": None, "multimesh_level": 1, "input_res": (res_h, res_w), "input_dim_grid_nodes": num_channels, @@ -195,7 +190,6 @@ def test_GraphCast_deploy(device, pytestconfig, num_channels=2, res_h=10, res_w= from modulus.models.graphcast.graph_cast_net import GraphCastNet model_kwds = { - "static_dataset_path": None, "multimesh_level": 1, "input_res": (res_h, res_w), "input_dim_grid_nodes": num_channels, diff --git a/test/models/graphcast/test_graphcast_snmg.py b/test/models/graphcast/test_graphcast_snmg.py index f97aec98ed..2317e5c7e9 100644 --- a/test/models/graphcast/test_graphcast_snmg.py +++ b/test/models/graphcast/test_graphcast_snmg.py @@ -55,7 +55,6 @@ def run_test_distributed_graphcast( res_w = 32 model_kwds = { - "static_dataset_path": None, "multimesh_level": 2, "input_res": (res_h, res_w), "input_dim_grid_nodes": 34, From 7f2c44f78712d62959b803e328c330b46fb5c3dc Mon Sep 17 00:00:00 2001 From: Mohammad Amin Nabian Date: Wed, 22 May 2024 13:19:07 -0700 Subject: [PATCH 9/9] fix pytest --- examples/weather/graphcast/conf/config_small.yaml | 4 ++-- test/models/graphcast/test_grad_checkpointing.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/weather/graphcast/conf/config_small.yaml b/examples/weather/graphcast/conf/config_small.yaml index 0f6d54b8ae..d5995f9bbe 100644 --- a/examples/weather/graphcast/conf/config_small.yaml +++ b/examples/weather/graphcast/conf/config_small.yaml @@ -40,8 +40,8 @@ cugraphops_decoder: False recompute_activation: False wb_mode: "online" synthetic_dataset: false -dataset_path: "/data/era5_73var" #"/code/datasets/era5_73var" -static_dataset_path: "/code/static" #"/code/mnabian/static" +dataset_path: "/data/era5_73var" +static_dataset_path: "/code/static" latlon_res: [181, 360] num_samples_per_year_train: 1408 num_workers: 8 diff --git a/test/models/graphcast/test_grad_checkpointing.py b/test/models/graphcast/test_grad_checkpointing.py index 4528f1993b..f145afba8b 100644 --- a/test/models/graphcast/test_grad_checkpointing.py +++ b/test/models/graphcast/test_grad_checkpointing.py @@ -29,7 +29,6 @@ def test_grad_checkpointing(device, pytestconfig, num_channels=2, res_h=15, res_ # constants model_kwds = { - "static_dataset_path": None, "multimesh_level": 2, "input_res": (res_h, res_w), "input_dim_grid_nodes": num_channels,