This restriction also comes from the implementation of torch.fx
(to do the quantization). torch.fx
will translate our torch.nn.Module
into torch.fx.GraphModule
, which is a torch.nn.Module
instance that holds a Graph as well as a forward method generated from the Graph. During this process, all the self-defined class members will be discarded. Two cases is listed below:
- The first is also from the cellpose repo. The code snippet is here.
if self.pretrained_model:
self.net.load_model(self.pretrained_model[0], device=self.device)
In this case, the self.net
is a torch.nn.Module
instance, and is the target for our quantization function. During the calibration process in the quantization, the self.net
will be transformed to the torch.fx.GraphModule
and will cause an error when the self.net.load_model
is called.
- The second cases come from the Deconoising repo. The code snippet is from here
def predict(im, net, device, outScaling):
stdTorch=torch.Tensor(np.array(net.std)).to(device)
meanTorch=torch.Tensor(np.array(net.mean)).to(device)
...
In this example, we use this function as our calibration function. However, the net
is a torch.nn.Module
instance, which will be transformed to torch.fx.GraphModule
and cause an error when the net.std
is called. We can precomputed these values.