Skip to content

Commit

Permalink
Merge pull request #224 from SnowMasaya/clamp
Browse files Browse the repository at this point in the history
Clamp isolated min/max
  • Loading branch information
jaybdub committed Jan 7, 2020
2 parents d526b24 + 1756aa9 commit e22844a
Showing 1 changed file with 75 additions and 7 deletions.
82 changes: 75 additions & 7 deletions torch2trt/converters/clamp.py
Expand Up @@ -85,18 +85,28 @@ def test_tensor_clamp_max():


# CLAMP


@tensorrt_converter('torch.clamp')
@tensorrt_converter('torch.Tensor.clamp')
def convert_clamp(ctx):
input = ctx.method_args[0]
min_val = ctx.method_args[1]
max_val = ctx.method_args[2]
output = ctx.method_return

layer = __add_clamp(ctx.network, input._trt, min_val, trt.ElementWiseOperation.MAX)
layer = __add_clamp(ctx.network, layer.get_output(0), max_val, trt.ElementWiseOperation.MIN)
if "min" in ctx.method_kwargs and "max" in ctx.method_kwargs:
min_val = ctx.method_kwargs["min"]
max_val = ctx.method_kwargs["max"]
layer = __add_clamp(ctx.network, input._trt, min_val, trt.ElementWiseOperation.MAX)
layer = __add_clamp(ctx.network, layer.get_output(0), max_val, trt.ElementWiseOperation.MIN)
elif "min" in ctx.method_kwargs:
min_val = ctx.method_kwargs["min"]
layer = __add_clamp(ctx.network, input._trt, min_val, trt.ElementWiseOperation.MAX)
elif "max" in ctx.method_kwargs:
max_val = ctx.method_kwargs["max"]
layer = __add_clamp(ctx.network, input._trt, max_val, trt.ElementWiseOperation.MIN)
else:
min_val = ctx.method_args[1]
max_val = ctx.method_args[2]
layer = __add_clamp(ctx.network, input._trt, min_val, trt.ElementWiseOperation.MAX)
layer = __add_clamp(ctx.network, layer.get_output(0), max_val, trt.ElementWiseOperation.MIN)

output._trt = layer.get_output(0)

Expand All @@ -118,4 +128,62 @@ def forward(self, x):

@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)])
def test_tensor_clamp():
return TensorClamp()
return TensorClamp()


class TorchClampOptionMax(torch.nn.Module):
def forward(self, x):
return torch.clamp(x, max=0.1)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)])
def test_torch_clamp_option_max():
return TorchClampOptionMax()

class TorchClampOptionMin(torch.nn.Module):
def forward(self, x):
return torch.clamp(x, min=-0.1)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)])
def test_torch_clamp_option_min():
return TorchClampOptionMin()


class TorchClampOptionMaxMin(torch.nn.Module):
def forward(self, x):
return torch.clamp(x, min=-0.1, max=0.1)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)])
def test_torch_clamp_option_max_min():
return TorchClampOptionMaxMin()


class TensorClampOptionMax(torch.nn.Module):
def forward(self, x):
return x.clamp(max=0.1)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)])
def test_tensor_clamp_option_max():
return TensorClampOptionMax()

class TensorClampOptionMin(torch.nn.Module):
def forward(self, x):
return x.clamp(min=-0.1)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)])
def test_tensor_clamp_option_min():
return TensorClampOptionMin()


class TensorClampOptionMaxMin(torch.nn.Module):
def forward(self, x):
return x.clamp(min=-0.1, max=0.1)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)])
def test_tensor_clamp_option_max_min():
return TensorClampOptionMaxMin()

0 comments on commit e22844a

Please sign in to comment.