In [6]:
import torch
import torchvision.models as models

from bnn import BConfig, prepare_binary_model
# Import a few examples of quantizers
from bnn.ops import BasicInputBinarizer, BasicScaleBinarizer, XNORWeightBinarizer

model = models.resnet18(weights='ResNet18_Weights.DEFAULT')
bconfig = BConfig(
    activation_pre_process = BasicInputBinarizer,
    activation_post_process = BasicScaleBinarizer,
    # optionally, one can pass certain custom variables
    weight_pre_process = XNORWeightBinarizer.with_args(center_weights=True)
)
bmodel = prepare_binary_model(model, bconfig)

from torchsummary import summary
summary(bmodel, (3, 32, 32), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
BasicInputBinarizer-1            [-1, 3, 32, 32]               0
XNORWeightBinarizer-2              [-1, 3, 7, 7]               0
BasicScaleBinarizer-3           [-1, 64, 16, 16]               0
            Conv2d-4           [-1, 64, 16, 16]           9,408
       BatchNorm2d-5           [-1, 64, 16, 16]             128
              ReLU-6           [-1, 64, 16, 16]               0
         MaxPool2d-7             [-1, 64, 8, 8]               0
BasicInputBinarizer-8             [-1, 64, 8, 8]               0
XNORWeightBinarizer-9             [-1, 64, 3, 3]               0
BasicScaleBinarizer-10             [-1, 64, 8, 8]               0
           Conv2d-11             [-1, 64, 8, 8]          36,864
      BatchNorm2d-12             [-1, 64, 8, 8]             128
             ReLU-13             [-1, 64, 8, 8]               0
BasicInputBinarizer-14          

In [11]:
from pthflops import count_ops

device = 'cuda:0' if torch.cuda.is_available() else "cpu"
inp = torch.rand(1, 3, 224, 224).to(device)

all_ops, all_data = count_ops(bmodel, inp, ignore_layers=['conv1'])
flops, bops = 0, 0
for op_name, ops_count in all_data.items():
    if 'Conv2d' in op_names and 'onnx::' not in op_name:
        bops += ops_count
    else:
        flops += ops_count
        
print('Total number of FLOPs: {}', flops)
print('Total number of BOPs: {}', bops)

TraceError: symbolically traced variables cannot be used as inputs to control flow