Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation loss increasing while WER decreases #78

Closed
SiddGururani opened this issue Jun 9, 2017 · 25 comments
Closed

Validation loss increasing while WER decreases #78

SiddGururani opened this issue Jun 9, 2017 · 25 comments

Comments

@SiddGururani
Copy link
Contributor

SiddGururani commented Jun 9, 2017

overfit

I would like to believe that the model is overfitting but why would the WER keep decreasing if it was overfitting?

The architecture is as follows:
500 hidden size
5 RNN layers
default LR
1.001 annealing factor which I'm increasing by 0.001 every epoch.

I'm training using Librispeech train-clean-100.tar.gz and validating on dev-clean.tar.gz

@ryanleary
Copy link
Collaborator

Do you have a plot of the CER too?

@SiddGururani
Copy link
Contributor Author

SiddGururani commented Jun 9, 2017

cer
It essentially looks the same as the WER.

Note that the training progressed in the meantime so that first image and the second image aren't identical.

@SeanNaren
Copy link
Owner

Try increasing the annealing factor to something like 1.01, the default annealing works well for AN4 since it's really small

@SiddGururani
Copy link
Contributor Author

@SeanNaren I can try that out.
But I was wondering why the accuracy metric is still showing an increase in performance while the loss on the validation set is increasing. If I hadn't plotted the validation loss, I would not have noticed this discrepancy because we were only concerned with the CER and WER on the validation and the average loss while training. I would expect that the general trend of the loss and WER would be pretty correlated. One thing that might lead to this that comes to my head is that the temporal accuracy of the words is incorrect, which is leading to a high cost but a low WER, but I don't see why it would keep getting worse.
Any ideas?

@SeanNaren
Copy link
Owner

This might be stupid, but could you try changing the batch norm flag here to true and re-training?

@SiddGururani
Copy link
Contributor Author

I'll do that. I'm keeping my old learning rate schedule for this experiment though. I'll post an update in a few hours when I have some results.

@SiddGururani
Copy link
Contributor Author

screenshot from 2017-06-13 10-08-16
Changing the batch_norm flag to True for the 0th layer didn't really effect the general trend of the loss and WER. Could there be something wrong with how I'm computing the loss? I just copied the code segment from the training section of the code and refactored it to compute the validation loss.

@ryanleary
Copy link
Collaborator

Do you have that code in a branch somewhere so we can look?

@SiddGururani
Copy link
Contributor Author

SiddGururani commented Jun 13, 2017

I setup a new branch with the code here https://github.com/SiddGururani/deepspeech.pytorch/blob/val_loss/train.py
Have a look at the section after this line.

@SeanNaren
Copy link
Owner

Hey so I saw the same behaviour on some internal tests on a torch lua version of DeepSpeech. My colleagues suggested that at a certain point the model overfits since the dataset is small and as a result validation results increase.

Further investigation is needed but I think that sounds like good intuition as to why val loss increases whereas WER/CER stays low.

@SiddGururani
Copy link
Contributor Author

SiddGururani commented Jun 19, 2017

Correct me if I'm not understanding this properly:
So the output loss is the negative log of the total probability of the most likely sequence of labels (computed dynamically). Since this function doesn't directly translate to an error between the true labels and the output labels, this number could be high while still giving a good CER/WER. It's just that the CTC layer hasn't learned the parameters such that the best path has a very high probability.

Now my question is: since our goal is to eventually have a low WER, should we care too much about the probability of the best sequence? If you look at my plots in this issue thread, the average validation loss is in the order of 10s, which makes the probability of the best sequence to be e^(-10), which is a very small number (order of 10^-5). I don't like this low probability for the best sequence of labels.
The concept of overfitting is also weird to me in this case. The model is still able to output better label sequences (on unseen data) as training continues but the average probability of the sequences gets lower.

@SeanNaren
Copy link
Owner

@SiddGururani really good thoughts cheers! Unfortunately I can't really add much guidance at this stage, but I can try follow up with a few people.

@SiddGururani
Copy link
Contributor Author

@SeanNaren Thanks haha!

It'd be great if someone can look at this from a more theoretical/mathematical point of view. I'm running a full librispeech training with a large enough model (512*6 LSTM) and the default learning rate and annealing. Gonna see how the validation loss looks like.

@SiddGururani
Copy link
Contributor Author

Once again, with the architecture that I've mentioned above, I observe the same behavior. So I don't think this has to do with the size of the dataset.
full_1
full2

Another thing that I'm curious about is that the average training loss and the average validation loss are very different from the start. Does that have any significance here? I could do some digging into librispeech to see if the utterance lengths are not similarly distributed between the training and the validation set. But has anyone else observed this with another dataset?

@ryanleary
Copy link
Collaborator

It should be unsurprising that the validation loss is different from the training loss. The training loss is being directly optimized as part of the training process. The validation loss indicates how well that training generalizes. The training loss will almost always be lower than the validation loss. (This is why the results of testing on training data are meaningless other than to indicate that the model is actually learning something).

@SiddGururani
Copy link
Contributor Author

SiddGururani commented Jun 23, 2017

@ryanleary I agree with everything you are saying, which is exactly why I'm confused about the validation loss being so low to begin with. I think you misread my statement by accident.
Actually, I see why you say that. I referenced the figures without actually pointing out that the validation loss is lower. my bad :D

@SiddGururani
Copy link
Contributor Author

Update:
So the model I was training ran for 23 epochs and it seems like it started overfitting after the 9th. The trends are here:
screenshot from 2017-06-23 18-50-09

What's different from before is that the WER has converged to 20.8. @ryanleary are you also observing this kind of WER trend for librispeech? I know that our model architectures are different so we can't really compare but I'd still like to know because it seems like your model also has a similar WER on the librispeech validation set.

@ryanleary
Copy link
Collaborator

@SiddGururani I haven't been tracking the validation loss. Can you please submit a PR that includes your change for that? I've run into an issue with the tensorboard logging where if training gets restarted, the plots get all messed up -- so I don't have any good plots at the moment. What was your model architecture out of curiosity?

I still don't have a really good intuition as to what's going on with the validation loss. Something seems amiss.

@dlmacedo
Copy link

dlmacedo commented Jun 24, 2017

I think we are beginning to overfitting...

So we need more regularization...

I think that if we fix #4, this overfitting problem will also be fixed.

@SiddGururani
Copy link
Contributor Author

SiddGururani commented Jun 24, 2017

@ryanleary the issue that you're running into with tensorboard should have been fixed in #92
Have you tried restarting tensorboard?
My model architecture was 6 layers of 512 LSTMs after the 2 CNN layers. I used the default learning rate and annealing. Batch size of 30 being trained on 3*1080Ti's
I'll submit my validation loss logging brach as a PR on Sunday. I was actually curious about your WER. Is it converging to ~21 after the first 10 epochs?

@SeanNaren
Copy link
Owner

So I asked my colleague @willfrey and this is his intuition:

The model might be getting on one particular valid path which means it may be penalizing other valid paths by making them less likely.
So the word error rate goes down because that one valid path is getting drilled into the model
But other valid paths, which CTC considers, become less likely and may increase the loss.

@SiddGururani
Copy link
Contributor Author

SiddGururani commented Jun 26, 2017

@SeanNaren that sounds right.
I've run another network with 7 layers and 400 units each and I'm getting similar results as my previously posted plots. Validation accuracy is increasing but the WER has converged after around 9-10 epochs.
This suggests that the initial suspicion that the dataset was too small might be true because both times I ran the network with the complete librispeech dataset, the WER converged while validation accuracy started to increase which suggests overfitting.

Maybe we figured out the specific issue of WER decreasing while loss increases. What remains unanswered is:

  1. Why is the validation loss so low from the start of training?
    Possible explanation is that it's just the dataset. I can test this out by validating against AN4 since it's small enough.

  2. How to tackle the problem of overfitting?
    The librispeech dataset should be large enough that it's not overfitting, but that's just me speculating. It's possible that the few differences between the baidu deepspeech and the current implementation may be a factor. Is there any scope of adding a regularizer to the loss? I'm not an expert in CTC loss so I can give no input there.

@SeanNaren
Copy link
Owner

@SiddGururani sorry for getting to you late about this.
Regarding 2:
Librispeech is still fairly small, it isn't enough to train a production model. Combining multiple free datasets as well as fisher_english etc from the LDC dataset will give you enough data to get a better model out. Combining with some noise injection and data augmentation will help as well.

@SeanNaren
Copy link
Owner

I'm going to close this for now, I can end this by saying I've seen this behaviour on all models I've trained, but doesn't affect the accuracy of the models!

@monk1337
Copy link

I was facing the same issue during whisper model training, Just commenting for future reference. Increasing validation loss with decreasing WER might be due to the model becoming more confident in specific predictions (leading to a better WER score) but getting the overall probability distribution more wrong (leading to higher loss ). Concentration on a Particular Path, Overfitting, and Dataset Issues are possible causes too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants