Skip to content

Commit

Permalink
Merge branch 'NumpyDatabase' into 'master'
Browse files Browse the repository at this point in the history
Numpy database

See merge request n2d2/n2d2!105
  • Loading branch information
thibaultallenet-cea committed Mar 9, 2023
2 parents 6693b9a + 9a4eb6f commit b7964a1
Show file tree
Hide file tree
Showing 13 changed files with 562 additions and 22 deletions.
43 changes: 43 additions & 0 deletions docs/python_api/databases.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,49 @@ Then to load the database we will use :
:members:
:inherited-members:

Numpy
~~~~~

The :py:class:`n2d2.database.Numpy` allows to create a database using Numpy array.
This can be especially usefull if you already have a dataloader written in Python.

.. note::

The labels are optional, this can be usefull if you have previously trained your model and only need data to calibrate you model using the :py:func:`n2d2.quantizer.PTQ` function.

Usage example
^^^^^^^^^^^^^

.. code-block:: python
import n2d2
import numpy as np
db = n2d2.database.Numpy()
db.load([
np.ones([1,2,3]),
np.zeros([1,2,3]),
np.ones([1,2,3]),
np.zeros([1,2,3]),
],
[
0,
1,
0,
1
])
db.partition_stimuli(1., 0., 0.) # Learn Validation Test
provider = n2d2.provider.DataProvider(db, [3, 2, 1], batch_size=2)
provider.set_partition("Learn")
print("First stimuli :")
print(next(provider))
.. autoclass:: n2d2.database.Numpy
:members:
:inherited-members:

MNIST
~~~~~

Expand Down
133 changes: 133 additions & 0 deletions docs/python_api/example/load_numpy_data.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
Load Numpy Data
===============

In this example, we will see how to load data from a Numpy array using :py:class:`n2d2.database.numpy`.

You can find the full python script here :download:`load_numpy_data.py</../python/examples/load_numpy_data.py>`.

Preliminary
-----------

For this tutorial, we will create a database using the following keras data loader : https://keras.io/api/datasets/fashion_mnist/.

Available by importing :

.. code-block::
from tensorflow.keras.datasets.fashion_mnist import load_data
(x_train, y_train), (x_test, y_test) = load_data()
Load data into N2D2
-------------------

Now that we have our data in the form of Numpy array we can create and populate the :py:class:`n2d2.database.numpy`.

.. code-block::
import n2d2
# Instanciate Numpy database object
db = n2d2.database.Numpy()
# Load train set
db.load([a for a in x_train], [(int)(i.item()) for i in y_train])
# Add the loaded data to the Learn partition
db.partition_stimuli(1., 0., 0.) # Learn Validation Test
# Load test set in the validation partition
db.load([a for a in x_test], [(int)(i.item()) for i in y_test], partition="Validation")
# Print a summary
db.get_partition_summary()
.. testoutput::

Number of stimuli : 70000
Learn : 60000 stimuli (85.71%)
Test : 0 stimuli (0.0%)
Validation : 10000 stimuli (14.29%)
Unpartitioned : 0 stimuli (0.0%)


Training a model using the numpy database
-----------------------------------------

Before anything, we will import the following modules :

.. code-block::
import n2d2
from n2d2.cells.nn import Fc, Softmax
from n2d2.cells import Sequence
from n2d2.solver import SGD
from n2d2.activation import Rectifier, Linear
from math import ceil
For this example we will create a very simple model :

.. code-block::
model = Sequence([
Fc(28*28, 128, activation=Rectifier()),
Fc(128, 10, activation=Linear()),
])
softmax = Softmax(with_loss=True)
model.set_solver(SGD(learning_rate=0.001))
print("Model :")
print(model)
In order to provide data to the model for the training, we will create a :py:class:`n2d2.provider.DataProvider`.

.. code-block::
provider = n2d2.provider.DataProvider(db, [28, 28, 1], batch_size=BATCH_SIZE)
provider.set_partition("Learn")
target = n2d2.target.Score(provider)
Then we can write a classic training loop to learn using the :py:class:`n2d2.provider.DataProvider` :

.. code-block::
print("\n### Training ###")
for epoch in range(EPOCH):
provider.set_partition("Learn")
model.learn()
print("\n# Train Epoch: " + str(epoch) + " #")
for i in range(ceil(db.get_nb_stimuli('Learn')/BATCH_SIZE)):
x = provider.read_random_batch()
x = model(x)
x = softmax(x)
x = target(x)
x.back_propagate()
x.update()
print("Example: " + str(i * BATCH_SIZE) + ", loss: "
+ "{0:.3f}".format(target.loss()), end='\r')
print("\n### Validation ###")
target.clear_success()
provider.set_partition('Validation')
model.test()
for i in range(ceil(db.get_nb_stimuli('Validation')/BATCH_SIZE)):
batch_idx = i * BATCH_SIZE
x = provider.read_batch(batch_idx)
x = model(x)
x = softmax(x)
x = target(x)
print("Validate example: " + str(i * BATCH_SIZE) + ", val success: "
+ "{0:.2f}".format(100 * target.get_average_success()) + "%", end='\r')
print("\nEND")
6 changes: 6 additions & 0 deletions docs/quant/post.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ model:

n2d2 MobileNet_ONNX.ini -seed 1 -w /dev/null -export CPP -fuse -calib -1 -act-clipping-mode KL-Divergence

With the python API
~~~~~~~~~~~~~~~~~~~


.. autofunction:: n2d2.quantizer.PTQ


Examples and results
--------------------
Expand Down
73 changes: 73 additions & 0 deletions include/Database/Tensor_Database.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
(C) Copyright 2023 CEA LIST. All Rights Reserved.
Contributor(s): Cyril MOINEAU (cyril.moineau@cea.fr)
This software is governed by the CeCILL-C license under French law and
abiding by the rules of distribution of free software. You can use,
modify and/ or redistribute the software under the terms of the CeCILL-C
license as circulated by CEA, CNRS and INRIA at the following URL
"http://www.cecill.info".
As a counterpart to the access to the source code and rights to copy,
modify and redistribute granted by the license, users are provided only
with a limited warranty and the software's author, the holder of the
economic rights, and the successive licensors have only limited
liability.
The fact that you are presently reading this means that you have had
knowledge of the CeCILL-C license and that you accept its terms.
*/

#ifndef N2D2_TENSOR_DATABASE_H
#define N2D2_TENSOR_DATABASE_H

#include "Database/Database.hpp"
#include "containers/Tensor.hpp"

namespace N2D2 {
class Tensor_Database : public Database {
public:
Tensor_Database();
template <class T>
void load(
std::vector<Tensor<T>>& inputs,
std::vector<int>& labels)
{
assert(inputs.size() == labels.size());
// Check there is no disperency with stimuli before method
assert(mStimuli.size() == mStimuliData.size());
assert(mStimuli.size() == mStimuliSets(Unpartitioned).size() +
mStimuliSets(Test).size() +
mStimuliSets(Validation).size() +
mStimuliSets(Learn).size());

unsigned int nbStimuliToLoad = inputs.size();
unsigned int oldNbStimuli = mStimuli.size();

mStimuli.reserve(mStimuli.size() + nbStimuliToLoad);
mStimuliData.reserve(mStimuliData.size() + nbStimuliToLoad);
mStimuliSets(Unpartitioned).reserve(mStimuliSets(Unpartitioned).size() + nbStimuliToLoad);


for(unsigned int i = 0; i < nbStimuliToLoad; ++i){
mStimuliData.push_back((cv::Mat)inputs[i].clone());
std::ostringstream nameStr;
nameStr << "Tensor[" << mStimuli.size() << "]";
mStimuli.push_back(Stimulus(nameStr.str(), labels[i]));
mStimuliSets(Unpartitioned).push_back(mStimuli.size() - 1);
}

// Check there is no disperency with stimuli after method
assert(mStimuli.size() == nbStimuliToLoad + oldNbStimuli);
assert(mStimuli.size() == mStimuliData.size());
assert(mStimuli.size() == mStimuliSets(Unpartitioned).size() +
mStimuliSets(Test).size() +
mStimuliSets(Validation).size() +
mStimuliSets(Learn).size());
}

virtual ~Tensor_Database() {};
};
}

