-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Closed
Labels
triagedIssue has been triaged by maintainersIssue has been triaged by maintainers
Description
I am struggling to load caffe network with PReLU layer (for example det1 from https://github.com/OAID/FaceDetection/tree/master/models).
The whole function is here:
ILayer* parsePReLU(INetworkDefinition& network, const trtcaffe::LayerParameter& msg, CaffeWeightFactory& weightFactory,
BlobNameToTensor& tensors)
{
// Caffe stores the slopes as weights rather than as a tensor, and only supports different slopes
// per channel
if (!checkBlobs(msg, 1, 1))
{
return nullptr;
}
const trtcaffe::PReLUParameter& p = msg.prelu_param();
bool channelShared = p.has_channel_shared() ? p.channel_shared() : false;
auto inputDims = tensors[msg.bottom(0)]->getDimensions();
if (inputDims.nbDims < 2)
{
return nullptr;
}
int nWeights = channelShared ? 1 : inputDims.d[1]; // Caffe treats second input dimension as channels
Dims slopesDims{inputDims.nbDims, {1}, {DimensionType::kSPATIAL}};
slopesDims.d[1] = nWeights;
Weights w = weightFactory.isInitialized() ? weightFactory(msg.name(), WeightType::kGENERIC) :
weightFactory.allocateWeights(nWeights, std::uniform_real_distribution<float>(0.F, 1.F));
auto constLayer = network.addConstant(slopesDims, w);
return network.addParametricReLU(*tensors[msg.bottom(0)], *constLayer->getOutput(0));
}
I can see at least one problem here. For the function to succeed inputDims.nbDims should be >= 2 (In my case it is 3). However slopesDims{inputDims.nbDims, {1}, {DimensionType::kSPATIAL}} initialises only first dimension. The line "slopesDims.d[1] = nWeights;" sets second dimension but not the type. And other dimension(s) are not initialised at all.
This results in failure to create Constant layer later with message " Parameter check failed at: ../builder/Network.cpp::addConstant::770, condition: allDimsGtEq(dimensions, 1)"
Metadata
Metadata
Assignees
Labels
triagedIssue has been triaged by maintainersIssue has been triaged by maintainers