Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concatenation Layer Issue with PyTorch ResNet #931

Open
clw5710 opened this issue Nov 30, 2023 · 0 comments
Open

Concatenation Layer Issue with PyTorch ResNet #931

clw5710 opened this issue Nov 30, 2023 · 0 comments
Labels

Comments

@clw5710
Copy link

clw5710 commented Nov 30, 2023

Quick Summary

Hello, I am currently experiencing issues using hls4ml to convert a ResNet20 PyTorch model. I believe the issue is coming from how the ResNet model is being parsed. Specifically, when printing out the hls4ml-interpreted topology the batch dimension seems to be included in a torch.zeros tensor that is being used by a concatenation layer, giving the tensor an extra dimension.

More Details

Topology:

Layer name: conv_in, layer type: Conv2D, input shape: [[None, 3, 32, 32]]
Layer name: bn_in, layer type: BatchNormalization, input shape: [[None, 16, 32, 32]]
Layer name: relu, layer type: Activation, input shape: [[None, 16, 32, 32]]
Layer name: stack1_0_conv1, layer type: Conv2D, input shape: [[None, 16, 32, 32]]
Layer name: stack1_0_bn1, layer type: BatchNormalization, input shape: [[None, 16, 32, 32]]
Layer name: stack1_0_relu1, layer type: Activation, input shape: [[None, 16, 32, 32]]
Layer name: stack1_0_conv2, layer type: Conv2D, input shape: [[None, 16, 32, 32]]
Layer name: stack1_0_bn2, layer type: BatchNormalization, input shape: [[None, 16, 32, 32]]
Layer name: add, layer type: Merge, input shape: [[None, 16, 32, 32], [None, 16, 32, 32]]
Layer name: stack1_0_relu2, layer type: Activation, input shape: [[None, 16, 32, 32]]
Layer name: stack1_1_conv1, layer type: Conv2D, input shape: [[None, 16, 32, 32]]
Layer name: stack1_1_bn1, layer type: BatchNormalization, input shape: [[None, 16, 32, 32]]
Layer name: stack1_1_relu1, layer type: Activation, input shape: [[None, 16, 32, 32]]
Layer name: stack1_1_conv2, layer type: Conv2D, input shape: [[None, 16, 32, 32]]
Layer name: stack1_1_bn2, layer type: BatchNormalization, input shape: [[None, 16, 32, 32]]
Layer name: add_1, layer type: Merge, input shape: [[None, 16, 32, 32], [None, 16, 32, 32]]
Layer name: stack1_1_relu2, layer type: Activation, input shape: [[None, 16, 32, 32]]
Layer name: stack1_2_conv1, layer type: Conv2D, input shape: [[None, 16, 32, 32]]
Layer name: stack1_2_bn1, layer type: BatchNormalization, input shape: [[None, 16, 32, 32]]
Layer name: stack1_2_relu1, layer type: Activation, input shape: [[None, 16, 32, 32]]
Layer name: stack1_2_conv2, layer type: Conv2D, input shape: [[None, 16, 32, 32]]
Layer name: stack1_2_bn2, layer type: BatchNormalization, input shape: [[None, 16, 32, 32]]
Layer name: add_2, layer type: Merge, input shape: [[None, 16, 32, 32], [None, 16, 32, 32]]
Layer name: stack1_2_relu2, layer type: Activation, input shape: [[None, 16, 32, 32]]
Layer name: stack2a_conv1, layer type: Conv2D, input shape: [[None, 16, 32, 32]]
Layer name: stack2a_bn1, layer type: BatchNormalization, input shape: [[None, 32, 16, 16]]
Layer name: stack2a_relu1, layer type: Activation, input shape: [[None, 32, 16, 16]]
Layer name: stack2a_conv2, layer type: Conv2D, input shape: [[None, 32, 16, 16]]
Layer name: stack2a_bn2, layer type: BatchNormalization, input shape: [[None, 32, 16, 16]]
Layer name: stack2a_downsample, layer type: AveragePooling2D, input shape: [[None, 16, 32, 32]]
Layer name: cat, layer type: Concatenate, input shape: [[None, 16, 16, 16], [None, 128, 16, 16, 16]]
Layer name: add_3, layer type: Merge, input shape: [[None, 32, 16, 16], [None, 144, 16, 16]]
Layer name: stack2a_relu2, layer type: Activation, input shape: [[None, 32, 16, 16]]
Layer name: stack2b_0_conv1, layer type: Conv2D, input shape: [[None, 32, 16, 16]]
Layer name: stack2b_0_bn1, layer type: BatchNormalization, input shape: [[None, 32, 16, 16]]
Layer name: stack2b_0_relu1, layer type: Activation, input shape: [[None, 32, 16, 16]]
Layer name: stack2b_0_conv2, layer type: Conv2D, input shape: [[None, 32, 16, 16]]
Layer name: stack2b_0_bn2, layer type: BatchNormalization, input shape: [[None, 32, 16, 16]]
Layer name: add_4, layer type: Merge, input shape: [[None, 32, 16, 16], [None, 32, 16, 16]]
Layer name: stack2b_0_relu2, layer type: Activation, input shape: [[None, 32, 16, 16]]
Layer name: stack2b_1_conv1, layer type: Conv2D, input shape: [[None, 32, 16, 16]]
Layer name: stack2b_1_bn1, layer type: BatchNormalization, input shape: [[None, 32, 16, 16]]
Layer name: stack2b_1_relu1, layer type: Activation, input shape: [[None, 32, 16, 16]]
Layer name: stack2b_1_conv2, layer type: Conv2D, input shape: [[None, 32, 16, 16]]
Layer name: stack2b_1_bn2, layer type: BatchNormalization, input shape: [[None, 32, 16, 16]]
Layer name: add_5, layer type: Merge, input shape: [[None, 32, 16, 16], [None, 32, 16, 16]]
Layer name: stack2b_1_relu2, layer type: Activation, input shape: [[None, 32, 16, 16]]
Layer name: stack3a_conv1, layer type: Conv2D, input shape: [[None, 32, 16, 16]]
Layer name: stack3a_bn1, layer type: BatchNormalization, input shape: [[None, 64, 8, 8]]
Layer name: stack3a_relu1, layer type: Activation, input shape: [[None, 64, 8, 8]]
Layer name: stack3a_conv2, layer type: Conv2D, input shape: [[None, 64, 8, 8]]
Layer name: stack3a_bn2, layer type: BatchNormalization, input shape: [[None, 64, 8, 8]]
Layer name: stack3a_downsample, layer type: AveragePooling2D, input shape: [[None, 32, 16, 16]]
Layer name: cat_1, layer type: Concatenate, input shape: [[None, 32, 8, 8], [None, 128, 32, 8, 8]]
Layer name: add_6, layer type: Merge, input shape: [[None, 64, 8, 8], [None, 160, 8, 8]]
Layer name: stack3a_relu2, layer type: Activation, input shape: [[None, 64, 8, 8]]
Layer name: stack3b_0_conv1, layer type: Conv2D, input shape: [[None, 64, 8, 8]]
Layer name: stack3b_0_bn1, layer type: BatchNormalization, input shape: [[None, 64, 8, 8]]
Layer name: stack3b_0_relu1, layer type: Activation, input shape: [[None, 64, 8, 8]]
Layer name: stack3b_0_conv2, layer type: Conv2D, input shape: [[None, 64, 8, 8]]
Layer name: stack3b_0_bn2, layer type: BatchNormalization, input shape: [[None, 64, 8, 8]]
Layer name: add_7, layer type: Merge, input shape: [[None, 64, 8, 8], [None, 64, 8, 8]]
Layer name: stack3b_0_relu2, layer type: Activation, input shape: [[None, 64, 8, 8]]
Layer name: stack3b_1_conv1, layer type: Conv2D, input shape: [[None, 64, 8, 8]]
Layer name: stack3b_1_bn1, layer type: BatchNormalization, input shape: [[None, 64, 8, 8]]
Layer name: stack3b_1_relu1, layer type: Activation, input shape: [[None, 64, 8, 8]]
Layer name: stack3b_1_conv2, layer type: Conv2D, input shape: [[None, 64, 8, 8]]
Layer name: stack3b_1_bn2, layer type: BatchNormalization, input shape: [[None, 64, 8, 8]]
Layer name: add_8, layer type: Merge, input shape: [[None, 64, 8, 8], [None, 64, 8, 8]]
Layer name: stack3b_1_relu2, layer type: Activation, input shape: [[None, 64, 8, 8]]
Layer name: avgpool, layer type: AveragePooling2D, input shape: [[None, 64, 8, 8]]
Layer name: squeeze, layer type: Reshape, input shape: [[None, 64, 1, 1]]
Layer name: fc_out, layer type: Dense, input shape: [[None, 64]]

Steps to Reproduce

Environment:

  • hls4ml version: 0.9.0, commit hash: e4d5e9...
  • PyTorch version: 2.0.0
  • Python version: 3.10

I have added the ResNet architecture below. It is adapted from the following repository where changes were made to layers that are not supported by hls4ml. I have comments in the code below indicating where these changes are. The repository's jupyter notebook, data loader and training script were used to train the model.

The build script used to load the trained PyTorch model and produce the hls4ml output is failing on the convert_from_pytorch_model method with the following inputs:

hls_model = hls4ml.converters.convert_from_pytorch_model(
   model,
   input_shape=[[None,3,32,32]],
   hls_config=config,
   output_dir=output_dir, part='xcvu9p-flga2104-2-i',
   input_data_tb=data_files[0], 
   output_data_tb=data_files[1],
)

When reproducing the concatenation bug a ValueError will appear stating: "operands could not be broadcast together with shapes (16,) (32,)". This error is coming from line 30 of the bn_fuse.py file where bn_scale.data and parent_bias.data have mismatched sizes. I am aware bn_fuse.py is used to fuse BatchNormalization and preceding Conv2D layers. The ResNet's output features between these layers match and the specified input_shape in the build script is correct so I'm unsure the reason for the error. To begin, I am interested in solving the concatenation issue which also occurs in a similar ResNet model I am trying to convert using hls4ml. However, if you have thoughts about the BatchNormalization issue I am seeing, they are welcome.

ResNet Architecture:

import torch
from torch import nn
import torch.nn.functional as F

class block(nn.Module):
    def __init__(self, filters, subsample=False):
        super().__init__()
        # Determine subsampling
        s = 0.5 if subsample else 1.0
        # Setup layers
        self.conv1 = nn.Conv2d(int(filters*s), filters, kernel_size=3, stride=int(1/s), padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(filters, track_running_stats=True)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(filters, track_running_stats=True)
        self.relu2 = nn.ReLU()
        # Shortcut downsampling
        self.downsample = nn.AvgPool2d(kernel_size=1, stride=2)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)   

    def shortcut(self, z, x, isDownsampling_2a, isDownsampling_3a):
        # CHANGE
        # previously: x.shape != z.shape
        # change condition to be an input parameter to address an hls4ml 
        # conversion error: "symbolically traced variables cannot be used as 
        # inputs to control flow"
        # CHANGE
        # modified downsampling to use torch.zeros
        # previous multiplication layer was invalid in hls4ml due to scalar argument 
        if isDownsampling_2a:
            d = self.downsample(x)
            p = torch.zeros(128, 16, 16, 16)
            return z + torch.cat((d, p), dim=1)
        elif isDownsampling_3a:
            d = self.downsample(x)
            p = torch.zeros(128, 32, 8, 8)
            return z + torch.cat((d, p), dim=1)
        else:
            return z + x   

    # CHANGE
    # add isDownsampling parameter
    def forward(self, x, shortcuts=False, isDownsampling_2a=False, isDownsampling_3a=False):
        z = self.conv1(x)
        z = self.bn1(z)
        z = self.relu1(z)
        z = self.conv2(z)
        z = self.bn2(z)
        
        # Shortcut connection
        # This if statement is the only difference between
        # a convolutional net and a resnet!
        if shortcuts:
            # CHANGE
            # add isDownsampling argument
            z = self.shortcut(z, x, isDownsampling_2a, isDownsampling_3a)
        z = self.relu2(z)
        return z

class ResNet(nn.Module):
    def __init__(self, n, shortcuts=True):
        super().__init__()
        self.shortcuts = shortcuts
        # Input
        self.convIn = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bnIn   = nn.BatchNorm2d(16, track_running_stats=True)
        self.relu   = nn.ReLU()
        # Stack1
        self.stack1 = nn.ModuleList([block(16, subsample=False) for _ in range(n)])
        # Stack2
        self.stack2a = block(32, subsample=True)
        self.stack2b = nn.ModuleList([block(32, subsample=False) for _ in range(n-1)])
        # Stack3
        self.stack3a = block(64, subsample=True)
        self.stack3b = nn.ModuleList([block(64, subsample=False) for _ in range(n-1)])
    
        # Output       
        # CHANGE
        # remove AdaptiveAvgPool2d((1,1)) 'AdaptiveAvgPool2d' not supported by hls4ml
        self.avgpool = nn.AvgPool2d(8,1)
        self.fcOut   = nn.Linear(64, 10, bias=True)
        self.softmax = nn.Softmax(dim=-1)
        
        # Initialize weights in fully connected layer 
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal(m.weight)
                m.bias.data.zero_()      

    def forward(self, x): 
        z = self.convIn(x)
        z = self.bnIn(z)
        z = self.relu(z)
        # CHANGE
        # add isDownsampling arguments
        for l in self.stack1: z = l(z, shortcuts=self.shortcuts, isDownsampling_2a=False, isDownsampling_3a=False)
        z = self.stack2a(z, shortcuts=self.shortcuts, isDownsampling_2a=True, isDownsampling_3a=False)
        for l in self.stack2b: 
            z = l(z, shortcuts=self.shortcuts, isDownsampling_2a=False, isDownsampling_3a=False)
        z = self.stack3a(z, shortcuts=self.shortcuts, isDownsampling_2a=False, isDownsampling_3a=True)
        for l in self.stack3b: 
            z = l(z, shortcuts=self.shortcuts, isDownsampling_2a=False, isDownsampling_3a=False)
        z = self.avgpool(z)
        # CHANGE
        # remove z.view(z.size(0), -1), 'size' function not supported by hls4ml
        z = z.squeeze()
        # CHANGE 
        # remove LogSoftmax, not supported by hls4ml
        # CrossEntropyLoss used now in main.ipynb as loss function criterion which requires raw outputs  
        return self.fcOut(z)
@clw5710 clw5710 added the bug label Nov 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant