Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Quantize my model question? #360

Closed
BlossomingL opened this issue Aug 23, 2019 · 13 comments
Closed

Quantize my model question? #360

BlossomingL opened this issue Aug 23, 2019 · 13 comments

Comments

@BlossomingL
Copy link

BlossomingL commented Aug 23, 2019

Hi~, I implemented my model quantization like this:
`import distiller
from distiller.quantization import PostTrainLinearQuantizer

quantizer = PostTrainLinearQuantizer(model)
quantizer.prepare_model(torch.rand(*your_input_shape))
apputils.save_checkpoint(0, 'mymodel', model, optimizer=None, name='model', dir='quantization')`

But output model did not resize to 1/4,just not changed, would you please to help me? Thanks!

@nzmora
Copy link
Contributor

nzmora commented Aug 23, 2019

Hi @Linxxxx ,

Please see our FAQ.
Cheers
Neta

@BlossomingL
Copy link
Author

@nzmora Thanks for your quick reply! But I am not understand the example of resnet20-cifar can output 1/2 model(8 bits and 16 bits are same?).

@nzmora
Copy link
Contributor

nzmora commented Aug 25, 2019

Hi @Linxxxx,
I did not understand your question. Please explain in more detail what you don't understand.
Cheers
Neta

@BlossomingL
Copy link
Author

BlossomingL commented Aug 26, 2019

@nzmora My question is "Are all quantized model will not be compressed through quantization on computer?"

@BlossomingL
Copy link
Author

BlossomingL commented Aug 26, 2019

@nzmora Hi, I had another question about quantize my own model. When I random a dummy_input to prepare_model, I found this in your code, as show:

self.model.quantizer_metadata["dummy_input"] = dummy_input
if dummy_input is not None:
summary_graph = distiller.SummaryGraph(self.model, dummy_input)
self.adjacency_map = summary_graph.adjacency_map(dedicated_modules_only=False)

But 'self.adjacency_map' was not used in your project, what's the function of this step?
Thanks!

@nzmora
Copy link
Contributor

nzmora commented Aug 26, 2019

Hi @Linxxxx

Question 1: correct, as explained in the documentation, to benefit from quantization you need the framework and HW to support execution of quantized models. Pytorch 1.2 supports quantization, so hopefully this will change soon.

Question 2: class SummaryGraph has an adjacency map which "Returns a mapping from each op in the graph to its immediate predecessors and successors". This helps us traverse the data dependencies between the layers in the graph. For example, we might need to know that a Convolution layer is followed by a BatchNorm layer. Using plain Pytorch, it is hard to know this.

Cheers
Neta

@BlossomingL
Copy link
Author

BlossomingL commented Aug 26, 2019

@nzmora Thanks! For question 1: When I try your example of resnet20-cifar, the quantized model is compressed from 2.2M(original) to 1.1M(INT8). According to you, the quantized model should not be compressed? I did not understand.

@BlossomingL
Copy link
Author

BlossomingL commented Aug 26, 2019

@nzmora For question 2: When I quantize my own model, in distiller/summary_graph.py function add_macs_attr, I got an exception:
E0826 15:55:58.187671 140177155864384 summary_graph.py:314] An input to a Convolutional layer is missing shape information (MAC values will be wrong) E0826 15:55:58.187805 140177155864384 summary_graph.py:316] For details see https://github.com/NervanaSystems/distiller/issues/168
I found the error source code:

            try:
                # MACs = volume(OFM) * (#IFM * K^2) / #Groups
                op['attrs']['MACs'] = int(
                    ofm_vol * SummaryGraph.volume(conv_w) * self.params[conv_in]['shape'][1] / groups)
            except IndexError:
                # Todo: change the method for calculating MACs
                msglogger.error("An input to a Convolutional layer is missing shape information "
                                "(MAC values will be wrong)")
                msglogger.error("For details see https://github.com/NervanaSystems/distiller/issues/168")
                op['attrs']['MACs'] = 0

