-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Closed
Description
error log | 日志或报错信息 | ログ
pytorch模型中存在LayerNorm时,导致pnnx转换出现异常,具体如下图所示:
model | 模型 | モデル
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, input_size, width, hidden_size):
super().__init__()
self.input_size = input_size
self.width = width
self.hidden_size = hidden_size
self.intra_fc = nn.Linear(hidden_size, hidden_size)
self.intra_norm = nn.LayerNorm((width, hidden_size), eps=1e-8)
def forward(self, x):
x = x.permute(0, 2, 3, 1) # (B, T, F, C)
# intra_x = x.reshape(-1, x.shape[2], x.shape[3]) # (B*T, F, C)
intra_x = x.squeeze(0) # (B*T, F, C)
# x1, x2 = torch.chunk(x, chunks=2, dim=-1)
x1, x2 = intra_x[:, :, :12], intra_x[:, :, 12:]
# Combine outputs
intra_x = torch.cat([x1, x2], dim=2)
intra_x = self.intra_fc(intra_x)
# intra_x = intra_x.reshape(batch_size, -1, self.width, self.hidden_size)
# intra_x = intra_x.unsqueeze(0)
intra_x = self.intra_norm(intra_x)
return intra_x
class Model_without_layernorm(nn.Module):
def __init__(self, input_size, width, hidden_size):
super().__init__()
self.input_size = input_size
self.width = width
self.hidden_size = hidden_size
self.intra_fc = nn.Linear(hidden_size, hidden_size)
def forward(self, x):
x = x.permute(0, 2, 3, 1) # (B, T, F, C)
# intra_x = x.reshape(-1, x.shape[2], x.shape[3]) # (B*T, F, C)
intra_x = x.squeeze(0) # (B*T, F, C)
# x1, x2 = torch.chunk(x, chunks=2, dim=-1)
x1, x2 = intra_x[:, :, :12], intra_x[:, :, 12:]
# Combine outputs
intra_x = torch.cat([x1, x2], dim=2)
intra_x = self.intra_fc(intra_x)
# intra_x = intra_x.reshape(batch_size, -1, self.width, self.hidden_size)
# intra_x = intra_x.unsqueeze(0)
return intra_x
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
# create data
x = torch.randn(1, 24, 63, 33)
# initialize model
m = Model(24,33,24)
m_without_layernorm = Model_without_layernorm(24,33,24)
out0 = m.forward(x)
out1 = m_without_layernorm.forward(x)
m.eval()
m_without_layernorm.eval()
model = torch.jit.trace(m.cpu(), (x), check_trace=True)
model_without_layernorm = torch.jit.trace(m_without_layernorm.cpu(), (x), check_trace=True)
pt_model_path = "model.pt"
pt_model_without_layernorm_path = "model_without_layernorm.pt"
model.save(pt_model_path)
model_ = torch.jit.load(pt_model_path).cpu()
pt_out0 = model_.forward(x)
model_without_layernorm.save(pt_model_without_layernorm_path)
model_without_layernorm_ = torch.jit.load(pt_model_without_layernorm_path).cpu()
pt_out1 = model_without_layernorm_.forward(x)
equal = torch.equal(out0, pt_out0)
equal_ = torch.equal(out1, pt_out1)
print("convert torch Model to pt model!!!!")
print("Model与pt Model输出结果一致:", equal)
print("Model_without_layernorm与pt Model_without_layernorm输出结果一致:", equal_)how to reproduce | 复现步骤 | 再現方法
- python model.py
- ./pnnx.exe model_without_layernorm.pt inputshape=[1,24,63,33] inputshape2=[1,24,100,33] fp16=0
- ./pnnx.exe model.pt inputshape=[1,24,63,33] inputshape2=[1,24,100,33] fp16=0
备注:torch:2.4.1, pnnx:20241223-windows
Metadata
Metadata
Assignees
Labels
No labels
