Skip to content

add model-parallel metrics with tests#975

Merged
mcgibbon merged 5 commits intoai2cm:mainfrom
E3SM-Project:feature/model-parallel-metrics
Mar 20, 2026
Merged

add model-parallel metrics with tests#975
mcgibbon merged 5 commits intoai2cm:mainfrom
E3SM-Project:feature/model-parallel-metrics

Conversation

@mahf708
Copy link
Contributor

@mahf708 mahf708 commented Mar 15, 2026

Implement both zonal_mean and gradient_magnitude_percent_diff for ModelTorchDistributed backend. The gradient_magnitude_percent_diff impl is quite inefficient, and tha'ts okay because it's not really on a critical path (and maybe we should just diable it?)

Changes:

  • a minor API change in the base backend class exposing a gather_spatial_tensor, and then building on that with gather_spatial in distributed. Not 100% sure this is better or worse, it makes life slightly easier for downstream use (see comment below)

  • Tests added (in the "right" place)

result[k] = self.spatial_reduce_sum(global_tensor)
return result
return {
k: self._distributed.gather_spatial_tensor(v, img_shape)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this now gets something from base.py, which felt a lil odd because that base.py was bare. I think the new is better, but I am not 100% sure. It helps in later gather ops in the backends...

Comment on lines +377 to +379
self.gather_spatial_tensor(truth, img_shape),
self.gather_spatial_tensor(predicted, img_shape),
weights=self.gather_spatial_tensor(weights, img_shape),
Copy link
Contributor Author

@mahf708 mahf708 Mar 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a "correct" impl, but not efficient. we have some paths forward for a distributed future:

  1. a fully correct impl with halo exchanges (will need to code up that stuff carefully, etc.)
  2. something fully correct, but in between (e.g., avoid excessive compute on all ranks, avoid excessive alloc on all ranks, etc.)
  3. a nearly correct impl that simply ignores boundaries, see below for a sketch
  def gradient_magnitude_percent_diff(                                                                       
        self,                                                                                                
        truth: torch.Tensor,                                                                                 
        predicted: torch.Tensor,                                                                             
        weights: torch.Tensor,                                                                               
        dim: tuple[int, ...],                                                                                
    ) -> torch.Tensor:                                                                                       
        from fme.core.metrics import gradient_magnitude                                                      
                                                                                                             
        truth_grad_mag = gradient_magnitude(truth, dim)                                                      
        predicted_grad_mag = gradient_magnitude(predicted, dim)                                              
        truth_mean = self._weighted_nanmean(truth_grad_mag, weights, dim)                                    
        predicted_mean = self._weighted_nanmean(predicted_grad_mag, weights, dim)                            
        return 100 * (predicted_mean - truth_mean) / truth_mean                                              
                                                                                                             
    def _weighted_nanmean(                                                                                   
        self,                                                                                                
        data: torch.Tensor,                                                                                  
        weights: torch.Tensor,                                                                               
        dim: tuple[int, ...],                                                                                
    ) -> torch.Tensor:                                                                                       
        expanded_weights = weights.expand(data.shape)                                                        
        valid_weights = torch.where(torch.isnan(data), 0.0, expanded_weights)                                
        local_weighted_sum = torch.nan_to_num(data * expanded_weights).sum(dim=dim)                          
        local_weight_sum = valid_weights.sum(dim=dim)                                                        
        return self.spatial_reduce_sum(local_weighted_sum) / self.spatial_reduce_sum(                        
            local_weight_sum                                                                                 
        ) 

let's open an issue or ... shall we tuck this under the rug 👀 I'd go for the last option and forget about it if this were up to me

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you open an issue for this? I think it's fine for this to go in now but we should address it later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes let's open an issue for this and address it later.

@mahf708
Copy link
Contributor Author

mahf708 commented Mar 16, 2026

@elynnwu the one fail is a timer fail --- is it something i should look into?

Comment on lines +377 to +379
self.gather_spatial_tensor(truth, img_shape),
self.gather_spatial_tensor(predicted, img_shape),
weights=self.gather_spatial_tensor(weights, img_shape),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you open an issue for this? I think it's fine for this to go in now but we should address it later.

"""All-reduce sum across spatial (h, w) ranks. Identity for non-spatial."""
...

def gather_spatial_tensor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean by base not being purely abstract. The logic here is technically shared across all backends, does that make it a good exception? I'm deferring this to @mcgibbon, specifically if this is a concern if we start to add more concrete methods in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to wait on this until later; in the meanwhile, I put this PR to skip the main metric needing special treatment (it's not strictly needed). See: #983

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Move this to Distributed instead of on the backend ABC.

Question: Why is gather_global not sufficient for this purpose? I'd like to avoid two ways to do the same thing, if possible, and it seems like they do the same thing except without the data-parallel portion of the gather. If we really need to be able to do a spatial-only gather, we could update gather_global to allow None for the data-parallel dim, and do spatial-only in this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Why is gather_global not sufficient for this purpose? I'd like to avoid two ways to do the same thing, if possible [...]

The only difference is the signature/return ... gather_spatial returns a stack of tensors, but gather_spatial_tensor returns a single tensor (from that stack). That's why I made gather_spatial use gather_spatial_tensor ... but maybe there's an easier/cleaner way without editing all files using gather_spatial (which are only tests anyway)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    def gather_spatial(
        self, data: dict[str, torch.Tensor], img_shape: tuple[int, int]
    ) -> dict[str, torch.Tensor]:
        return {
            k: self._distributed.gather_spatial_tensor(v, img_shape)
            for k, v in data.items()
        }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I misread your question 💀

the difference between gather_global and this: gather_global does stuff on the root only whereas this does all ranks, right? I did consider this earlier and settled on the hierarchical approach in gather_spatial borrowing from gather_spatial_tensor

@mahf708
Copy link
Contributor Author

mahf708 commented Mar 17, 2026

Setting this to draft, and we can live with #983 in the meanwhile.

Comment on lines +370 to +374
h_total = torch.tensor(truth.shape[-2], device=truth.device)
w_total = torch.tensor(truth.shape[-1], device=truth.device)
torch.distributed.all_reduce(h_total, group=self._h_group)
torch.distributed.all_reduce(w_total, group=self._w_group)
img_shape = (int(h_total), int(w_total))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably this should be a helper function, or more likely this code should go into gather_spatial_tensor as a way of supporting img_shape=None (you could consider still requiring the argument, so the user has to explicitly say "yeah I don't have the img_shape in this scope", given how important it is to pass along.

We probably could also speed this up by passing the img_shape through here, we have access to it where we call this function (at least, it gets passed in to the aggregator init's in general).

Comment on lines +377 to +379
self.gather_spatial_tensor(truth, img_shape),
self.gather_spatial_tensor(predicted, img_shape),
weights=self.gather_spatial_tensor(weights, img_shape),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes let's open an issue for this and address it later.

torch.distributed.all_reduce(w_total, group=self._w_group)
img_shape = (int(h_total), int(w_total))

return metrics.gradient_magnitude_percent_diff(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not now: Later I'd like to flip this so that metrics.gradient_magnitude_percent_diff calls helper methods on dist, instead of vice-versa. E.g. the efficient version of this probably involves dist.weighted_mean_diff_h and dist.weighted_mean_diff_w functions called inside gradient_magnitude_percent_diff.

For example, you could do it this way:

  • pass your left data to/receive your right data from the rank next to you along the h group
  • compute weighted_mean_diff_h locally on each rank for its internal data, including the right side halo
  • gather these diffs to the global root

The mean_diff_w is identical except with a different dimension and comm group, so they can both use the same helper function.

zonal_mean feels like an atomic distributed function that long-term should exist on Distributed (though maybe later in a more general form), but grad mag percent diff doesn't.

@mahf708 mahf708 force-pushed the feature/model-parallel-metrics branch from edc03a6 to 013c8f7 Compare March 18, 2026 18:50
@mahf708 mahf708 marked this pull request as ready for review March 18, 2026 18:50
@mahf708 mahf708 requested review from elynnwu and mcgibbon March 18, 2026 19:18
Copy link
Contributor

@mcgibbon mcgibbon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, one question and one or two lines to potentially revert.

truth = self.gather_spatial_tensor(truth, img_shape)
predicted = self.gather_spatial_tensor(predicted, img_shape)
weights = self.gather_spatial_tensor(weights, img_shape)
return metrics.gradient_magnitude_percent_diff(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very simple!

Do we expect this to return a globally reduced value or just a spatially-reduced value? Currently, this returned value needs to afterwards also get passed into a dist.reduce_mean call, which I think is fine/expected/congruent with all our other metrics, but I wanted to check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want the global value, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, here you're returning the spatially-global but not rank-global value. Assuming the gather is allgather, each rank gets the grad mag percent diff for its data-parallel segment, and then later when we call dist.reduce_mean those get reduced to a truly global value for grad mag percent diff.

@mahf708 mahf708 force-pushed the feature/model-parallel-metrics branch from e4c42e3 to 3a5b42b Compare March 19, 2026 20:17
@mcgibbon mcgibbon enabled auto-merge (squash) March 20, 2026 15:45
@mcgibbon mcgibbon merged commit d2f1525 into ai2cm:main Mar 20, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants