Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GraphCast improvements - Part I #510

Merged
merged 10 commits into from
May 22, 2024
Merged

Conversation

mnabian
Copy link
Collaborator

@mnabian mnabian commented May 21, 2024

Modulus Pull Request

Description

Closes #506, #505, #486, #508, #509, #511, #516, #517

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

@mnabian mnabian self-assigned this May 21, 2024
@mnabian mnabian added the 3 - Ready for Review Ready for review by team label May 21, 2024
@mnabian
Copy link
Collaborator Author

mnabian commented May 21, 2024

/blossom-ci

@stadlmax
Copy link
Collaborator

@mnabian
Since you are revisiting GraphCast now, adding a few comments

  • Can we add the option to use transformer_engine.LayerNorm? In AIFS benchmarks, we just could get a 1.3x end-to-end improvement from doing so since the PyTorch implementation is rather bad for the sizes we encounter in these workloads.
  • Can you check whether the current combination of MeshGraphNodeBlock and MeshGraphEdgeBlock actually matches the paper (https://github.com/NVIDIA/modulus/blob/main/modulus/models/graphcast/graph_cast_processor.py#L97-L98) I created a schematic of the GraphCast architecture for some Arch folks last week, and I think the order of residuals over the edges does not match the paper here. I might have made a mistake when trying to use shared primitives here the last time. The issue here is that in MeshGraphNet, EdgeBlock already applies the "residual" on the edge features, while the NodeBlock would expect then the features including the residual connection prior to message-passing while in GraphCast, all residual connections are only applied after both the updated edge and node features are computed (at least according to the paper).
  • What would you think of splitting the GraphCastNet into a GraphCastNetERA5 and a GraphCastNet model? The current issue I see with GraphCastNet is that it is very specific to the nature of the ERA5 dataset (e.g. when it comes to preparing the input and output to switch between the HxW layout and the typical "serial" graph layout. GraphCastNet then could be a rather data-agnostic model defining the operations on (g2m_graph, mesh_graph, m2g_graph), while GraphCastNetERA5 defines the things somewhat specific to the workload like checkpointing, input/output conversions, etc.. In the longer term, I think it really could make sense to try to make things a bit more modular. In particular, this also includes things like "history" or the actual "prediction" mode, i.e. whether GraphCastNetERA5 predicts y_t = f(x_t-1) or y_t = x_t - 1 + f(x_t-1). It could make sense if the "backbone" is agnostic to these things while having a specialized prediction wrapper.

@mnabian
Copy link
Collaborator Author

mnabian commented May 21, 2024

@mnabian Since you are revisiting GraphCast now, adding a few comments

  • Can we add the option to use transformer_engine.LayerNorm? In AIFS benchmarks, we just could get a 1.3x end-to-end improvement from doing so since the PyTorch implementation is rather bad for the sizes we encounter in these workloads.
  • Can you check whether the current combination of MeshGraphNodeBlock and MeshGraphEdgeBlock actually matches the paper (https://github.com/NVIDIA/modulus/blob/main/modulus/models/graphcast/graph_cast_processor.py#L97-L98) I created a schematic of the GraphCast architecture for some Arch folks last week, and I think the order of residuals over the edges does not match the paper here. I might have made a mistake when trying to use shared primitives here the last time. The issue here is that in MeshGraphNet, EdgeBlock already applies the "residual" on the edge features, while the NodeBlock would expect then the features including the residual connection prior to message-passing while in GraphCast, all residual connections are only applied after both the updated edge and node features are computed (at least according to the paper).
  • What would you think of splitting the GraphCastNet into a GraphCastNetERA5 and a GraphCastNet model? The current issue I see with GraphCastNet is that it is very specific to the nature of the ERA5 dataset (e.g. when it comes to preparing the input and output to switch between the HxW layout and the typical "serial" graph layout. GraphCastNet then could be a rather data-agnostic model defining the operations on (g2m_graph, mesh_graph, m2g_graph), while GraphCastNetERA5 defines the things somewhat specific to the workload like checkpointing, input/output conversions, etc.. In the longer term, I think it really could make sense to try to make things a bit more modular. In particular, this also includes things like "history" or the actual "prediction" mode, i.e. whether GraphCastNetERA5 predicts y_t = f(x_t-1) or y_t = x_t - 1 + f(x_t-1). It could make sense if the "backbone" is agnostic to these things while having a specialized prediction wrapper.

Thanks @stadlmax , I'll add your comments to my epic and consider them all.

@mnabian mnabian requested a review from stadlmax May 21, 2024 16:50
@mnabian
Copy link
Collaborator Author

mnabian commented May 21, 2024

Note to myself: API updates breaks GraphCast tests. Need to update them all.

@mnabian
Copy link
Collaborator Author

mnabian commented May 21, 2024

@stadlmax as far as I remember, we were using fused layernorm and that gave us nice speedup:https://github.com/NVIDIA/modulus/blob/main/modulus/models/gnn_layers/mesh_graph_mlp.py#L157... Did you also compare transformer_engine.LayerNorm with fused layernorm?

@stadlmax
Copy link
Collaborator

@stadlmax as far as I remember, we were using fused layernorm and that gave us nice speedup (although I can't find it in the most recent code)... Did you also compare transformer_engine.LayerNorm with fused layernorm?

Yes, for AIFS, I found TE > APEX > PyTorch throughout a bunch of usual sizes AIFS had in their RFI benchmark. Especially the backward kernels in TE are much better for our cases. (reported numbers are runtimes, lower is better)

num_channels = 256

layer_norm_impl 1626240 x 256 327660 x 256 40962 x 256 542080 x 256 814540 x 256
apex 9.75127 2.03821 0.371149 3.32072 4.9402
pytorch 10.752 4.17265 0.957743 3.63721 10.2774
transformer_engine 2.59236 0.580879 0.801795 0.916124 1.33596

num_channels = 384

layer_norm_impl 1626240 x 384 327660 x 384 40962 x 384 542080 x 384 814540 x 384
apex 11.2164 2.3109 0.359366 3.79922 5.64847
pytorch 11.8419 4.33466 0.583828 3.99414 10.6802
transformer_engine 3.98762 0.849599 0.396184 1.38306 2.022

num_channels = 512

layer_norm_impl 1626240 x 512 327660 x 512 40962 x 512 542080 x 512 814540 x 512
apex 12.1739 2.50785 0.37578 4.11927 6.14573
pytorch 12.7752 4.5477 0.615464 4.30874 11.2191
transformer_engine 4.90182 1.04243 0.391352 1.6877 2.4967

@mnabian
Copy link
Collaborator Author

mnabian commented May 21, 2024

@stadlmax as far as I remember, we were using fused layernorm and that gave us nice speedup (although I can't find it in the most recent code)... Did you also compare transformer_engine.LayerNorm with fused layernorm?

Yes, for AIFS, I found TE > APEX > PyTorch throughout a bunch of usual sizes AIFS had in their RFI benchmark. Especially the backward kernels in TE are much better for our cases. (reported numbers are runtimes, lower is better)

num_channels = 256

layer_norm_impl 1626240 x 256 327660 x 256 40962 x 256 542080 x 256 814540 x 256
apex 9.75127 2.03821 0.371149 3.32072 4.9402
pytorch 10.752 4.17265 0.957743 3.63721 10.2774
transformer_engine 2.59236 0.580879 0.801795 0.916124 1.33596
num_channels = 384

layer_norm_impl 1626240 x 384 327660 x 384 40962 x 384 542080 x 384 814540 x 384
apex 11.2164 2.3109 0.359366 3.79922 5.64847
pytorch 11.8419 4.33466 0.583828 3.99414 10.6802
transformer_engine 3.98762 0.849599 0.396184 1.38306 2.022
num_channels = 512

layer_norm_impl 1626240 x 512 327660 x 512 40962 x 512 542080 x 512 814540 x 512
apex 12.1739 2.50785 0.37578 4.11927 6.14573
pytorch 12.7752 4.5477 0.615464 4.30874 11.2191
transformer_engine 4.90182 1.04243 0.391352 1.6877 2.4967

This is great comparison, thanks! I'll switch to te then. Do we have any reason to still keep fused layernorm from apex, or we should just remove it?

@stadlmax
Copy link
Collaborator

This is great comparison, thanks! I'll switch to te then. Do we have any reason to still keep fused layernorm from apex, or we should just remove it?

I guess, no, not really. TE also should be decently covered when it comes to development specifically for Blackwell and beyond. I know a few POCs that try to optimize The LN in TE even further.
If we are based on the DLFW containers, TE also should come pre-installed.

@mnabian
Copy link
Collaborator Author

mnabian commented May 21, 2024

@stadlmax added support for TE layernorm.

@mnabian
Copy link
Collaborator Author

mnabian commented May 21, 2024

Note to myself: API updates breaks GraphCast tests. Need to update them all.

Done

@mnabian
Copy link
Collaborator Author

mnabian commented May 21, 2024

/blossom-ci

@mnabian
Copy link
Collaborator Author

mnabian commented May 21, 2024

/blossom-ci

@mnabian
Copy link
Collaborator Author

mnabian commented May 22, 2024

/blossom-ci

@mnabian
Copy link
Collaborator Author

mnabian commented May 22, 2024

/blossom-ci

@mnabian
Copy link
Collaborator Author

mnabian commented May 22, 2024

/blossom-ci

@mnabian
Copy link
Collaborator Author

mnabian commented May 22, 2024

/blossom-ci

@mnabian mnabian requested a review from stadlmax May 22, 2024 19:24
@stadlmax
Copy link
Collaborator

Thanks for addressing the feedback, looks good to me.

@mnabian
Copy link
Collaborator Author

mnabian commented May 22, 2024

/blossom-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3 - Ready for Review Ready for review by team
Projects
None yet
Development

Successfully merging this pull request may close these issues.

GraphCast: Support for arbitrary input shapes
3 participants