Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add example which trains a distributed GraphCast model on shallow-water-equations data #400

Merged
merged 52 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
c7ec335
init swe example
Mar 13, 2024
22d143e
update distributed examples, minor fixes here and there
stadlmax Mar 18, 2024
209cbb9
fix default parameter
stadlmax Mar 18, 2024
cd7f527
update example
Mar 19, 2024
83489a6
format code
Mar 19, 2024
3ea5868
wip: imp dist swe example, prepare for final runs
Apr 8, 2024
2ce8623
format code
Apr 8, 2024
8b334a1
overwrite default output paths
stadlmax Apr 8, 2024
bf593f8
fix a few things
Apr 8, 2024
dd9f679
update training script
Apr 19, 2024
3171e93
update
stadlmax Apr 19, 2024
57a780d
Merge branch 'example-distributed-gnn-swe' of github.com:stadlmax/mod…
Apr 19, 2024
99e7bb7
add some updates
Apr 22, 2024
8ae9528
fix some things, improve tests
stadlmax Apr 23, 2024
c3b735c
update train.py
stadlmax Apr 23, 2024
3087e8b
format code
stadlmax Apr 23, 2024
96a8e69
format code
stadlmax Apr 23, 2024
2373b8e
push some clean-up changes from runs on eos
Apr 26, 2024
9fec1f6
replace send/recv pairs with all_to_all_v calls and update tests
stadlmax Apr 30, 2024
2aec08d
format code
stadlmax Apr 30, 2024
99afcb7
Merge branch 'main' of github.com:stadlmax/modulus into example-distr…
stadlmax Apr 30, 2024
5811a88
add eos updates
May 8, 2024
2e242a4
some changes
stadlmax May 8, 2024
0a77f21
format code
stadlmax May 8, 2024
c0e477e
Merge branch 'main' of github.com:stadlmax/modulus into example-distr…
stadlmax May 8, 2024
8f5f9c9
incorporate feedback
stadlmax May 17, 2024
36a43a1
update headers
stadlmax May 17, 2024
8826aa7
Merge branch 'main' of github.com:stadlmax/modulus into example-distr…
stadlmax May 17, 2024
311a70a
make partitioning logic more general
stadlmax May 17, 2024
bdb4491
format
stadlmax May 17, 2024
7a736d1
Merge branch 'main' of github.com:stadlmax/modulus into example-distr…
stadlmax May 21, 2024
84f0c3a
address feedback
stadlmax May 23, 2024
2634cda
resolve conflicts
stadlmax May 23, 2024
537ad3d
format
stadlmax May 23, 2024
35e2081
further address feedback
stadlmax May 23, 2024
59f592d
minor updates to tests to avoid hangs
stadlmax May 23, 2024
f5b8d25
format
stadlmax May 23, 2024
40b79f4
fix example
stadlmax May 23, 2024
66756bb
remove mention of icosphere.json in gitignore and update changelog
stadlmax May 23, 2024
d74728e
update manager
stadlmax May 23, 2024
af39937
resolve conflict in changelog
May 27, 2024
e50093a
update README
stadlmax May 27, 2024
e255e1f
format
stadlmax May 27, 2024
3e3df7d
format README
stadlmax May 27, 2024
b47b62b
resolve conflicts
stadlmax May 29, 2024
62e0f1d
minor pde cleanup
stadlmax May 29, 2024
cf5263e
resolve conflicts
stadlmax May 29, 2024
01900d7
address feedback (setup of process groups, mp+dp, resolve conflits, s…
stadlmax May 29, 2024
01d4155
format code
stadlmax May 29, 2024
799c490
resolve conflicts
stadlmax May 31, 2024
e03026a
fix changelog issue
stadlmax May 31, 2024
380f977
fix lint issue in md
stadlmax May 31, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,6 @@ checkpoints/
outputs/
multirun/
.hydra/

# SWE example
icospheres_*.json
146 changes: 146 additions & 0 deletions examples/cfd/swe_distributed_gnn/pde_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# coding=utf-8
stadlmax marked this conversation as resolved.
Show resolved Hide resolved

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import torch

from math import ceil

from shallow_water_equations import ShallowWaterSolver


class PdeDataset(torch.utils.data.Dataset):
"""Custom Dataset class for PDE training data"""

def __init__(
self,
dt,
nsteps,
dims=(384, 768),
pde="shallow water equations",
initial_condition="random",
num_examples=32,
device=torch.device("cpu"),
normalize=True,
rank=0,
stream=None,
dtype=torch.float32,
):
self.dtype = dtype

self.num_examples = num_examples
self.device = device
self.stream = stream
self.rank = rank

self.nlat = dims[0]
self.nlon = dims[1]

# number of solver steps used to compute the target
self.nsteps = nsteps
self.normalize = normalize

if pde == "shallow water equations":
lmax = ceil(self.nlat / 3)
mmax = lmax
dt_solver = dt / float(self.nsteps)
self.solver = (
ShallowWaterSolver(
self.nlat,
self.nlon,
dt_solver,
lmax=lmax,
mmax=mmax,
grid="equiangular",
)
.to(self.device)
.float()
)
else:
raise NotImplementedError

self.set_initial_condition(ictype=initial_condition)

inp0, tar0 = self._get_sample()
self.inp_shape = inp0.shape
self.tar_shape = tar0.shape

if self.normalize:
self.inp_mean = torch.mean(inp0, dim=(-1, -2)).reshape(-1, 1, 1)
self.inp_var = torch.var(inp0, dim=(-1, -2)).reshape(-1, 1, 1)

def __len__(self):
length = self.num_examples if self.ictype == "random" else 1
return length

def set_initial_condition(self, ictype="random"):
self.ictype = ictype

def set_num_examples(self, num_examples=32):
self.num_examples = num_examples

def _get_sample(self):
if self.ictype == "random":
inp = self.solver.random_initial_condition(mach=0.2)
elif self.ictype == "galewsky":
inp = self.solver.galewsky_initial_condition()
akshaysubr marked this conversation as resolved.
Show resolved Hide resolved

# solve pde for n steps to return the target
tar = self.solver.timestep(inp, self.nsteps)
inp = self.solver.spec2grid(inp)
tar = self.solver.spec2grid(tar)

return inp, tar

def __getitem__(self, index):
if self.rank == 0:
with torch.inference_mode():
with torch.no_grad():
inp, tar = self._get_sample()

if self.normalize:
inp = (inp - self.inp_mean) / torch.sqrt(self.inp_var)
tar = (tar - self.inp_mean) / torch.sqrt(self.inp_var)
inp = inp.clone()
akshaysubr marked this conversation as resolved.
Show resolved Hide resolved
tar = tar.clone()

if inp.dtype != self.dtype:
inp = inp.to(dtype=self.dtype)
tar = tar.to(dtype=self.dtype)

else:
inp = torch.empty(
(3, self.nlat, self.nlon), device=self.device, dtype=self.dtype
)
tar = torch.empty(
(3, self.nlat, self.nlon), device=self.device, dtype=self.dtype
)

return inp, tar
2 changes: 2 additions & 0 deletions examples/cfd/swe_distributed_gnn/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
cartopy
torch-harmonics
Loading