I debug the error then I found self.params[conv_in]['shape'][1] is wrong(error: tuple index out of range),
But in this function, you set a try-exception module, so it will continue running, but in the next function add_footprint_attr, you did not set a try-exception module, so code will interrupt(error: tuple index out of range), do you have any suggestion?

@nzmora
Copy link
Contributor

nzmora commented Aug 26, 2019

Hi @Linxxxx ,

Please provide your PyTorch model.
Thanks
Neta

@BlossomingL
Copy link
Author

BlossomingL commented Aug 27, 2019

@nzmora Hello, This is my model:

`import copy
import torch
import torch.nn as nn
from torchvision.models.resnet import resnet50, Bottleneck
from opt import opt

num_classes = opt.num_classes # change this depend on your dataset

class MGN(nn.Module):
def init(self):
super(MGN, self).init()

    feats = 256
    resnet = resnet50(pretrained=True)

    self.backbone = nn.Sequential(
        resnet.conv1,
        resnet.bn1,
        resnet.relu,
        resnet.maxpool,
        resnet.layer1,
        resnet.layer2,
        resnet.layer3[0],
    )

    res_conv4 = nn.Sequential(*resnet.layer3[1:])

    res_g_conv5 = resnet.layer4

    res_p_conv5 = nn.Sequential(
        Bottleneck(1024, 512, downsample=nn.Sequential(nn.Conv2d(1024, 2048, 1, bias=False), nn.BatchNorm2d(2048))),
        Bottleneck(2048, 512),
        Bottleneck(2048, 512))
    res_p_conv5.load_state_dict(resnet.layer4.state_dict())

    self.p1 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_g_conv5))
    self.p2 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5))
    self.p3 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5))

    self.maxpool_zg_p1 = nn.MaxPool2d(kernel_size=(12, 4))
    self.maxpool_zg_p2 = nn.MaxPool2d(kernel_size=(24, 8))
    self.maxpool_zg_p3 = nn.MaxPool2d(kernel_size=(24, 8))
    self.maxpool_zp2 = nn.MaxPool2d(kernel_size=(12, 8))
    self.maxpool_zp3 = nn.MaxPool2d(kernel_size=(8, 8))

    self.reduction = nn.Sequential(nn.Conv2d(2048, feats, 1, bias=False), nn.BatchNorm2d(feats), nn.ReLU())

    self._init_reduction(self.reduction)

    self.fc_id_2048_0 = nn.Linear(feats, num_classes)
    self.fc_id_2048_1 = nn.Linear(feats, num_classes)
    self.fc_id_2048_2 = nn.Linear(feats, num_classes)

    self.fc_id_256_1_0 = nn.Linear(feats, num_classes)
    self.fc_id_256_1_1 = nn.Linear(feats, num_classes)
    self.fc_id_256_2_0 = nn.Linear(feats, num_classes)
    self.fc_id_256_2_1 = nn.Linear(feats, num_classes)
    self.fc_id_256_2_2 = nn.Linear(feats, num_classes)

    self._init_fc(self.fc_id_2048_0)
    self._init_fc(self.fc_id_2048_1)
    self._init_fc(self.fc_id_2048_2)

    self._init_fc(self.fc_id_256_1_0)
    self._init_fc(self.fc_id_256_1_1)
    self._init_fc(self.fc_id_256_2_0)
    self._init_fc(self.fc_id_256_2_1)
    self._init_fc(self.fc_id_256_2_2)

@staticmethod
def _init_reduction(reduction):
    # conv
    nn.init.kaiming_normal_(reduction[0].weight, mode='fan_in')
    # nn.init.constant_(reduction[0].bias, 0.)

    # bn
    nn.init.normal_(reduction[1].weight, mean=1., std=0.02)
    nn.init.constant_(reduction[1].bias, 0.)

@staticmethod
def _init_fc(fc):
    nn.init.kaiming_normal_(fc.weight, mode='fan_out')
    # nn.init.normal_(fc.weight, std=0.001)
    nn.init.constant_(fc.bias, 0.)

def forward(self, x):
    x = self.backbone(x)

    p1 = self.p1(x)
    p2 = self.p2(x)
    p3 = self.p3(x)

    zg_p1 = self.maxpool_zg_p1(p1)
    zg_p2 = self.maxpool_zg_p2(p2)
    zg_p3 = self.maxpool_zg_p3(p3)

    zp2 = self.maxpool_zp2(p2)
    z0_p2 = zp2[:, :, 0:1, :]
    z1_p2 = zp2[:, :, 1:2, :]

    zp3 = self.maxpool_zp3(p3)
    z0_p3 = zp3[:, :, 0:1, :]
    z1_p3 = zp3[:, :, 1:2, :]
    z2_p3 = zp3[:, :, 2:3, :]

    fg_p1 = self.reduction(zg_p1).squeeze(dim=3).squeeze(dim=2)
    fg_p2 = self.reduction(zg_p2).squeeze(dim=3).squeeze(dim=2)
    fg_p3 = self.reduction(zg_p3).squeeze(dim=3).squeeze(dim=2)
    f0_p2 = self.reduction(z0_p2).squeeze(dim=3).squeeze(dim=2)
    f1_p2 = self.reduction(z1_p2).squeeze(dim=3).squeeze(dim=2)
    f0_p3 = self.reduction(z0_p3).squeeze(dim=3).squeeze(dim=2)
    f1_p3 = self.reduction(z1_p3).squeeze(dim=3).squeeze(dim=2)
    f2_p3 = self.reduction(z2_p3).squeeze(dim=3).squeeze(dim=2)

    l_p1 = self.fc_id_2048_0(fg_p1)
    l_p2 = self.fc_id_2048_1(fg_p2)
    l_p3 = self.fc_id_2048_2(fg_p3)

    l0_p2 = self.fc_id_256_1_0(f0_p2)
    l1_p2 = self.fc_id_256_1_1(f1_p2)
    l0_p3 = self.fc_id_256_2_0(f0_p3)
    l1_p3 = self.fc_id_256_2_1(f1_p3)
    l2_p3 = self.fc_id_256_2_2(f2_p3)

    predict = torch.cat([fg_p1, fg_p2, fg_p3, f0_p2, f1_p2, f0_p3, f1_p3, f2_p3], dim=1)

    return predict, fg_p1, fg_p2, fg_p3, l_p1, l_p2, l_p3, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3

`

@nzmora
Copy link
Contributor

nzmora commented Aug 27, 2019

Hi @Linxxxx

Please try adding a try-except block in add_footprint_attr like you suggested and tell me if this helps, and then I'll add it to the code. You are correct that it should be there, but I'd like to know if it help solves your issue.
Thanks
Neta

@BlossomingL
Copy link
Author

BlossomingL commented Aug 27, 2019

@nzmora Hi~ It does work, but is there any influence to the quantization result? This is my revised code:
try:
n_ifm = self.param_shape(conv_in)[1]
n_ofm = self.param_shape(conv_out)[1]
weights_vol = kernel_size * n_ifm * n_ofm / group
op['attrs']['n_ifm'] = n_ifm
op['attrs']['n_ofm'] = n_ofm
op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol
op['attrs']['fm_vol'] = ofm_vol + ifm_vol
op['attrs']['weights_vol'] = weights_vol
except IndexError:
# Todo: change the method for calculating MACs
msglogger.error('{} {} did not exist'.format(conv_in, conv_out))
op['attrs']['n_ifm'] = 0
op['attrs']['n_ofm'] = 0
op['attrs']['footprint'] = ofm_vol + ifm_vol
op['attrs']['fm_vol'] = ofm_vol + ifm_vol
op['attrs']['weights_vol'] = 0

nzmora added a commit that referenced this issue Oct 25, 2019
Add try/except block around code accessing missing convolution
shape information.
@nzmora
Copy link
Contributor

nzmora commented Oct 25, 2019

Please re-open if you need further information or assistance.

@nzmora nzmora closed this as completed Oct 25, 2019
michaelbeale-IL pushed a commit that referenced this issue Apr 24, 2023
Add try/except block around code accessing missing convolution
shape information.
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants