# <p style="text-align: center;, font-style: strong;">Partie 2 : MNIST with Convolutional Neural Network (CNN)</p>

### <p style="text-align: center;">(Almond 0.8.1, Scala 2.12.8)</p>


## Dependencies

Surprise!

In [None]:
interp.load.ivy(coursierapi.Dependency.of("org.platanios", "tensorflow_2.12", "0.4.1").withClassifier("linux-cpu-x86_64"))
interp.load.ivy("org.platanios" %% "tensorflow-data" % "0.4.1")

In [None]:
import java.nio.file.Paths

import org.platanios.tensorflow.api._

import org.platanios.tensorflow.api.tf
import org.platanios.tensorflow.api.tensors.Tensor
import org.platanios.tensorflow.api.core.Shape
import org.platanios.tensorflow.api.core.Indexer._
import org.platanios.tensorflow.api.core.client.Session
import org.platanios.tensorflow.data.image.MNISTLoader

import org.platanios.tensorflow.api.learn.layers.{ Softmax, AddBias, Sigmoid, Dropout, Flatten, Input, Linear, ReLU, SparseSoftmaxCrossEntropy, Mean, Conv2D, MaxPool }
import org.platanios.tensorflow.api.learn.{ Model, StopCriteria }
import org.platanios.tensorflow.api.learn.estimators.InMemoryEstimator

import org.platanios.tensorflow.api.ops.NN.SameConvPadding


## Display MNIST Dataset

In [None]:
{{
def displayNumberMNIST(nb: Int) {
    val dataset = MNISTLoader.load(Paths.get("../resources/dataset"))
    val images = dataset.trainImages
    val imagesToDisplay = images.slice(0 :: nb, ::, ::)
    for (index <- 0 until nb) {
        val png = Session().run(fetches = tf.decodeRaw[Byte](tf.image.encodePng(imagesToDisplay(index).reshape(Shape(28, 28, 1)))))
        Image(png.entriesIterator.toArray).withFormat(Image.PNG).withWidth(100).withHeight(100).display 
    }
}
displayNumberMNIST(20)
}}

### Data iterator for training

In [None]:
val dataset = MNISTLoader.load(Paths.get("../resources/dataset"))

val trainImages = dataset.trainImages.toFloat
val trainImagesReshape = tf.data.datasetFromTensorSlices(trainImages.reshape(Shape(dataset.trainImages.shape(0), dataset.trainImages.shape(1), dataset.trainImages.shape(2), 1)))

val trainLabels = tf.data.datasetFromTensorSlices(dataset.trainLabels.toLong)
val trainData =
  trainImagesReshape.zip(trainLabels)
      .repeat()
      .shuffle(10000)
      .batch(256)
      .prefetch(10)


### Input shape

In [None]:
val input = tf.learn.Input(FLOAT32, Shape(-1, 28, 28, 1))
val trainInput = Input(INT64, Shape(-1))

### Model Topology

CNN models are build with a succession of specific Layers:

- Convolution Layer to score locally a set of 2D patterns on the 2D grid, e.g. 

    `Conv2D[Float]("Layer_0/Conv2D", Shape(3, 3, 1, 32), 1, 1, SameConvPadding)`
    
    
- Rectifying Linear Unit to avoid symetric detections (mirror effects), e.g.

    `ReLU[Float]("Layer_0/ReLU")`
    
    
- Pooling scores to select the best pattern in a given region, e.g.

   `MaxPool[Float]("Layer_0/MaxPool", Seq(1, 2, 2, 1), 1, 1, SameConvPadding)`
   
Successive such layers bring a hierarchy of pattern detection/selection

Then ends with a Flatttening from 2D to 1D (remove locality), a fully connected layer and the ouptut layer to assess the classes of different such patterns. 


TODO:

Find a better model, try to reach > 0.96 accuracy, (15-20 mins exercise):

- add a Convolution of shape (3, 3, 32, 64)?
- add a Convolution of shape (3, 3, 64, 128)?
- add some `Dropout("Embedding/Dropout", 0.33F)` after Flatten ?
- add some steps ?


In [None]:

// Create the CNN model.
val layer = 
        Conv2D[Float]("Layer_0/Conv2D", Shape(3, 3, 1, 32), 1, 1, SameConvPadding) >>
        ReLU[Float]("Layer_0/ReLU") >>
        MaxPool[Float]("Layer_0/MaxPool", Seq(1, 2, 2, 1), 1, 1, SameConvPadding) >>
        Flatten[Float]("Layer_2/Flatten") >>
        Linear[Float]("OutputLayer/Linear", 128) >>
        Linear[Float]("OutputLayer/Linear", 10) 



### Loss, Optimizer and wrapping in an Estimator

In [None]:
val loss = SparseSoftmaxCrossEntropy[Float, Long, Float]("Loss") >>
    Mean("Loss/Mean")
val optimizer = tf.train.Adam()
val model = Model.simpleSupervised(input, trainInput, layer, loss, optimizer)

// Create an estimator and train the model.
val estimator = InMemoryEstimator(model)



### Training!

In [None]:
val start = System.currentTimeMillis()
estimator.train(() => trainData, StopCriteria(maxSteps = Some(12)))
val end = System.currentTimeMillis()
println(end-start)

In [None]:
def accuracy(images: Tensor[UByte], labels: Tensor[UByte]): Float = {
    val predictions = estimator.infer(() => images.reshape(Shape(images.shape(0), images.shape(1), images.shape(2), 1)).toFloat)
    predictions
      .argmax(1).toUByte
      .equal(labels).toFloat
      .mean().scalar
}

val nbSample = 1000
println(s"Train accuracy = ${accuracy(dataset.trainImages.slice(0 :: nbSample, ::, ::), dataset.trainLabels.slice(0 :: nbSample))}")
println(s"Test accuracy = ${accuracy(dataset.testImages.slice(0 :: nbSample, ::, ::), dataset.testLabels.slice(0 :: nbSample))}")

## Test results

In [None]:
val images = dataset.testImages

def inferOnSelectedImage(indexes: Seq[Int], images: Tensor[UByte]) {
    indexes.foreach { index => 
        val imageToInfer = images.slice(index, ::, ::).reshape(Shape(1, 28, 28, 1))
        val predictions = estimator.infer(() => imageToInfer.toFloat)
        println(s"Label infered: ${predictions.argmax(1).scalar}")
        val png = Session().run(fetches = tf.decodeRaw[Byte](tf.image.encodePng(imageToInfer.reshape(Shape(28, 28, 1)))))
        Image(png.entriesIterator.toArray).withFormat(Image.PNG).withWidth(100).withHeight(100).display 
    }
}

inferOnSelectedImage((10 to 20), images)
