Skip to content

Commit

Permalink
Feat (nn): deprecate QuantMaxPool (#858)
Browse files Browse the repository at this point in the history
  • Loading branch information
capnramses authored Feb 22, 2024
1 parent d3fc994 commit 4e9a643
Show file tree
Hide file tree
Showing 8 changed files with 2 additions and 266 deletions.
42 changes: 0 additions & 42 deletions src/brevitas/export/onnx/standard/qoperator/handler/pool.py

This file was deleted.

8 changes: 2 additions & 6 deletions src/brevitas/export/onnx/standard/qoperator/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
from .handler.parameter import StdQOpONNXQuantConv1dHandler
from .handler.parameter import StdQOpONNXQuantConv2dHandler
from .handler.parameter import StdQOpONNXQuantLinearHandler
from .handler.pool import StdQOpONNXQuantMaxPool1d
from .handler.pool import StdQOpONNXQuantMaxPool2d


class StdQOpONNXManager(StdONNXBaseManager):
Expand All @@ -43,7 +41,7 @@ class StdQOpONNXManager(StdONNXBaseManager):
F.max_pool3d,
F.adaptive_max_pool1d,
F.adaptive_max_pool2d,
F.adaptive_max_pool3d,]
F.adaptive_max_pool3d]

handlers = [
StdQOpONNXQuantConv1dHandler,
Expand All @@ -53,9 +51,7 @@ class StdQOpONNXManager(StdONNXBaseManager):
StdQOpONNXQuantHardTanhHandler,
StdQOpONNXQuantIdentityHandler,
StdQOpONNXQuantTanhHandler,
StdQOpONNXQuantSigmoidHandler,
StdQOpONNXQuantMaxPool1d,
StdQOpONNXQuantMaxPool2d]
StdQOpONNXQuantSigmoidHandler]

onnx_passes = [
# remove unused graph inputs & initializers
Expand Down
60 changes: 0 additions & 60 deletions src/brevitas/export/torch/qoperator/handler/pool.py

This file was deleted.

4 changes: 0 additions & 4 deletions src/brevitas/export/torch/qoperator/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,12 @@
from .handler.parameter import PytorchQuantConv1dHandler
from .handler.parameter import PytorchQuantConv2dHandler
from .handler.parameter import PytorchQuantLinearHandler
from .handler.pool import PytorchQuantMaxPool1d
from .handler.pool import PytorchQuantMaxPool2d


class TorchQOpManager(BaseManager):
target_name = 'torch'

handlers = [
PytorchQuantMaxPool1d,
PytorchQuantMaxPool2d,
PytorchQuantHardTanhHandler,
PytorchQuantIdentityHandler,
PytorchQuantReLUHandler,
Expand Down
2 changes: 0 additions & 2 deletions src/brevitas/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from .quant_eltwise import QuantEltwiseAdd
from .quant_embedding import QuantEmbedding
from .quant_linear import QuantLinear
from .quant_max_pool import QuantMaxPool1d
from .quant_max_pool import QuantMaxPool2d
from .quant_mha import QuantMultiheadAttention
from .quant_rnn import QuantLSTM
from .quant_rnn import QuantRNN
Expand Down
88 changes: 0 additions & 88 deletions src/brevitas/nn/quant_max_pool.py

This file was deleted.

1 change: 0 additions & 1 deletion tests/brevitas/export/test_qonnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantIdentity
from brevitas.nn import QuantLinear
from brevitas.nn import QuantMaxPool2d
from brevitas.nn import QuantReLU
from brevitas.nn import TruncAvgPool2d
from brevitas.quant.scaled_int import Int4WeightPerTensorFloatDecoupled
Expand Down
63 changes: 0 additions & 63 deletions tests/brevitas/export/test_torch_qop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantIdentity
from brevitas.nn import QuantLinear
from brevitas.nn import QuantMaxPool2d
from brevitas.nn import QuantReLU
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
from brevitas.quant.scaled_int import Int16Bias
Expand Down Expand Up @@ -203,65 +202,3 @@ def forward(self, x):
pytorch_out = pytorch_qf_model(inp)
atol = model.act2.quant_output_scale().item() * TOLERANCE
assert pytorch_out.isclose(brevitas_out, rtol=0.0, atol=atol).all()


@jit_disabled_for_export()
def test_quant_max_pool2d_export():
IN_SIZE = (1, 1, IN_CH, IN_CH)
KERNEL_SIZE = 3

class Model(torch.nn.Module):

def __init__(self):
super().__init__()
self.act = QuantIdentity(
bit_width=8, act_quant=ShiftedUint8ActPerTensorFloat, return_quant_tensor=True)
self.pool = QuantMaxPool2d(
kernel_size=KERNEL_SIZE, stride=KERNEL_SIZE, return_quant_tensor=False)

def forward(self, x):
return self.pool(self.act(x))

inp = torch.randn(IN_SIZE)
model = Model()
model(inp) # collect scale factors
model.eval()
inp = torch.randn(IN_SIZE) * RANDN_STD + RANDN_MEAN # New input with bigger range
brevitas_out = model(inp)
pytorch_qf_model = export_torch_qop(model, input_t=inp)
pytorch_out = pytorch_qf_model(inp)
atol = model.act.quant_output_scale().item() * TOLERANCE
assert pytorch_out.isclose(brevitas_out, rtol=0.0, atol=atol).all()


@requires_pt_ge('9999', 'Darwin')
@jit_disabled_for_export()
def test_func_quant_max_pool2d_export():
IN_SIZE = (1, 1, IN_CH, IN_CH)
KERNEL_SIZE = 2

class Model(torch.nn.Module):

def __init__(self):
super().__init__()
self.act1 = QuantIdentity(
bit_width=8, act_quant=ShiftedUint8ActPerTensorFloat, return_quant_tensor=True)
self.act2 = QuantIdentity(
bit_width=8, act_quant=ShiftedUint8ActPerTensorFloat, return_quant_tensor=False)

def forward(self, x):
x = self.act1(x)
x = torch.nn.functional.max_pool2d(x, KERNEL_SIZE)
x = self.act2(x)
return x

inp = torch.randn(IN_SIZE)
model = Model()
model(inp) # collect scale factors
model.eval()
inp = torch.randn(IN_SIZE) * RANDN_STD + RANDN_MEAN # New input with bigger range
brevitas_out = model(inp)
pytorch_qf_model = export_torch_qop(model, input_t=inp)
pytorch_out = pytorch_qf_model(inp)
atol = model.act2.quant_output_scale().item() * TOLERANCE
assert pytorch_out.isclose(brevitas_out, rtol=0.0, atol=atol).all()

0 comments on commit 4e9a643

Please sign in to comment.