-
Notifications
You must be signed in to change notification settings - Fork 202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Activation function and Loss function not read correctly in RnnOutputLayer #67
Comments
Hi @fdebeyan, unfortunately I wasn't able to replicate this issue - can you try the configuration below by copying it into the GUI and let me know how it goes? If I'm not mistaken, this is the output layer configuration you're having an issue with.
Change the I tested this on a tiny version of the IMDB dataset (attached) and it worked without a hitch: |
When setting the activation function and loss function for RnnOutputLayer, I'm using sigmoid and LossBinaryXENT since the class is binary. However, weka doesn;t seem to read the options properly. I tried it from command line and GUI.
To Reproduce
[Low] RnnSequenceClassifier$2030411960|-S 1 -tBPTTBackward 25 -tBPTTForward 25 -cache-mode MEMORY -early-stopping "weka.dl4j.earlystopping.EarlyStopping -maxEpochsNoImprovement 0 -valPercentage 0.0" -normalization "Standardize training data" -iterator "weka.dl4j.iterators.instance.sequence.text.rnn.RnnTextEmbeddingInstanceIterator -stopWords "weka.dl4j.text.stopwords.Dl4jRainbow " -tokenPreProcessor "weka.dl4j.text.tokenization.preprocessor.CommonPreProcessor " -tokenizerFactory "weka.dl4j.text.tokenization.tokenizer.factory.NGramTokenizerFactory -NMax 3 -NMin 1 -delimiters \" \\\\\\r\\\\\\n\\\\\\t.,;:\\\'\\\"()?!\"" -truncationLength 100 -wordVectorLocation /home/w2v-sample-2.arff -bs 1" -iteration-listener "weka.dl4j.listener.EpochListener -eval true -n 5" -layer "weka.dl4j.layers.LSTM -gateActivation "weka.dl4j.activations.ActivationSigmoid " -nOut 100 -activation "weka.dl4j.activations.ActivationReLU " -name "LSTM layer"" -layer "weka.dl4j.layers.RnnOutputLayer -lossFn "weka.dl4j.lossfunctions.LossBinaryXENT " -nOut 2 -activation "weka.dl4j.activations.ActivationSigmoid " -name "RnnOutput layer"" -logConfig "weka.core.LogConfiguration -append true -dl4jLogLevel WARN -logFile /wekafiles/wekaDeeplearning4j.log -nd4jLogLevel INFO -wekaDl4jLogLevel INFO" -config "weka.dl4j.NeuralNetConfiguration -biasInit 0.0 -biasUpdater "weka.dl4j.updater.Sgd -lr 0.001 -lrSchedule \"weka.dl4j.schedules.ConstantSchedule -scheduleType EPOCH\"" -dist "weka.dl4j.distribution.Disabled " -dropout "weka.dl4j.dropout.Disabled " -gradientNormalization None -gradNormThreshold 1.0 -l1 NaN -l2 NaN -minimize -algorithm STOCHASTIC_GRADIENT_DESCENT -updater "weka.dl4j.updater.Adam -beta1MeanDecay 0.9 -beta2VarDecay 0.999 -epsilon 1.0E-8 -lr 0.001 -lrSchedule \"weka.dl4j.schedules.ConstantSchedule -scheduleType EPOCH\"" -weightInit XAVIER -weightNoise "weka.dl4j.weightnoise.Disabled "" -numEpochs 10 -numGPUs 1 -averagingFrequency 10 -prefetchSize 24 -queueSize 0 -zooModel "weka.dl4j.zoo.CustomNet -channelsLast false -pretrained NONE"
Expected behavior
Activation function should be set to sigmoid and Loss function should be set to LossBinaryXENT
Error
Caused by: org.deeplearning4j.exception.DL4JInvalidConfigException: Invalid output layer configuration for layer "RnnOutput layer": softmax activation function in combination with LossBinaryXENT (binary cross entropy loss function). For multi-class classification, use softmax + MCXENT (multi-class cross entropy); for binary multi-label classification, use sigmoid + XENT.
OR:
Caused by: org.deeplearning4j.exception.DL4JInvalidConfigException: Invalid output layer configuration for layer "RnnOutput layer": sigmoid activation function in combination with LossMCXENT (binary cross entropy loss function). For multi-class classification, use softmax + MCXENT (multi-class cross entropy); for binary multi-label classification, use sigmoid + XENT.
Additional Information
The text was updated successfully, but these errors were encountered: