Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added mse_loss #218

Merged
merged 41 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
3b8ccca
feat: rough mse_loss implementation
k223kim Apr 16, 2024
62cb80a
feat: added is_cpu_scalar_tensor for maybe_broadcast
k223kim Apr 18, 2024
8b4a503
feat: added mse_loss in torch/__init__.py
k223kim Apr 18, 2024
0ef9ac6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2024
e859a1c
feat: updated preserve_cpu_scalar_tensors to be False by default
k223kim Apr 18, 2024
296c9e1
feat: updated torchex with mse_loss
k223kim Apr 18, 2024
a41305c
feat: pre-commit
k223kim Apr 18, 2024
5f5deee
feat: added test cases for mse_loss
k223kim Apr 18, 2024
af5635b
feat: pre-commit
k223kim Apr 18, 2024
831f2f1
feat: updated test code for mse_loss
Apr 18, 2024
26e7d30
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2024
1906e9a
feat: updated test code for mse_loss
Apr 18, 2024
8221f54
feat: added decorateinfo for mse_loss
k223kim Apr 22, 2024
9050609
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2024
1a2c714
feat: updated utilities for mse_loss
k223kim Apr 22, 2024
9cd96cd
feat: updated mse_loss implementation with additional test case
k223kim Apr 23, 2024
bd3112d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
7efd9f0
feat: updated mse_loss implementation and broadcast_tensors
k223kim Apr 23, 2024
42c7a4a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
b5b7eec
Merge branch 'main' into k223kim/add_mse_loss
k223kim Apr 23, 2024
c012350
feat: commented grad related code
k223kim Apr 23, 2024
a296742
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
75945ce
feat: add mse_loss to vjp_op_force
k223kim Apr 23, 2024
fccf612
feat: removed broadcasting
k223kim Apr 24, 2024
7867fa5
Merge branch 'main' into k223kim/add_mse_loss
k223kim Apr 24, 2024
0721496
feat: implemented nan_to_num with test case
Apr 22, 2024
2c464f2
feat: updated test case for nan_to_num to test nans and floats properly
k223kim Apr 22, 2024
c243d11
Revert "feat: updated test case for nan_to_num to test nans and float…
k223kim Apr 22, 2024
43a6a32
feat: updated test case for nan_to_num to test nan and inf properly
k223kim Apr 22, 2024
88cf743
feat: small fix in test case
k223kim Apr 23, 2024
fc4e4c3
feat: updated nan_to_num to handle overflow
k223kim Apr 23, 2024
a73275f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
0886096
feat: updated to resemble a.clone()
k223kim Apr 23, 2024
3a6dd37
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
4b3dda9
feat: quick fix
Apr 23, 2024
44793d5
feat: added docstring for nan_to_num with additional test cases
k223kim Apr 23, 2024
50d3484
feat: quick fix2
k223kim Apr 24, 2024
143801c
feat: removed comments and unwanted code
k223kim Apr 24, 2024
cb3c947
fix: removed non mse_loss related code
k223kim Apr 24, 2024
e849b94
feat: added custom_comparator for failed test case
k223kim Apr 24, 2024
e99badd
Merge branch 'main' into k223kim/add_mse_loss
k223kim Apr 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,14 +1213,21 @@ def matrix_transpose(a: TensorProxy) -> TensorProxy:
# TODO: add scalar support
# TODO: review hasattr pattern
@clangop()
def maybe_broadcast(*args):
def maybe_broadcast(*args, preserve_cpu_scalar_tensors: bool = False):
k223kim marked this conversation as resolved.
Show resolved Hide resolved
# torch sets preserve_cpu_scalar_tensors=True as default
# but this throws error in several areas so I have set this to False
# once this implementation has been confirmed, I can set this back to True and
# modify else where error occurs
"""Returns tensors with the same shape, possibly broadcasting inputs to the result shape."""

# Computes common shape
common_shape = compute_broadcast_shape(*map(lambda t: t.shape if hasattr(t, "shape") else None, args))

def _maybe_broadcast(x, shape):
if hasattr(x, "shape"):
if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x):
return x

if not utils.same_shape(x.shape, common_shape):
return expand(x, common_shape)

Expand Down
4 changes: 4 additions & 0 deletions thunder/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,10 @@ def _reify(x):
#


def is_cpu_scalar_tensor(a: TensorProxy) -> bool:
k223kim marked this conversation as resolved.
Show resolved Hide resolved
return a.ndim == 0 and a.device.devicetype == devices.DeviceType.CPU


# TODO: improve device handling by canonicalizing devices and expressing them per langctx
# TODO: should the comparison between devices be ==?
def check_same_device(*args):
Expand Down
30 changes: 30 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,7 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim:
conv1d = _register_torch_operation("conv1d", module=torch.nn.functional)
conv2d = _register_torch_operation("conv2d", module=torch.nn.functional)
conv3d = _register_torch_operation("conv3d", module=torch.nn.functional)
mse_loss = _register_torch_operation("mse_loss", module=torch.nn.functional)
cross_entropy = _register_torch_operation("cross_entropy", module=torch.nn.functional)
dropout = _register_torch_operation("dropout", module=torch.nn.functional)
embedding = _register_torch_operation("embedding", module=torch.nn.functional)
Expand Down Expand Up @@ -1329,6 +1330,30 @@ def _convolution_transform(
return convolution(a, weight, bias, stride, padding, dilation, bool(transposed), output_padding, groups)


# def _mse_loss_backward_impl(
k223kim marked this conversation as resolved.
Show resolved Hide resolved
# g: torch.Tensor,
# a: torch.Tensor,
# target: torch.Tensor,
# reduction: str,
# ) -> torch.Tensor:

# if reduction == "none":
# reduction_idx = 0
# elif reduction == "mean":
# reduction_idx = 1
# elif reduction == "sum":
# reduction_idx = 2
# else:
# reduction_idx = -1

# utils.check(
# reduction_idx > -1 and reduction_idx < 3,
# lambda: f"{reduction} is not a valid value for reduction parameter.",
# )

# return torch.ops.aten.mse_loss_backward(g, a, target, reduction_idx)


def _cross_entropy_backward_impl(
g: torch.Tensor,
a: torch.Tensor,
Expand Down Expand Up @@ -1534,6 +1559,11 @@ def _pad_prim_impl(
_register_implementation(ltorch.conv1d, conv1d, checker=_always_executable)
_register_implementation(ltorch.conv2d, conv2d, checker=_always_executable)
_register_implementation(ltorch.conv3d, conv3d, checker=_always_executable)
_register_implementation(ltorch.mse_loss, mse_loss, checker=_always_executable)
# mse_loss_backward = ex.register_operator(
k223kim marked this conversation as resolved.
Show resolved Hide resolved
# "torch_mse_loss_backward_impl", meta=ltorch.mse_loss_backward, fn=_mse_loss_backward_impl
# )
# _register_implementation(ltorch.mse_loss_backward, mse_loss_backward, checker=_always_executable)
_register_implementation(ltorch.cross_entropy, cross_entropy, checker=_always_executable)
cross_entropy_backward = ex.register_operator(
"torch_cross_entropy_backward_impl", meta=ltorch.cross_entropy_backward, fn=_cross_entropy_backward_impl
Expand Down
57 changes: 57 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -7316,6 +7316,63 @@ def nll_loss_error_generator(op, device, dtype=torch.float32, **kwargs):
nn_ops.append(nll_loss_opinfo)


def mse_loss_sample_generator(op, device, dtype, requires_grad, **kwards):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

# input_shape, target_shape
shapes = (
((2, 16), (2, 16)),
k223kim marked this conversation as resolved.
Show resolved Hide resolved
((7, 18), (7, 18)),
((3, 4, 2, 3), (3, 4, 2, 3)),
((3, 4, 2, 3), (4, 1, 3)),
k223kim marked this conversation as resolved.
Show resolved Hide resolved
((2, 3, 1), (3, 1)),
)

reduction_options = ("none", "mean", "sum")

for shape, reduction_str in itertools.product(shapes, reduction_options):
input_shape, target_shape = shape

C = input_shape[1] if len(input_shape) >= 2 else input_shape[0]
yield SampleInput(
make(input_shape, low=0.0, high=1.0, dtype=dtype, requires_grad=True),
make(target_shape, low=0.0, high=1.0, dtype=dtype, requires_grad=True),
reduction=reduction_str,
)


mse_loss_opinfo = OpInfo(
ltorch.mse_loss,
# supports_grad=True,
k223kim marked this conversation as resolved.
Show resolved Hide resolved
sample_input_generator=mse_loss_sample_generator,
torch_reference=torch.nn.functional.mse_loss,
dtypes=(datatypes.floating,),
test_directives=(
# NOTE: PyTorch does not support bf16 mse_loss
DecorateInfo(
pytest.mark.skip,
"test_core_vs_torch_consistency",
dtypes=(datatypes.bfloat16,),
devicetypes=(devices.DeviceType.CPU,),
),
# NOTE: currently, mse_loss is encountering the following errors
# RuntimeError: "mse_cpu" not implemented for 'BFloat16'
# RuntimeError: "mse_backward_cpu_out" not implemented for 'Half'
DecorateInfo(
k223kim marked this conversation as resolved.
Show resolved Hide resolved
pytest.mark.skip,
"test_phantom_grad_vs_torch_consistency",
dtypes=(
datatypes.bfloat16,
datatypes.float16,
),
devicetypes=(devices.DeviceType.CPU,),
),
),
)

nn_ops.append(mse_loss_opinfo)


def interpolate_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

Expand Down
1 change: 1 addition & 0 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"split",
"stack",
"cumsum",
"mse_loss",
}


Expand Down
46 changes: 46 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3907,6 +3907,52 @@ def nll_loss_backward(
return TensorProxy(like=g, shape=a.shape)


@torchsymbol(torch.nn.functional.mse_loss)
def mse_loss(
a: TensorLike,
/,
target: TensorLike,
size_average: None | Any = None,
reduce: None | Any = None,
reduction: str = "mean",
) -> TensorLike:
utils.check(
size_average is None and reduce is None,
lambda: f"Deprecated size_average={size_average} and reduce={reduce} is not supported!",
)
utils.check(
reduction in ("none", "sum", "mean"),
lambda: f'Expected reduction string to be "none", "sum", or "mean", but it is {reduction}.',
exception_type=ValueError,
)

# warn broadcasting
if a.size() != target.size():
warnings.warn(
k223kim marked this conversation as resolved.
Show resolved Hide resolved
f"Using a target size {target.size()} that is different to the input size {a.size()}"
"This will likely lead to incorrect results due to broadcasting."
"Please ensure they have the same size."
)
out = (a - target) ** 2
k223kim marked this conversation as resolved.
Show resolved Hide resolved

# maybe add _apply_loss_reduction
# (like https://github.com/pytorch/pytorch/blob/df5829d0babaefc6e271897d6fffd40073d8b723/torch/_refs/nn/functional/__init__.py#L490)
# not sure if this would be useful
if reduction == "none":
return out
elif reduction == "sum":
return sum(out)
elif reduction == "mean":
return mean(out)
else:
raise ValueError(f"Reduction argument {reduction} to mse_loss is not supported")


# @torchsymbol("mse_loss_backward", id="mse_loss_backward", is_prim=True)
k223kim marked this conversation as resolved.
Show resolved Hide resolved
# def mse_loss_backward(g, a, /, target, reduction):
# return TensorProxy(like=g, shape=a.shape)


# TODO Add annotations
# NOTE The scale parameter is kwarg-only in PyTorch
@torchsymbol(torch.nn.functional.scaled_dot_product_attention)
Expand Down
Loading