#endif // N2D2_TENSOR_DATABASE_H
121 changes: 121 additions & 0 deletions python/examples/load_numpy_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
(C) Copyright 2023 CEA LIST. All Rights Reserved.
Contributor(s): Cyril MOINEAU (cyril.moineau@cea.fr)
This software is governed by the CeCILL-C license under French law and
abiding by the rules of distribution of free software. You can use,
modify and/ or redistribute the software under the terms of the CeCILL-C
license as circulated by CEA, CNRS and INRIA at the following URL
"http://www.cecill.info".
As a counterpart to the access to the source code and rights to copy,
modify and redistribute granted by the license, users are provided only
with a limited warranty and the software's author, the holder of the
economic rights, and the successive licensors have only limited
liability.
The fact that you are presently reading this means that you have had
knowledge of the CeCILL-C license and that you accept its terms.
"""

"""
This script showcase how to use data loaded with numpy to train your Network.
In this example we use the Keras dataloader fashion mnist : https://keras.io/api/datasets/fashion_mnist/
To learn a minimal LeNet Network
"""

from tensorflow.keras.datasets.fashion_mnist import load_data
import n2d2
from n2d2.cells.nn import Fc, Softmax
from n2d2.cells import Sequence
from n2d2.solver import SGD
from n2d2.activation import Rectifier, Linear
from math import ceil
import argparse

# ARGUMENTS PARSING
parser = argparse.ArgumentParser()

parser.add_argument('--dev', '-d', type=int, default=0, help='GPU device, only if CUDA is available. (default=0)')
parser.add_argument('--epochs', "-e", type=int, default=10, help='Number of epochs (default=10)')
parser.add_argument('--batch_size', "-b", type=int, default=32, help='Batchsize (default=32)')
args = parser.parse_args()

BATCH_SIZE = args.batch_size
EPOCH = args.epochs

if n2d2.global_variables.cuda_available:
n2d2.global_variables.default_model = "Frame_CUDA"
n2d2.global_variables.cuda_device = args.dev
else:
print("CUDA is not available")
(x_train, y_train), (x_test, y_test) = load_data()

db = n2d2.database.Numpy()

# x_train is a numpy array of shape [nb train, 28, 28].
# `n2d2.database.numpy.load` only take a list of stimuli.
# So we create a list of numpy array of shape [28, 28] using list comprehension.
db.load([a for a in x_train], [(int)(i.item()) for i in y_train])
db.partition_stimuli(1., 0., 0.) # Learn Validation Test

# Using test set for validation
db.load([a for a in x_test], [(int)(i.item()) for i in y_test], partition="Validation")

db.get_partition_summary()

model = Sequence([
Fc(28*28, 128, activation=Rectifier()),
Fc(128, 10, activation=Linear()),
])
softmax = Softmax(with_loss=True)
model.set_solver(SGD(learning_rate=0.001))

print("Model :")
print(model)


provider = n2d2.provider.DataProvider(db, [28, 28, 1], batch_size=BATCH_SIZE)

provider.set_partition("Learn")

target = n2d2.target.Score(provider)

print("\n### Training ###")
for epoch in range(EPOCH):

provider.set_partition("Learn")
model.learn()

print("\n# Train Epoch: " + str(epoch) + " #")

for i in range(ceil(db.get_nb_stimuli('Learn')/BATCH_SIZE)):

x = provider.read_random_batch()
x = model(x)
x = softmax(x)
x = target(x)
x.back_propagate()
x.update()

print("Example: " + str(i * BATCH_SIZE) + ", loss: "
+ "{0:.3f}".format(target.loss()), end='\r')

print("\n### Validation ###")

target.clear_success()

provider.set_partition('Validation')
model.test()

for i in range(ceil(db.get_nb_stimuli('Validation')/BATCH_SIZE)):
batch_idx = i * BATCH_SIZE

x = provider.read_batch(batch_idx)
x = model(x)
x = softmax(x)
x = target(x)

print("Validate example: " + str(i * BATCH_SIZE) + ", val success: "
+ "{0:.2f}".format(100 * target.get_average_success()) + "%", end='\r')
print("\nEND")
1 change: 1 addition & 0 deletions python/n2d2/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
from n2d2.database.gtsrb import GTSRB
from n2d2.database.ilsvrc2012 import ILSVRC2012
from n2d2.database.mnist import MNIST
from n2d2.database.numpy import Numpy

0 comments on commit b7964a1

Please sign in to comment.