In [3]:
# TODO: Remove

from typing import Tuple

import torch
from torch import Tensor
from torch.autograd import Function

from brevitas.function.ops import binary_sign
from brevitas.function.ops import dpu_round
from brevitas.function.ops import round_to_zero
from brevitas.function.ops import tensor_clamp
from brevitas.function.ops import tensor_clamp_

class ScalarClampMinSteFn(Function):
    """
    Autograd function that implements ``torch.clamp_min`` with a straight-through gradient estimator
    for the gradient of y w.r.t. to x, while the gradient of y w.r.t. to ``min_val`` is always
    ``None``.

    ``ScalarClampMinSteFn.apply(*args)`` is first aliased to :func:`scalar_clamp_min_ste_impl(*args)
    <brevitas.ops.autograd_ste_ops.scalar_clamp_min_ste_impl>` and then wrapped by
    :func:`~brevitas.function.ops_ste.scalar_clamp_min_ste` and invoked when env ``BREVITAS_JIT=0``.
    See :func:`~brevitas.function.ops_ste.scalar_clamp_ste` for details on the interface and
    examples.
    """

    @staticmethod
    def forward(ctx, x: Tensor, min_val: float) -> Tensor:
        y = torch.clamp_min(x, min_val)
        return y

    @staticmethod
    def backward(ctx, grad_y: Tensor) -> Tuple[Tensor, None]:
        return grad_y, None

    @staticmethod
    def symbolic(g, x: Tensor, min_val: float):
        y = g.op('Clip', x, torch.tensor(min_val))
        return y

In [20]:
import numpy as np

from brevitas.function.ops_ste import *

scalar_clamp_min_ste_impl = ScalarClampMinSteFn.apply

EPS = 1.
x = torch.tensor(-2.)
min_val = torch.tensor(EPS)
x.requires_grad_(True)
out = torch.copysign(scalar_clamp_min_ste(abs_binary_sign_grad(x), min_val), x)
#out = scalar_clamp_min_ste_impl(x, min_val)*torch.tensor([1., 1., 1.])
out.backward(torch.tensor(1.), retain_graph=True)
print(x.grad)
x.grad = None

for x in np.linspace(-1, 1, 20):
    x = torch.tensor(x)
    x.requires_grad_(True)
    out = scalar_clamp_min_ste_impl(x, min_val)
    out = out * out
    out.backward(torch.tensor(1.), retain_graph=True)
    print(x.grad)
    x.grad = None



tensor(1.)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)
tensor(2., dtype=torch.float64)


In [5]:
import numpy as np

scalar_clamp_min_ste_impl = ScalarClampMinSteFn.apply

import brevitas
class BrevitasClamp(brevitas.jit.ScriptModule):

    @brevitas.jit.script_method
    def forward(self, x: Tensor, min_val: float):
        return torch.copysign(scalar_clamp_min_ste_impl(torch.abs(x), min_val), x)


from brevitas.core.function_wrapper.ops_ste import ScalarSignedClampMinSte
EPS = 0.
x = torch.tensor(0.)

x.requires_grad_(True)
out = ScalarSignedClampMinSte(min_val)(x)
out.backward(torch.tensor(1.), retain_graph=True)
print(f"{x.item():.2f}, {out.item():.2f}, {x.grad.item():.2f}")
x.grad = None

for x in np.linspace(-0.5, 0.5, 20):
    x = torch.tensor(x)
    x.requires_grad_(True)
    out = ScalarSignedClampMinSte(min_val)(x)
    out.backward(torch.tensor(1.), retain_graph=True)
    print(f"{x.item():.2f}, {out.item():.2f}, {x.grad.item():.2f}")
    x.grad = None

0.00, 0.00, 0.00
-0.50, -0.50, 1.00
-0.45, -0.45, 1.00
-0.39, -0.39, 1.00
-0.34, -0.34, 1.00
-0.29, -0.29, 1.00
-0.24, -0.24, 1.00
-0.18, -0.18, 1.00
-0.13, -0.13, 1.00
-0.08, -0.08, 1.00
-0.03, -0.03, 1.00
0.03, 0.03, 1.00
0.08, 0.08, 1.00
0.13, 0.13, 1.00
0.18, 0.18, 1.00
0.24, 0.24, 1.00
0.29, 0.29, 1.00
0.34, 0.34, 1.00
0.39, 0.39, 1.00
0.45, 0.45, 1.00
0.50, 0.50, 1.00
