Skip to content

A c++ implementation of Convolutional neural network, with a MNIST hand-written digits recognition application

License

Notifications You must be signed in to change notification settings

arnogranier/cppCNN

Repository files navigation

cpp CNN

Final project for ENSTA's SIM202 course

A c++ implementation of Convolutional neural network, with an application on MNIST hand-written digits recognition.

You can find the proposed subject together with a brief introductive survey of CNN's theory here : Subject

Currently training at 98% on LeNet architecture (LeNet is a classical architecture for OCR introduced in LeCun, Y., Bottou, L., Bengio, Y., & Haffner, P. (1998). Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11), 2278-2324. pdf

Rapport proposé (french) : pdf

Architecture

alt text

Example of usage

#include <iostream>
#include <list>
#include <ctime>
#include <CNN.hpp>
using namespace std;

int main()
{
    
    // * PARAMETERS * 
    
    // path to database
    string filename_train_images = "../../MNIST/train-images-idx3-ubyte";
    string filename_train_labels = "../../MNIST/train-labels-idx1-ubyte";
    string filename_test_images = "../../MNIST/t10k-images-idx3-ubyte";
    string filename_test_labels = "../../MNIST/t10k-labels-idx1-ubyte";
    
    // Feature detector architecture
    Layer3D* f1 = new ConvLayer(5, 1, 6, 1, 0);
    Layer3D* p1 = new MaxPoolLayer(2, 2);
    Layer3D* f2 = new ConvLayer(5, 6, 16, 1, 0);
    Layer3D* p2 = new MaxPoolLayer(2, 2);
    list<Layer3D*> feature_detector{f1, p1, f2, p2};
    
    // Classifier architecture
    FCLayer l1(4*4*16, 84);
    FCLayer l2(84, 10);
    list<FCLayer> classifier{l1, l2};
    
    // Learning rates
    double classifier_lr = 0.01;
    double feature_detector_lr = 0.0005;
    
    // Number of epochs
    uint n_epoch = 10;
    
    
    // * MAIN * 
    
    // Building the network
    CNN net(feature_detector, classifier, classifier_lr, feature_detector_lr);
    
    // Read and set the database
    net.set_db(filename_train_images, filename_train_labels, 
               filename_test_images, filename_test_labels);

    // Accuracy test
    double acc;
    acc = net.test_accuracy();
    cout << "Accuracy on test dataset: " << acc << " % " << endl;
    
    // Training phase
    net.train(n_epoch);
    
    // Accuracy test
    acc = net.test_accuracy();
    cout << "Accuracy on test dataset: " << acc << " % " << endl;
    
    net.save("mynet");
    // net.load("mynet");
    for (auto pointer:feature_detector) delete pointer;

    
    return 0;
}

About

A c++ implementation of Convolutional neural network, with a MNIST hand-written digits recognition application

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages