Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'NaN' when using nesterov momentum and high learning rates #765

Closed
KenobySky opened this issue Nov 6, 2016 · 15 comments
Closed

'NaN' when using nesterov momentum and high learning rates #765

KenobySky opened this issue Nov 6, 2016 · 15 comments

Comments

@KenobySky
Copy link

KenobySky commented Nov 6, 2016

Introduction
On August 21, I reported a bug on the lasagne group forum about this.
Basically, when you are using nesterov momentum to train conv nets, if your learning rate starts too high, "nans" are generated to the function loss (training and validation loss), accuracy remains normal. In my case, they start happening near epoch 9.

Obs : The bug seems to happen more often when your dataset is unbalanced or has a lot of values of a single label.

This happens at the most new update of theano and lasagne.

Right now, im using a current version of lasagne that is old but doesnt have that issue :
"pip install --upgrade --no-deps git+git://github.com/Lasagne/Lasagne.git@5a009f9"

Jan.schluete asked me to bisect the code with git to find the guilty commit.
I did it twice to confirm and found it.

_4d4e0b0796634c23ad43889685ee4b428fe30f8a is the first bad commit
commit 4d4e0b0
Author: Jan Schlüter <jan.schlueter@ .at>
Date: Thu Jun 30 13:30:29 2016 +0200

Have create_param() set the broadcast pattern of created shared variables

:040000 040000 93ce1ab99b4d4bd14d3825a0b143109aa7d234b2 020a514a738f4fd6d29f538fa2d0fdc5af76555b M lasagne_

The model im using and the learning rate im using are below :

def getModel():
    try:

        input_var = tensor.tensor4('inputs')
        target_var = tensor.fmatrix('targets')

        # Input Layer
        network = lasagne.layers.InputLayer(shape=(None, 1, 7, 7), input_var=input_var)

        # First Convolution Layer
        network = lasagne.layers.Conv2DLayer(network, num_filters=60, filter_size=(3, 3), stride=1, pad=3)
        network = lasagne.layers.Conv2DLayer(network, num_filters=60, filter_size=(3, 3), stride=1, pad=2)
        network = lasagne.layers.MaxPool2DLayer(incoming=network, pool_size=(2, 2))

        # Second Convolution Layer
        network = lasagne.layers.Conv2DLayer(network, num_filters=120, filter_size=(3, 3), stride=1, pad=3)
        network = lasagne.layers.Conv2DLayer(network, num_filters=120, filter_size=(3, 3), stride=1, pad=2)
        network = lasagne.layers.MaxPool2DLayer(incoming=network, pool_size=(2, 2))

        # Hidden Layers
        network = lasagne.layers.DenseLayer(lasagne.layers.dropout(network, p=0.5), num_units=96)
        network = lasagne.layers.DenseLayer(lasagne.layers.dropout(network, p=0.5), num_units=32)
        network = lasagne.layers.DenseLayer(lasagne.layers.dropout(network, p=0.5), num_units=1, nonlinearity=lasagne.nonlinearities.sigmoid)

        return network, input_var, target_var

    except Exception as inst:
        print ("Failure to Build NN !", inst.message, type(inst), (inst.args), (inst))
        sys.exit(1)

def getLearningRate(current_epoch=0, max_epochs=1000):
    if current_epoch < (max_epochs * 5) / 100:
        return float32(0.150)

    if current_epoch < (max_epochs * 10) / 100:
        return float32(0.10)

    if current_epoch < (max_epochs * 30) / 100:
        return float32(0.010)

    if current_epoch < (max_epochs * 80) / 100:
        return float32(0.0010)

    if current_epoch <= (max_epochs * 100) / 100:
        return float32(0.00010)

Momentum is kept at default.

If you suspect you have found a bug**, please first try updating to the bleeding-edge versions of Theano and Lasagne. It may have been fixed already.

Same issue!

If you are not sure whether the problem lies within your code, Theano, or Lasagne, first post on our mailing list.

The problem is in the commit informed above.

Thanks! I really need this bug fixed since my thesis is using lasagne and referencing it a lot!

Please help!


I recreated the bug using the mnist dataset. However, due to the different data im using, i had to transform every y[y>1] = 0 , so it simulates the unbalance of my dataset.

And yes, the bug starts happening either at the initial epochs ( 2 or 3) or before 10 epochs.

It doesnt happen on the version of lasagne i mentioned before : git+git://github.com/Lasagne/Lasagne.git@5a009f9

I tested more than 10 times in the bleeding edge version and more than 10 times on the version above.

The script on the zip and below simulates it.
Main_RebalancedMnist.zip

#CPATH=~/cuda/include:$CPATH PATH=~/cuda/bin/:$PATH LD_LIBRARY_PATH=~/cuda/lib64/:$LD_LIBRARY_PATH THEANO_FLAGS=mode=FAST_RUN,device=gpu1,floatX=float32,optimizer_including=cudnn
# python Main_RebalancedMnist.py
#Using gpu device 1: GeForce GTX TITAN X (CNMeM is enabled with initial size: 95.0% of memory, cuDNN 5105)
#THEANO RC SETTINGS:
#
#[lib]
#cnmem=1

#[dnn]
#enabled=True
from __future__ import print_function

import os
import sys
import time

import lasagne
import numpy as np
import theano
import theano.tensor as T


def float32(k):
    return np.cast['float32'](k)


# ################## Download and prepare the MNIST dataset ##################
def load_dataset():
    # We first define a download function, supporting both Python 2 and 3.
    if sys.version_info[0] == 2:
        from urllib import urlretrieve
    else:
        from urllib.request import urlretrieve

    def download(filename, source='http://yann.lecun.com/exdb/mnist/'):
        print("Downloading %s" % filename)
        urlretrieve(source + filename, filename)

    # We then define functions for loading MNIST images and labels.
    # For convenience, they also download the requested files if needed.
    import gzip

    def load_mnist_images(filename):
        if not os.path.exists(filename):
            download(filename)
        # Read the inputs in Yann LeCun's binary format.
        with gzip.open(filename, 'rb') as f:
            data = np.frombuffer(f.read(), np.uint8, offset=16)
        # The inputs are vectors now, we reshape them to monochrome 2D images,
        # following the shape convention: (examples, channels, rows, columns)
        data = data.reshape(-1, 1, 28, 28)
        # The inputs come as bytes, we convert them to float32 in range [0,1].
        # (Actually to range [0, 255/256], for compatibility to the version
        # provided at http://deeplearning.net/data/mnist/mnist.pkl.gz.)
        return data / np.float32(256)

    def load_mnist_labels(filename):
        if not os.path.exists(filename):
            download(filename)
        # Read the labels in Yann LeCun's binary format.
        with gzip.open(filename, 'rb') as f:
            data = np.frombuffer(f.read(), np.uint8, offset=8)
        # The labels are vectors of integers now, that's exactly what we want.
        return data

    # We can now download and read the training and test set images and labels.
    X_train = load_mnist_images('train-images-idx3-ubyte.gz')
    y_train = load_mnist_labels('train-labels-idx1-ubyte.gz')
    X_test = load_mnist_images('t10k-images-idx3-ubyte.gz')
    y_test = load_mnist_labels('t10k-labels-idx1-ubyte.gz')

    # We reserve the last 10000 training examples for validation.
    X_train, X_val = X_train[:-10000], X_train[-10000:]
    y_train, y_val = y_train[:-10000], y_train[-10000:]

    # We just return all the arrays in order, as expected in main().
    # (It doesn't matter how we do this as long as we can read them again.)
    return X_train, y_train, X_val, y_val, X_test, y_test


def build_cnn(input_var=None):
    # Input Layer
    network = lasagne.layers.InputLayer(shape=(None, 1, 28, 28), input_var=input_var)

    # First Convolution Layer
    network = lasagne.layers.Conv2DLayer(network, num_filters=60, filter_size=(3, 3), stride=1, pad=3)
    network = lasagne.layers.Conv2DLayer(network, num_filters=60, filter_size=(3, 3), stride=1, pad=2)
    network = lasagne.layers.MaxPool2DLayer(incoming=network, pool_size=(2, 2))

    # Second Convolution Layer
    network = lasagne.layers.Conv2DLayer(network, num_filters=120, filter_size=(3, 3), stride=1, pad=3)
    network = lasagne.layers.Conv2DLayer(network, num_filters=120, filter_size=(3, 3), stride=1, pad=2)
    network = lasagne.layers.MaxPool2DLayer(incoming=network, pool_size=(2, 2))

    # Hidden Layers
    network = lasagne.layers.DenseLayer(lasagne.layers.dropout(network, p=0.5), num_units=96)
    network = lasagne.layers.DenseLayer(lasagne.layers.dropout(network, p=0.5), num_units=32)
    network = lasagne.layers.DenseLayer(lasagne.layers.dropout(network, p=0.5), num_units=1, nonlinearity=lasagne.nonlinearities.sigmoid)

    return network


def iterate_minibatches(inputs, targets, batchsize):
    assert len(inputs) == len(targets)

    indices = np.arange(len(inputs))
    np.random.shuffle(indices)

    for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
        excerpt = indices[start_idx:start_idx + batchsize]
        yield inputs[excerpt], targets[excerpt]


# ############################## Main program ################################
def main():
    num_epochs = 500
    batch_size = 3000
    learning_rate = 0.150
    momentum_rho = 0.9

    seed = 8000
    lasagne.random.set_rng(np.random.RandomState(seed))

    print("\nnum_epochs " + str(num_epochs))
    print("batch_size " + str(batch_size))
    print("learning_rate " + str(learning_rate))
    print("momentum_rho " + str(momentum_rho))
    print("seed " + str(seed))


    # Load the dataset
    print("\nLoading data...")
    X_train, y_train, X_val, y_val, X_test, y_test = load_dataset()

    # REBALANCE DATASET TO 0,1 - SIMULATE A DATASET WITH A LOT MORE 0's THAN 1's.

    y_train.flags.writeable = True
    y_val.flags.writeable = True
    y_test.flags.writeable = True

    def rebalance(k):
        k[k > 1] = 0
        return k

    X_train = X_train.reshape(-1, 1, 28, 28)
    X_val = X_val.reshape(-1, 1, 28, 28)
    X_test = X_test.reshape(-1, 1, 28, 28)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)
    y_val = y_val.reshape(-1, 1)

    y_train = rebalance(y_train)
    y_val = rebalance(y_val)
    y_test = rebalance(y_test)

    # Prepare Theano variables for inputs and targets
    input_var = T.tensor4('inputs')
    target_var = T.fmatrix('targets')

    # Create neural network model (depending on first command line parameter)
    print("Building model and compiling functions...")
    network = build_cnn(input_var)

    prediction = lasagne.layers.get_output(network)
    loss = lasagne.objectives.binary_crossentropy(prediction, target_var)
    loss = loss.mean()

    params = lasagne.layers.get_all_params(network, trainable=True)
    updates = lasagne.updates.nesterov_momentum(loss, params, learning_rate=float32(learning_rate), momentum=float32(momentum_rho))

    test_prediction = lasagne.layers.get_output(network, deterministic=True)
    test_loss = lasagne.objectives.binary_crossentropy(test_prediction, target_var)
    test_loss = test_loss.mean()

    test_acc = lasagne.objectives.binary_accuracy(test_prediction, target_var).mean()

    train_fn = theano.function([input_var, target_var], loss, updates=updates)

    val_fn = theano.function([input_var, target_var], [test_loss, test_acc])

    # Finally, launch the training loop.
    print("Starting training...")

    for epoch in range(num_epochs):
        # In each epoch, we do a full pass over the training data:
        train_err = 0
        train_batches = 0
        start_time = time.time()
        for batch in iterate_minibatches(X_train, y_train, batch_size):
            inputs, targets = batch
            train_err += train_fn(inputs, targets)
            train_batches += 1

        # And a full pass over the validation data:
        val_err = 0
        val_acc = 0
        val_batches = 0
        for batch in iterate_minibatches(X_val, y_val, batch_size):
            inputs, targets = batch
            err, acc = val_fn(inputs, targets)
            val_err += err
            val_acc += acc
            val_batches += 1

        # Then we print the results for this epoch:
        print("Epoch {} of {} took {:.3f}s".format(epoch + 1, num_epochs, time.time() - start_time))
        print("  training loss:\t\t{:.6f}".format(train_err / train_batches))
        print("  validation loss:\t\t{:.6f}".format(val_err / val_batches))
        print("  validation accuracy:\t\t{:.2f} %".format(val_acc / val_batches * 100))

    # After training, we compute and print the test error:
    test_err = 0
    test_acc = 0
    test_batches = 0
    for batch in iterate_minibatches(X_test, y_test, batch_size):
        inputs, targets = batch
        err, acc = val_fn(inputs, targets)
        test_err += err
        test_acc += acc
        test_batches += 1

    print("Final results:")
    print("  test loss:\t\t\t{:.6f}".format(test_err / test_batches))
    print("  test accuracy:\t\t{:.2f} %".format(test_acc / test_batches * 100))


if __name__ == '__main__':
    main()

@ebenolson
Copy link
Member

Hi, good work tracking it down to a commit. My first assumption would have been that your learning rate was just too high, but since it behaved differently in the past we should figure out what's going on. Assuming the bug is not obvious (it's not to me), there are some things you can do to help this get fixed quicker:

  1. Create a fully contained test script that reproduces the error - use randomly generated or constant input data if possible.
  2. Simply the test case as much as possible - is the error dependent on this network structure or can you reduce it to a single layer? Can you replace getLearningRate with a constant?

@KenobySky
Copy link
Author

Im working on it now, however, I would like to ask that lasagne group also check on this.

It was working fine on the past, for a long time, probably some 6 months before that commit.

So wasnt just a random "working thing".
Im preparing a script with the mnist data and see if i can reproduce the problem.

@KenobySky
Copy link
Author

Done, script added!
I will try other types of network structures and different learning rates. Might take some time.

@KenobySky
Copy link
Author

Test Cases :

#Happens with learning_rate = 0.5, 0.3, 0.2, 0.1
#Doesnt seem to happen when lr = 0.01

Doesnt happen when you remove a hidden layer.
When theres just 1 hidden layer, it seems it's harder to happen.
Adding more hidden layers or changing num of kernels doesnt affect, cause Nan's.

For example, the model below causes NAN's :

def build_cnn(input_var=None):
    # Input Layer
    network = lasagne.layers.InputLayer(shape=(None, 1, 28, 28), input_var=input_var)

    # First Convolution Layer
    network = lasagne.layers.Conv2DLayer(network, num_filters=30, filter_size=(3, 3), stride=1, pad=3)
    #network = lasagne.layers.Conv2DLayer(network, num_filters=60, filter_size=(3, 3), stride=1, pad=2)
    #network = lasagne.layers.MaxPool2DLayer(incoming=network, pool_size=(2, 2))

    # Second Convolution Layer
    #network = lasagne.layers.Conv2DLayer(network, num_filters=120, filter_size=(3, 3), stride=1, pad=3)
    #network = lasagne.layers.Conv2DLayer(network, num_filters=120, filter_size=(3, 3), stride=1, pad=2)
    #network = lasagne.layers.MaxPool2DLayer(incoming=network, pool_size=(2, 2))

    # Hidden Layers
    network = lasagne.layers.DenseLayer(lasagne.layers.dropout(network, p=0.5), num_units=96)
    #network = lasagne.layers.DenseLayer(lasagne.layers.dropout(network, p=0.5), num_units=96)
    network = lasagne.layers.DenseLayer(lasagne.layers.dropout(network, p=0.5), num_units=32)
    network = lasagne.layers.DenseLayer(lasagne.layers.dropout(network, p=0.5), num_units=1, nonlinearity=lasagne.nonlinearities.sigmoid)

    return network

@eralmansouri
Copy link

I think it might just be because of using a sigmoid activation function with binary crossentropy. If the result is exactly 1 or exactly 0, the error function results in a NaN which ends up screwing the weights and future outputs.

Try using clipped results when calculating the error:

loss = lasagne.objectives.binary_crossentropy(T.clip(predictions, 1e-7, 1-(1e-7)), target_var);

If that doesn't work, it might be because the default activation function for dense layers is ReLU IIRC. The weights can get unrealistically high when the error is too big, which results in values too high causing NaN values.

To solve that, I would suggest scaling weights or weight updates to remain within a reasonable certain range.

@ebenolson
Copy link
Member

Thanks for the test script. So it seems to me the issue is that create_param is marking the last dimension of W and b for the DenseLayer as broadcastable since num_units = 1. This seems to cause a NaN gradient when a prediction is either 1 or 0, however when unbroadcastable the grad is 0.

It's still not clear to me why this difference happens - maybe @f0k can weigh in if it's a Theano bug, but I think we don't want to be marking those dims broadcastable anyway. Here's a simplified demonstration:

W = theano.shared(np.array([[0], [1e3]]).astype('float32'), broadcastable=(False, True))

x = T.matrix()
y = T.matrix()

p = lasagne.nonlinearities.sigmoid(input_var.dot(W))
l = lasagne.objectives.binary_crossentropy(p, target_var).mean()
T.grad(l, W).eval({input_var:np.array([[0, 1]]).astype('float32'), target_var:np.array([[1]]).astype('float32')})

output:

array([[ nan],
       [ nan]], dtype=float32)

versus

W = theano.shared(np.array([[0], [1e6]]).astype('float32'), broadcastable=(False, False))

input_var = T.matrix()
target_var = T.matrix()

p = lasagne.nonlinearities.sigmoid(input_var.dot(W))
l = lasagne.objectives.binary_crossentropy(p, target_var).mean()
g = T.grad(l, W)
T.grad(l, W).eval({input_var:np.array([[0, 1]]).astype('float32'), target_var:np.array([[1]]).astype('float32')})

output:

array([[ 0.],
       [ 0.]], dtype=float32)   

@f0k
Copy link
Member

f0k commented Nov 7, 2016

Thank you for bisecting this, Andre, and thank you for the clear demonstration, Eben!

maybe @f0k can weigh in if it's a Theano bug

I just commented on the mailing list post: https://groups.google.com/forum/#!topic/lasagne-users/y-yb6dO_Dzg
I think this is a Theano bug of some optimization not being applied due to a Rebroadcast/Dimshuffle getting in the way (Theano/Theano/issues/4451), although I'm not 100% sure if the possibility of a broadcast poses a genuine problem for the optimization.

but I think we don't want to be marking those dims broadcastable anyway

#715 did that on purpose to obtain a broadcastable tensor from a 1-unit dense layer. Unfortunately, nobody noticed how it affected the log(sigmoid) optimization, and we didn't have any tests for this either -- we're fully relying on Theano there.

@andrelopes1705: A quick workaround for now is to have your target variable be a column vector as well:

target_var = T.vector()
...
loss = lasagne.objectives.binary_crossentropy(prediction, target_var.dimshuffle(0, 'x'))

Or, with even less changes to your code:

target_var = T.TensorType(theano.config.floatX, (False, True))('targets')

But we need to figure out how to change Lasagne or Theano to make this work seamlessly again, with existing code using a plain T.matrix() target.

@nouiz, any insights from your side?

@f0k
Copy link
Member

f0k commented Nov 7, 2016

Just to see how the graphs differ -- with same broadcast pattern for predictions and targets:

import theano
import theano.tensor as T
x, y = T.matrices('xy')
theano.printing.debugprint(theano.function([x, y], theano.grad(T.nnet.binary_crossentropy(T.nnet.sigmoid(x), y).mean(), x)))

Elemwise{Composite{((i0 * i1 * (i2 - scalar_sigmoid(i3))) - (i0 * (i2 - i1) * scalar_sigmoid(i3)))}} [id A] '(dmean/dx)'   9
 |Elemwise{Composite{(i0 / (i1 * i2))}} [id B] ''   8
 | |TensorConstant{(1, 1) of -1.0} [id C]
 | |InplaceDimShuffle{x,x} [id D] ''   7
 | | |Subtensor{int64} [id E] ''   5
 | |   |Elemwise{Cast{float32}} [id F] ''   3
 | |   | |MakeVector{dtype='int64'} [id G] ''   2
 | |   |   |Shape_i{0} [id H] ''   1
 | |   |   | |y [id I]
 | |   |   |Shape_i{1} [id J] ''   0
 | |   |     |y [id I]
 | |   |Constant{1} [id K]
 | |InplaceDimShuffle{x,x} [id L] ''   6
 |   |Subtensor{int64} [id M] ''   4
 |     |Elemwise{Cast{float32}} [id F] ''   3
 |     |Constant{0} [id N]
 |y [id I]
 |TensorConstant{(1, 1) of 1.0} [id O]
 |x [id P]

With different broadcast pattern:

x = T.TensorType(theano.config.floatX, (False, True))('x')
theano.printing.debugprint(theano.function([x, y], theano.grad(T.nnet.binary_crossentropy(T.nnet.sigmoid(x), y).mean(), x)))

Elemwise{Composite{((((i0 * i1) / scalar_sigmoid(i2)) + (-((i0 * i3) / (i4 - scalar_sigmoid(i2))))) * scalar_sigmoid(i2) * (i4 - scalar_sigmoid(i2)))}}[(0, 1)] [id A] '(dmean/dx)'   14
 |Elemwise{Composite{(i0 / (i1 * i2))}} [id B] ''   13
 | |TensorConstant{(1, 1) of -1.0} [id C]
 | |InplaceDimShuffle{x,x} [id D] ''   12
 | | |Subtensor{int64} [id E] ''   10
 | |   |Elemwise{Cast{float32}} [id F] ''   7
 | |   | |MakeVector{dtype='int64'} [id G] ''   5
 | |   |   |Shape_i{0} [id H] ''   2
 | |   |   | |y [id I]
 | |   |   |Shape_i{1} [id J] ''   1
 | |   |     |y [id I]
 | |   |Constant{1} [id K]
 | |InplaceDimShuffle{x,x} [id L] ''   11
 |   |Subtensor{int64} [id M] ''   9
 |     |Elemwise{Cast{float32}} [id F] ''   7
 |     |Constant{0} [id N]
 |InplaceDimShuffle{0,x} [id O] ''   4
 | |Sum{axis=[1], acc_dtype=float64} [id P] ''   0
 |   |y [id I]
 |x [id Q]
 |InplaceDimShuffle{0,x} [id R] ''   8
 | |Sum{axis=[1], acc_dtype=float64} [id S] ''   6
 |   |Elemwise{sub,no_inplace} [id T] ''   3
 |     |TensorConstant{(1, 1) of 1.0} [id U]
 |     |y [id I]
 |TensorConstant{(1, 1) of 1.0} [id U]

@nouiz
Copy link

nouiz commented Nov 10, 2016

Someone here will work on that in the next few days.

@KenobySky
Copy link
Author

I Appreciate the support!
Thanks a lot!
I forwarded this to my Master Degree advisor.

@nouiz
Copy link

nouiz commented Nov 18, 2016

We merged a fix for that in Theano. Can you try it in your own environment to make sure the fix is complete?

@KenobySky
Copy link
Author

Yes, of course! Should i try with my own test case or do you want me to test with some script?

$Requesting instructions to test.

@f0k
Copy link
Member

f0k commented Nov 21, 2016

We merged a fix for that in Theano.

Great!

Can you try it in your own environment to make sure the fix is complete?

Yes, it works both for Eben's test case and mine. Good job @ReyhaneAskari!

Requesting instructions to test.

Just run the same code you used for bisecting Lasagne. It should train fine now when you update Theano to the latest version from git, both with old and recent Lasagne versions.

@f0k f0k closed this as completed Nov 21, 2016
@KenobySky
Copy link
Author

It's Fixed!

@cembirler
Copy link

in my case i changed the learning rate to 0.0012 and it works now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants