Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions modulus/metrics/general/crps.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ def _crps_gaussian(mean: Tensor, std: Tensor, obs: Union[Tensor, np.ndarray]) ->
Computes:

.. math:
CRPS(mean, std, y) = std * [ \\frac{1}{\\pi} - 2 \\phi ( \\frac{x-mean}{std} ) -
CRPS(mean, std, y) = std * [ \\frac{1}{\\sqrt{\\pi}}} - 2 \\phi ( \\frac{x-mean}{std} ) -
( \\frac{x-mean}{std} ) * (2 \\Phi(\\frac{x-mean}{std}) - 1) ]

where \\phi and \\Phi are the normal gaussian pdf/cdf respectively.

Parameters
----------
mean : Tensor
Expand Down Expand Up @@ -69,9 +71,11 @@ def _crps_gaussian(mean: Tensor, std: Tensor, obs: Union[Tensor, np.ndarray]) ->

d = (obs - mean) / std
phi = torch.exp(-0.5 * d**2) / torch.sqrt(torch.as_tensor(2 * torch.pi))

# Note, simplified expression below is not exactly Gaussian CDF
Phi = torch.erf(d / torch.sqrt(torch.as_tensor(2.0)))

return 2 * phi + (obs - mean) * Phi - std / torch.sqrt(torch.as_tensor(torch.pi))
return std * (2 * phi + d * Phi - 1.0 / torch.sqrt(torch.as_tensor(torch.pi)))


def _crps_from_cdf(
Expand Down
100 changes: 64 additions & 36 deletions modulus/metrics/general/ensemble_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ class EnsembleMetrics(ABC):
def __init__(
self,
input_shape: Union[Tuple[int, ...], List[int]],
device: torch.device = "cpu",
device: Union[str, torch.device] = "cpu",
dtype: torch.dtype = torch.float32,
):
super().__init__()
self.input_shape = list(input_shape)
self.device = device
self.device = torch.device(device)
self.dtype = dtype

if DistributedManager.is_initialized() and not (dist.is_initialized()):
Expand All @@ -55,13 +55,13 @@ def __init__(
torch process group, see https://pytorch.org/docs/stable/distributed.html"
)

def _check_shape(self, input: Tensor) -> None:
def _check_shape(self, inputs: Tensor) -> None:
"""
Check input shapes for non-batched dimension.
"""
assert [i == s for (i, s) in zip(input.shape[1:], self.input_shape)], (
assert [i == s for (i, s) in zip(inputs.shape[1:], self.input_shape)], (
"Expected new input to have compatible shape with existing shapes but got"
+ str(input.shape)
+ str(inputs.shape)
+ "and"
+ str(self.input_shape)
+ "."
Expand Down Expand Up @@ -89,7 +89,7 @@ def finalize(self, *args):
def _update_mean(
old_sum: Tensor,
old_n: Union[int, Tensor],
input: Tensor,
inputs: Tensor,
batch_dim: Union[int, None] = 0,
) -> Tuple[Tensor, Union[int, Tensor]]:
"""Updated mean sufficient statistics given new data
Expand All @@ -116,11 +116,11 @@ def _update_mean(
Updated (rolling sum, number of samples)
"""
if batch_dim is None:
input = torch.unsqueeze(input, 0)
inputs = torch.unsqueeze(inputs, 0)
batch_dim = 0

new_sum = old_sum + torch.sum(input, dim=batch_dim)
new_n = old_n + input.shape[batch_dim]
new_sum = old_sum + torch.sum(inputs, dim=batch_dim)
new_n = old_n + inputs.shape[batch_dim]

return new_sum, new_n

Expand All @@ -139,23 +139,28 @@ class Mean(EnsembleMetrics):
def __init__(self, input_shape: Union[Tuple, List], **kwargs):
super().__init__(input_shape, **kwargs)
self.sum = torch.zeros(self.input_shape, dtype=self.dtype, device=self.device)
self.n = torch.zeros(1, dtype=torch.int32, device=self.device)
self.n = torch.zeros([1], dtype=torch.int32, device=self.device)

def __call__(self, input: Tensor) -> Tensor:
def __call__(self, inputs: Tensor, dim: int = 0) -> Tensor:
"""Calculate an initial mean

Parameters
----------
input : Tensor
inputs : Tensor
Input data
dim : Int
Dimension of batched data

Returns
-------
Tensor
Mean value
"""
self.sum = torch.sum(input, dim=0)
self.n = torch.as_tensor(input.shape[0])
assert (
inputs.device == self.device
), f"Input device, {inputs.device}, and Module device, {self.device}, must be the same."
self.sum = torch.sum(inputs, dim=dim)
self.n = torch.as_tensor([inputs.shape[dim]], device=self.device)
# TODO(Dallas) Move distributed calls into finalize.

if DistributedManager.is_initialized() and dist.is_initialized():
Expand All @@ -164,29 +169,41 @@ def __call__(self, input: Tensor) -> Tensor:

return self.sum / self.n

def update(self, input: Tensor) -> Tensor:
def update(self, inputs: Tensor, dim: int = 0) -> Tensor:
"""Update current mean and essential statistics with new data

Parameters
----------
input : Tensor
Input tensor
inputs : Tensor
Inputs tensor
dim : int
Dimension of batched data

Returns
-------
Tensor
Current mean value
"""
self._check_shape(input)
self._check_shape(inputs)
assert (
inputs.device == self.device
), f"Inputs device, {inputs.device}, and Module device, {self.device}, must be the same."

# TODO(Dallas) Move distributed calls into finalize.
if DistributedManager.is_initialized() and dist.is_initialized():
sums, n = _update_mean(self.sum, self.n, input, batch_dim=0)
# Collect local sums, n
sums = torch.sum(inputs, batch_dim=dim)
n = torch.as_tensor([inputs.shape[dim]], device=self.device)

# Reduce
dist.all_reduce(sums, op=dist.ReduceOp.SUM)
dist.all_reduce(n, op=dist.ReduceOp.SUM)

# Update
self.sum += sums
self.n += n
else:
self.sum, self.n = _update_mean(self.sum, self.n, input, batch_dim=0)
self.sum, self.n = _update_mean(self.sum, self.n, inputs, batch_dim=dim)
return self.sum / self.n

def finalize(
Expand All @@ -208,7 +225,7 @@ def _update_var(
old_sum: Tensor,
old_sum2: Tensor,
old_n: Union[int, Tensor],
input: Tensor,
inputs: Tensor,
batch_dim: Union[int, None] = 0,
) -> Tuple[Tensor, Tensor, Union[int, Tensor]]:
"""Updated variance sufficient statistics given new data
Expand All @@ -224,7 +241,7 @@ def _update_var(
Current, or old, running squared sum
old_n : Union[int, Tensor]
Current, or old, number of samples
input : Tensor
inputs : Tensor
New input to add to current/old sum. May be batched, in which case the batched
dimension must be flagged by passing an int to batch_dim.
batch_dim : Union[int, None], optional
Expand All @@ -245,12 +262,12 @@ def _update_var(
"""

if batch_dim is None:
input = torch.unsqueeze(input, 0)
inputs = torch.unsqueeze(inputs, 0)
batch_dim = 0

temp_n = input.shape[batch_dim]
temp_sum = torch.sum(input, dim=batch_dim)
temp_sum2 = torch.sum((input - temp_sum / temp_n) ** 2, dim=batch_dim)
temp_n = inputs.shape[batch_dim]
temp_sum = torch.sum(inputs, dim=batch_dim)
temp_sum2 = torch.sum((inputs - temp_sum / temp_n) ** 2, dim=batch_dim)

delta = old_sum * temp_n / old_n - temp_sum

Expand Down Expand Up @@ -282,35 +299,41 @@ class Variance(EnsembleMetrics):

def __init__(self, input_shape: Union[Tuple, List], **kwargs):
super().__init__(input_shape, **kwargs)
self.n = torch.zeros(1, dtype=torch.int32, device=self.device)
self.n = torch.zeros([1], dtype=torch.int32, device=self.device)
self.sum = torch.zeros(self.input_shape, dtype=self.dtype, device=self.device)
self.sum2 = torch.zeros(self.input_shape, dtype=self.dtype, device=self.device)

def __call__(self, inputs: Tensor) -> Tensor:
def __call__(self, inputs: Tensor, dim: int = 0) -> Tensor:
"""Calculate an initial variance

Parameters
----------
input : Tensor
inputs : Tensor
Input data
dim : Int
Dimension of batched data

Returns
-------
Tensor
Unbiased variance values
"""
self.sum = torch.sum(inputs, dim=0)
self.n = torch.as_tensor(inputs.shape[0])
# TODO(Dallas) Move distributed calls into finalize.

assert (
inputs.device == self.device
), f"Input device, {inputs.device}, and Module device, {self.device}, must be the same."
self.sum = torch.sum(inputs, dim=dim)
self.n = torch.as_tensor([inputs.shape[0]], device=self.device)

if DistributedManager.is_initialized() and dist.is_initialized():
# Compute mean and send around.
dist.all_reduce(self.sum, op=dist.ReduceOp.SUM)
dist.all_reduce(self.n, op=dist.ReduceOp.SUM)

self.sum2 = torch.sum((inputs - self.sum / self.n) ** 2, dim=0)
self.sum2 = torch.sum((inputs - self.sum / self.n) ** 2, dim=dim)
dist.all_reduce(self.sum2, op=dist.ReduceOp.SUM)
else:
self.sum2 = torch.sum((inputs - self.sum / self.n) ** 2, dim=0)
self.sum2 = torch.sum((inputs - self.sum / self.n) ** 2, dim=dim)

if self.n < 2.0:
return self.sum2
Expand All @@ -322,16 +345,21 @@ def update(self, inputs: Tensor) -> Tensor:

Parameters
----------
input : Tensor
inputs : Tensor
Input data

Returns
-------
Tensor
Unbiased variance tensor
"""

self._check_shape(inputs)
new_n = torch.as_tensor(inputs.shape[0])
assert (
inputs.device == self.device
), f"Input device, {inputs.device}, and Module device, {self.device}, must be the same."

new_n = torch.as_tensor([inputs.shape[0]], device=self.device)
new_sum = torch.sum(inputs, dim=0)
# TODO(Dallas) Move distributed calls into finalize.
if DistributedManager.is_initialized() and dist.is_initialized():
Expand Down
4 changes: 2 additions & 2 deletions test/metrics/test_metrics_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def test_means_var(device, rtol: float = 1e-3, atol: float = 1e-3):
x = torch.randn((10, 1, 72, 144), device=device)
y = torch.randn((5, 1, 72, 144), device=device)

M = em.Mean((1, 72, 144))
M = em.Mean((1, 72, 144), device=device)
meanx = M(x)
assert torch.allclose(meanx, torch.mean(x, dim=0))
meanxy = M.update(y)
Expand All @@ -347,7 +347,7 @@ def test_means_var(device, rtol: float = 1e-3, atol: float = 1e-3):
_sumxy, _n = em._update_mean(_sumxy, _n, y[1:], batch_dim=0)
assert torch.allclose(meanxy, _sumxy / _n, rtol=rtol, atol=atol)

V = em.Variance((1, 72, 144))
V = em.Variance((1, 72, 144), device=device)
varx = V(x)
assert torch.allclose(varx, torch.var(x, dim=0))
varxy = V.update(y)
Expand Down