Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GraphCast improvements - Part I #510

Merged
merged 10 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions examples/weather/graphcast/__init__.py

This file was deleted.

36 changes: 19 additions & 17 deletions examples/weather/graphcast/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ hydra:
dir: ./outputs/

processor_layers: 16
hidden_dim: 64 # 512
hidden_dim: 512
multimesh_level: 6
segments: 1
norm_type: "TELayerNorm"
# "TELayerNorm" or "LayerNorm"
force_single_checkpoint: False
checkpoint_encoder: True
checkpoint_processor: False
Expand All @@ -32,18 +35,18 @@ checkpoint_encoder_finetune: True
checkpoint_processor_finetune: True
checkpoint_decoder_finetune: True
concat_trick: True
cugraphops_encoder: False
cugraphops_processor: False
cugraphops_decoder: False
recompute_activation: False
wb_mode: "disabled"
cugraphops_encoder: True
cugraphops_processor: True
cugraphops_decoder: True
recompute_activation: True
wb_mode: "online"
synthetic_dataset: false
dataset_path: "/data"
static_dataset_path: "datasets/static"
latlon_res: (721, 1440)
num_samples_per_year_train: 1448
static_dataset_path: null
latlon_res: [721, 1440]
num_samples_per_year_train: 1408
num_workers: 0 # 8
num_channels: 3 # 34
num_channels: 474
stadlmax marked this conversation as resolved.
Show resolved Hide resolved
num_channels_val: 3
num_val_steps: 8
num_val_spy: 3 # SPY: Samples Per Year
Expand All @@ -59,13 +62,12 @@ num_iters_step1: 1000
num_iters_step2: 299000
num_iters_step3: 11000
step_change_freq: 1000
save_freq: 1 # 500
val_freq: 1 # 1000
ckpt_path: "checkpoints_34var"
val_dir: "validation_34var"
ckpt_name: "model_34var.pt"
use_apex: False
save_freq: 500
val_freq: 1000
ckpt_path: "checkpoints"
val_dir: "validation"
ckpt_name: "model"
use_apex: True
pyt_profiler: False
profile: False
profile_range: (90, 110)
icospheres_path: "icospheres.json"
73 changes: 73 additions & 0 deletions examples/weather/graphcast/conf/config_small.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: True
run:
dir: ./outputs_small/

processor_layers: 16
hidden_dim: 512
multimesh_level: 5
segments: 1
norm_type: "TELayerNorm"
force_single_checkpoint: False
checkpoint_encoder: False
checkpoint_processor: False
checkpoint_decoder: False
force_single_checkpoint_finetune: False
checkpoint_encoder_finetune: False
checkpoint_processor_finetune: False
checkpoint_decoder_finetune: False
concat_trick: False
cugraphops_encoder: False
cugraphops_processor: False
cugraphops_decoder: False
recompute_activation: False
wb_mode: "online"
synthetic_dataset: false
dataset_path: "/data/era5_73var" #"/code/datasets/era5_73var"
mnabian marked this conversation as resolved.
Show resolved Hide resolved
static_dataset_path: "/code/static" #"/code/mnabian/static"
latlon_res: [181, 360]
num_samples_per_year_train: 1408
num_workers: 8
num_channels_climate: 73
num_channels_static: 1
num_channels_val: 3
num_val_steps: 8
num_val_spy: 3 # SPY: Samples Per Year
grad_clip_norm: 32.0
jit: False
amp: False
amp_dtype: "bfloat16"
full_bf16: True
watch_model: False
lr: 1e-3
lr_step3: 3e-7
num_iters_step1: 1000
num_iters_step2: 299000
num_iters_step3: 11000
step_change_freq: 1000
save_freq: 500
val_freq: 1000
ckpt_path: "checkpoints_small"
val_dir: "validation_small"
ckpt_name: "model_small"
use_apex: True
pyt_profiler: False
profile: False
profile_range: (90, 110)
1 change: 1 addition & 0 deletions examples/weather/graphcast/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
git+https://github.com/NVIDIA/TransformerEngine.git@stable
75 changes: 68 additions & 7 deletions examples/weather/graphcast/train_graphcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from validation import Validation
from modulus.datapipes.climate import ERA5HDF5Datapipe, SyntheticWeatherDataLoader
from modulus.distributed import DistributedManager
from modulus.utils.graphcast.data_utils import StaticData

import hydra
from hydra.utils import to_absolute_path
Expand Down Expand Up @@ -85,16 +86,25 @@ def __init__(self, cfg: DictConfig, dist, rank_zero_logger):
else:
raise ValueError("Invalid dtype for config amp")

# Handle the number of static channels
if not self.static_dataset_path:
cfg.num_channels_static = 0
rank_zero_logger.warning(
"Static dataset path is not provided. Setting num_channels_static to 0."
)

# instantiate the model
self.model = GraphCastNet(
meshgraph_path=to_absolute_path(cfg.icospheres_path),
static_dataset_path=static_dataset_path,
input_dim_grid_nodes=cfg.num_channels,
multimesh_level=cfg.multimesh_level,
input_res=tuple(cfg.latlon_res),
input_dim_grid_nodes=cfg.num_channels_climate + cfg.num_channels_static,
input_dim_mesh_nodes=3,
input_dim_edges=4,
output_dim_grid_nodes=cfg.num_channels,
output_dim_grid_nodes=cfg.num_channels_climate,
processor_layers=cfg.processor_layers,
hidden_dim=cfg.hidden_dim,
norm_type=cfg.norm_type,
do_concat_trick=cfg.concat_trick,
use_cugraphops_encoder=cfg.cugraphops_encoder,
use_cugraphops_processor=cfg.cugraphops_processor,
Expand Down Expand Up @@ -140,10 +150,14 @@ def __init__(self, cfg: DictConfig, dist, rank_zero_logger):
DataPipe = (
SyntheticWeatherDataLoader if cfg.synthetic_dataset else ERA5HDF5Datapipe
)
self.interpolation_shape = (
cfg.latlon_res if cfg.latlon_res != (721, 1440) else None
) # interpolate if not in native resolution
self.datapipe = DataPipe(
data_dir=to_absolute_path(os.path.join(cfg.dataset_path, "train")),
stats_dir=to_absolute_path(os.path.join(cfg.dataset_path, "stats")),
channels=[i for i in range(cfg.num_channels)],
channels=[i for i in range(cfg.num_channels_climate)],
interpolation_shape=self.interpolation_shape,
num_samples_per_year=cfg.num_samples_per_year_train,
num_steps=1,
batch_size=1,
Expand Down Expand Up @@ -178,11 +192,17 @@ def __init__(self, cfg: DictConfig, dist, rank_zero_logger):
self.criterion = CellAreaWeightedLossFunction(self.area)
try:
self.optimizer = apex.optimizers.FusedAdam(
self.model.parameters(), lr=cfg.lr, betas=(0.9, 0.95), weight_decay=0.1
self.model.parameters(),
lr=cfg.lr,
betas=(0.9, 0.95),
adam_w_mode=True,
weight_decay=0.1,
)
rank_zero_logger.info("Using FusedAdam optimizer")
except:
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.lr)
self.optimizer = torch.optim.AdamW(
self.model.parameters(), lr=cfg.lr, betas=(0.9, 0.95), weight_decay=0.1
)
scheduler1 = LinearLR(
self.optimizer,
start_factor=1e-3,
Expand Down Expand Up @@ -214,6 +234,36 @@ def __init__(self, cfg: DictConfig, dist, rank_zero_logger):
device=dist.device,
)

# Get the static data
if self.static_dataset_path:
self.static_data = StaticData(
static_dataset_path, self.latitudes, self.longitudes
).get()
self.static_data = self.static_data.to(dtype=self.dtype).to(
device=dist.device
)
assert cfg.num_channels_static == self.static_data.size(1), (
f"Number of static channels in model ({cfg.num_channels_static}) "
+ f"does not match the static data ({self.static_data.size(1)})"
)
if (
self.model.is_distributed and self.model.expect_partitioned_input
): # TODO verify
# if input itself is distributed, we also need to distribute static data
self.static_data(
self.static_data[0].view(cfg.num_channels_static, -1).permute(1, 0)
)
self.static_data = self.g2m_graph.get_src_node_features_in_partition(
self.static_data
)
self.static_data = self.static_data.permute(1, 0).unsqueeze(dim=0)
self.static_data = self.static_data.to(dtype=self.dtype).to(
device=dist.device
)

else:
self.static_data = None


@hydra.main(version_base="1.3", config_path="conf", config_name="config")
def main(cfg: DictConfig) -> None:
Expand Down Expand Up @@ -244,11 +294,15 @@ def main(cfg: DictConfig) -> None:
entity="Modulus",
name="GraphCast-Training",
group="GraphCast-DDP-Group",
mode=cfg.wb_mode,
) # Wandb logger
logger = PythonLogger("main") # General python logger
rank_zero_logger = RankZeroLoggingWrapper(logger, dist) # Rank 0 logger
rank_zero_logger.file_logging()

# print ranks and devices
logger.info(f"Rank: {dist.rank}, Device: {dist.device}")

# specify the datapipe
if cfg.synthetic_dataset:
DataPipe = SyntheticWeatherDataLoader
Expand All @@ -271,6 +325,7 @@ def main(cfg: DictConfig) -> None:
iter < cfg.num_iters_step1 + cfg.num_iters_step2 + cfg.num_iters_step3
), "Training is already finished!"
for i, data in enumerate(trainer.datapipe):

# profiling
if cfg.profile and iter == cfg.profile_range[0]:
rank_zero_logger.info("Starting profile", "green")
Expand Down Expand Up @@ -320,7 +375,9 @@ def main(cfg: DictConfig) -> None:
trainer.datapipe = DataPipe(
data_dir=os.path.join(cfg.dataset_path, "train"),
stats_dir=os.path.join(cfg.dataset_path, "stats"),
channels=[i for i in range(cfg.num_channels)],
channels=[i for i in range(cfg.num_channels_climate)],
interpolation_shape=trainer.interpolation_shape,
num_samples_per_year=cfg.num_samples_per_year_train,
num_steps=num_rollout_steps,
batch_size=1,
num_workers=cfg.num_workers,
Expand All @@ -338,6 +395,10 @@ def main(cfg: DictConfig) -> None:
# TODO modify for history > 0
data_x = data[0]["invar"]
data_y = data[0]["outvar"]

# add static data
invar = torch.concat((invar, trainer.static_data), dim=1)

# move to device & dtype
data_x = data_x.to(dtype=trainer.dtype)
grid_nfeat = data_x
Expand Down
4 changes: 4 additions & 0 deletions examples/weather/graphcast/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,14 @@ def __init__(self, cfg: DictConfig, model, dtype, dist):
self.model = model
self.dtype = dtype
self.dist = dist
interpolation_shape = (
cfg.latlon_res if cfg.latlon_res != (721, 1440) else None
) # interpolate if not in native resolution
self.val_datapipe = ERA5HDF5Datapipe(
data_dir=os.path.join(cfg.dataset_path, "test"),
stats_dir=os.path.join(cfg.dataset_path, "stats"),
channels=[i for i in range(cfg.num_channels)],
interpolation_shape=interpolation_shape,
num_steps=cfg.num_val_steps,
batch_size=1,
num_samples_per_year=cfg.num_val_spy,
Expand Down
Loading