Skip to content

模型中存在LayerNorm导致pnnx出现算子转换异常现象 #5942

@Xiaowei-coder

Description

@Xiaowei-coder

error log | 日志或报错信息 | ログ

pytorch模型中存在LayerNorm时,导致pnnx转换出现异常,具体如下图所示:

Image

model | 模型 | モデル

import torch
import torch.nn as nn


class Model(nn.Module):
    def __init__(self, input_size, width, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.width = width
        self.hidden_size = hidden_size

        self.intra_fc = nn.Linear(hidden_size, hidden_size)
        self.intra_norm = nn.LayerNorm((width, hidden_size), eps=1e-8)

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)  # (B, T, F, C)

        # intra_x = x.reshape(-1, x.shape[2], x.shape[3])  # (B*T, F, C)
        intra_x = x.squeeze(0)  # (B*T, F, C)
        # x1, x2 = torch.chunk(x, chunks=2, dim=-1)
        x1, x2 = intra_x[:, :, :12], intra_x[:, :, 12:]

        # Combine outputs
        intra_x = torch.cat([x1, x2], dim=2)
        intra_x = self.intra_fc(intra_x)
        # intra_x = intra_x.reshape(batch_size, -1, self.width, self.hidden_size)
        # intra_x = intra_x.unsqueeze(0)
        intra_x = self.intra_norm(intra_x)

        return intra_x

class Model_without_layernorm(nn.Module):
    def __init__(self, input_size, width, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.width = width
        self.hidden_size = hidden_size
        self.intra_fc = nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)  # (B, T, F, C)

        # intra_x = x.reshape(-1, x.shape[2], x.shape[3])  # (B*T, F, C)
        intra_x = x.squeeze(0)  # (B*T, F, C)
        # x1, x2 = torch.chunk(x, chunks=2, dim=-1)
        x1, x2 = intra_x[:, :, :12], intra_x[:, :, 12:]

        # Combine outputs
        intra_x = torch.cat([x1, x2], dim=2)
        intra_x = self.intra_fc(intra_x)
        # intra_x = intra_x.reshape(batch_size, -1, self.width, self.hidden_size)
        # intra_x = intra_x.unsqueeze(0)

        return intra_x


# Press the green button in the gutter to run the script.
if __name__ == '__main__':
    # create data
    x = torch.randn(1, 24, 63, 33)

    # initialize model
    m = Model(24,33,24)
    m_without_layernorm = Model_without_layernorm(24,33,24)

    out0 = m.forward(x)
    out1 = m_without_layernorm.forward(x)

    m.eval()
    m_without_layernorm.eval()

    model = torch.jit.trace(m.cpu(), (x), check_trace=True)
    model_without_layernorm = torch.jit.trace(m_without_layernorm.cpu(), (x), check_trace=True)

    pt_model_path = "model.pt"
    pt_model_without_layernorm_path = "model_without_layernorm.pt"

    model.save(pt_model_path)
    model_ = torch.jit.load(pt_model_path).cpu()
    pt_out0 = model_.forward(x)

    model_without_layernorm.save(pt_model_without_layernorm_path)
    model_without_layernorm_ = torch.jit.load(pt_model_without_layernorm_path).cpu()
    pt_out1 = model_without_layernorm_.forward(x)
    equal = torch.equal(out0, pt_out0)
    equal_ = torch.equal(out1, pt_out1)

    print("convert torch Model to pt model!!!!")
    print("Model与pt Model输出结果一致:", equal)
    print("Model_without_layernorm与pt Model_without_layernorm输出结果一致:", equal_)

how to reproduce | 复现步骤 | 再現方法

  1. python model.py
  2. ./pnnx.exe model_without_layernorm.pt inputshape=[1,24,63,33] inputshape2=[1,24,100,33] fp16=0
  3. ./pnnx.exe model.pt inputshape=[1,24,63,33] inputshape2=[1,24,100,33] fp16=0

备注:torch:2.4.1, pnnx:20241223-windows

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions