# CNTK: Manual minibatch loop and Training Session
 
This tutorial demonstrates a canonical training minibatch loop and how it can be rewritten using a training session. We discuss what the main caveats are and how they can be avoided.

__Note: Please consider using a higher-level [Function.train](https://www.cntk.ai/pythondocs/cntk.ops.functions.html#cntk.ops.functions.Function.train)/[Function.test](https://www.cntk.ai/pythondocs/cntk.ops.functions.html#cntk.ops.functions.Function.test) functionality instead of directly using the training session API. For more information please see [this tutorial](https://github.com/Microsoft/CNTK/blob/v2.0/Tutorials/CNTK_200_GuidedTour.ipynb).__

## Manual training loop

Many scripts in CNTK have a very similar structure:
 - they create a network
 - instantiate a trainer and a learner with appropriate hyper-parameters
 - load training and testing data with minibatch sources
 - then run the main training loop fetching the data from the train minibatch source and feeding it to the trainer for N samples/sweeps
 - at the end they perform the eval loop using data from the test minibatch source and calling [test_minibatch](https://www.cntk.ai/pythondocs/cntk.train.trainer.html?highlight=test_minibatch#cntk.train.trainer.Trainer.test_minibatch) on the trainer or evaluator

As an example for such a script we will take a toy task of learning XOR operation with a simple feed forward network.
We will try to learn the following function:

|x|y|result|
|:-|:-|:------|
|0|0|   0  |
|0|1|   1  |
|1|0|   1  |
|0|0|   0  |

The network will have two dense layers, we use [tanh](https://www.cntk.ai/pythondocs/cntk.ops.html?highlight=tanh#cntk.ops.tanh) as an activation for the first layer and no activation for the second. The sample script is presented below:

In [2]:
from __future__ import print_function
from __future__ import division

import cntk
import cntk.ops
import cntk.io
import cntk.train

from cntk.layers import Dense, Sequential
from cntk.io import StreamDef, StreamDefs, MinibatchSource, CTFDeserializer
from cntk.logging import ProgressPrinter

# Let's prepare data in the CTF format. It exactly matches
# the table above
INPUT_DATA = r'''|xy 0 0	|r 0
|xy 1 0	|r 1
|xy 0 1	|r 1
|xy 1 1	|r 0
'''

# Write the data to a temporary file
input_file = 'input'
with open(input_file, 'w') as f:
    f.write(INPUT_DATA)

# Create a network
xy = cntk.input_variable(2)
label = cntk.input_variable(1)

model = Sequential([
    Dense(2, activation=cntk.ops.tanh),
    Dense(1)])

z = model(xy)
loss = cntk.squared_error(z, label)

# Define our input data streams
streams = StreamDefs(
    xy = StreamDef(field='xy', shape=2),
    r = StreamDef(field='r', shape=1))

# Create a learner and a trainer and a progress writer to 
# output current progress
learner = cntk.sgd(model.parameters, cntk.learning_rate_schedule(0.1, cntk.UnitType.sample))
trainer = cntk.train.Trainer(z, (loss, loss), learner, ProgressPrinter(freq=10))

# Now let's create a minibatch source for out input file
mb_source = MinibatchSource(CTFDeserializer(input_file, streams))
input_map = { xy : mb_source['xy'], label : mb_source['r'] }

# Run a manual training minibatch loop
minibatch_size = 4
max_samples = 800
train = True
while train and trainer.total_number_of_samples_seen < max_samples:
    data = mb_source.next_minibatch(minibatch_size, input_map)
    train = trainer.train_minibatch(data)

# Run a manual evaluation loop ussing the same data file for evaluation
test_mb_source = MinibatchSource(CTFDeserializer(input_file, streams), randomize=False, max_samples=100)
test_input_map = { xy : test_mb_source['xy'], label : test_mb_source['r'] }
total_samples = 0
error = 0.
data = test_mb_source.next_minibatch(32, input_map)
while data:
    total_samples += data[label].number_of_samples 
    error += trainer.test_minibatch(data) * data[label].number_of_samples
    data = test_mb_source.next_minibatch(32, test_input_map)

print("Error %f" % (error / total_samples))

Learning rate per sample: 0.1
 Minibatch[   1-  10]: loss = 0.307924 * 40, metric = 30.79% * 40;
 Minibatch[  11-  20]: loss = 0.235954 * 40, metric = 23.60% * 40;
 Minibatch[  21-  30]: loss = 0.219182 * 40, metric = 21.92% * 40;
 Minibatch[  31-  40]: loss = 0.191043 * 40, metric = 19.10% * 40;
 Minibatch[  41-  50]: loss = 0.147003 * 40, metric = 14.70% * 40;
 Minibatch[  51-  60]: loss = 0.088223 * 40, metric = 8.82% * 40;
 Minibatch[  61-  70]: loss = 0.035363 * 40, metric = 3.54% * 40;
 Minibatch[  71-  80]: loss = 0.009170 * 40, metric = 0.92% * 40;
 Minibatch[  81-  90]: loss = 0.001756 * 40, metric = 0.18% * 40;
 Minibatch[  91- 100]: loss = 0.010385 * 40, metric = 1.04% * 40;
 Minibatch[ 101- 110]: loss = 0.242481 * 40, metric = 24.25% * 40;
 Minibatch[ 111- 120]: loss = 0.004554 * 40, metric = 0.46% * 40;
 Minibatch[ 121- 130]: loss = 0.000244 * 40, metric = 0.02% * 40;
 Minibatch[ 131- 140]: loss = 0.000030 * 40, metric = 0.00% * 40;
 Minibatch[ 141- 150]: loss = 0.000004 *

As it can be seen above, the actual model is specified in just two lines, the rest is a boilerplate code to iterate over the data and feed it manually for training and evaluation. With a manual loop, the user has the complete flexibility how to feed the data, but she also has to take several not so obvious things into account.

For simplicity we use a toy example, but imaging a situation when your job runs for a couple of days.

### Failover and recovery

For the small sample above the recovery is not important, but in case the training spans several weeks or days it is not safe to assume that the machine stays online all the time and there are no hardware or software glitches. If the machine reboots, goes down or the script has a bug the user will have to rerun the same experiment from the beginning. That is highly undesirable. To avoid that CNTK allows the user to perform checkpoints and restore from them in the event of failure.

One of the means to save the model state in CNTK is by using [save method](https://cntk.ai/pythondocs/cntk.ops.functions.html?highlight=save#cntk.ops.functions.Function.save) on the [Function class](https://cntk.ai/pythondocs/cntk.ops.functions.html?highlight=save#cntk.ops.functions.Function).
It is worth mentioning that this function only saves the model state, but there are other stateful entities in the script, including:
 * minibatch sources
 * trainer
 * learners
 
In order to save the complete state of the script, the user has to manually save the current state of the minibatch source and the trainer. The minibatch source provides [get_checkpoint_state](https://www.cntk.ai/pythondocs/cntk.io.html?highlight=get_checkpoint_state#cntk.io.MinibatchSource.get_checkpoint_state) method, the result can be passed to the trainer [save_checkpoint](https://www.cntk.ai/pythondocs/cntk.train.trainer.html?highlight=save_checkpoint#cntk.train.trainer.Trainer.save_checkpoint) method, that takes care of saving the state to disk or exchanging the state in case of distributed training. There are also the corresponding [restore_from_checkpoint](https://www.cntk.ai/pythondocs/cntk.train.trainer.html?highlight=restore_from_checkpoint#cntk.train.trainer.Trainer.restore_from_checkpoint) methods on the trainer and the minibatch source that can be used for restore. To recover from error, on start up the user has to restore a state using the trainer and set the current position of the minibatch source.

With the above in mind, let's rewrite our loop as follows:

In [12]:
# Run a manual training minibatch loop with checkpointing
import os

# Initialize main objects
mb_source = MinibatchSource(CTFDeserializer(input_file, streams))
input_map = { xy : mb_source['xy'], label : mb_source['r'] }

learner = cntk.sgd(model.parameters, cntk.learning_rate_schedule(0.1, cntk.UnitType.sample))
trainer = cntk.train.Trainer(z, (loss, loss), learner, ProgressPrinter(freq=10))

# Try to restore if the checkpoint exists
checkpoint = 'manual_loop_checkpointed'

if os.path.exists(checkpoint):
    print("Trying to restore from checkpoint")
    mb_source_state = trainer.restore_from_checkpoint(checkpoint)
    mb_source.restore_from_checkpoint(mb_source_state)
    print("Restore has finished successfully")
else:
    print("No restore file found")
    
checkpoint_frequency = 100
last_checkpoint = 0
train = True
while train and trainer.total_number_of_samples_seen < max_samples:
    data = mb_source.next_minibatch(minibatch_size, input_map)
    train = trainer.train_minibatch(data)
    if trainer.total_number_of_samples_seen / checkpoint_frequency != last_checkpoint:
        mb_source_state = mb_source.get_checkpoint_state()
        trainer.save_checkpoint(checkpoint, mb_source_state)
        last_checkpoint = trainer.total_number_of_samples_seen / checkpoint_frequency


Trying to restore from checkpoint
Restore has finished successfully


At the beginning we check if the checkpoint file exists and we can restore from it. After that we start the training. Our loop is based on the total number of samples the trainer has seen. This information is included in the checkpoint, so in 
case of failure the training will resume at the saved position (this will become even more important for distributed training).

Depending on the checkpointing frequency the above script retrieves the current state of the minibatch source and creates a checkpoint using the trainer. If the script iterates over the same data many times, saving the state of the minibatch source is not that important, but for huge workloads you probably do not want to start seeing the same data from the beginning.

At some point the user will want to parallelize the script to decrease the training time. Let's look how this can be done in the next section.

### Distributed manual loop

In order to make training distributed CNTK provides a set of distributed learner that encapsulate a set of algorithms (1BitSGD, BlockMomentum, data parallel SGD) that uses MPI to exchage the state. From the script perspecitve, almost everything stays the same. The only difference is that the user needs to wrap the learner into the corresponding distributed learner and make sure she picks up the data from the minibatch source based on the current worker rank (also the script should be run with ```mpiexec```):

In [13]:
# Run a manual training minibatch loop with distributed learner
checkpoint = 'manual_loop_distributed'

mb_source = MinibatchSource(CTFDeserializer(input_file, streams))
input_map = { xy : mb_source['xy'], label : mb_source['r'] }

# Make sure the learner is distributed
learner = cntk.distributed.data_parallel_distributed_learner(cntk.sgd(model.parameters, cntk.learning_rate_schedule(0.1, cntk.UnitType.sample)))
trainer = cntk.train.Trainer(z, (loss, loss), learner, ProgressPrinter(freq=10))

if os.path.exists(checkpoint):
    print("Trying to restore from checkpoint")
    mb_source_state = trainer.restore_from_checkpoint(checkpoint)
    mb_source.restore_from_checkpoint(mb_source_state)
else:
    print("No restore file found")

last_checkpoint = 0
train = True
partition = cntk.distributed.Communicator.rank()
num_partitions = cntk.distributed.Communicator.num_workers()
while train and trainer.total_number_of_samples_seen < max_samples:
    # Make sure each worker gets its own data only
    data = mb_source.next_minibatch(minibatch_size_in_samples = minibatch_size,
                                    input_map = input_map, device = cntk.use_default_device(), 
                                    num_data_partitions=num_partitions, partition_index=partition)
    train = trainer.train_minibatch(data)
    if trainer.total_number_of_samples_seen / checkpoint_frequency != last_checkpoint:
        mb_source_state = mb_source.get_checkpoint_state()
        trainer.save_checkpoint(checkpoint, mb_source_state)
        last_checkpoint = trainer.total_number_of_samples_seen / checkpoint_frequency

# When you use distributed learners, please call finalize MPI at the end of your script, 
# see the next cell.
# cntk.distributed.Communicator.finalize()

Trying to restore from checkpoint


In order for distribution to work properly, the minibatch loop should exit by all workers at the same time. Some of the workers can have more data then the others, so the exit condition of the loop should be based on the return value of the trainer (if no more work should be done by a particular worker this can be communicated by passing an empty minibatch to ```train_minibatch```).

As has been noted before, the decisions inside the loop are based on the [Trainer.total_number_of_samples_seen](https://www.cntk.ai/pythondocs/cntk.train.trainer.html?highlight=restore_from_checkpoint#cntk.train.trainer.Trainer.total_number_of_samples_seen). Some of the operations (i.e. ```train_minibatch```, checkpoint, cross validation, if done in a distributed fashion) require synchronization and to match among all the workers they use a global state - the global number of samples seen by the trainer.

Even though writing manual training loops brings all the flexibility to the user, it can also be error prone and require a lot of boilerplate code to make everything work. When this flexibility if not required, it is better to use a higher abstraction can be used.

## Using Training Session

Instead of writing the training loop manually and taking care of checkpointing and distribution herself, the user can delegate this aspects to the training session. It automatically takes care of the following things:
    1. checkpointing
    2. cross validation
    3. testing/evaluation

All that is needed from the user is to provide the corresponding configuration parameters. Plus to the higher abstraction the training session is also implemented in C++, so it is generally faster than writing a loop in Python:

In [None]:
checkpoint = 'training_session'

# Minibatch sources
mb_source = MinibatchSource(CTFDeserializer(input_file, streams))
test_mb_source = MinibatchSource(CTFDeserializer(input_file, streams), randomize=False, max_samples=100)

learner = cntk.distributed.data_parallel_distributed_learner(cntk.sgd(model.parameters, cntk.learning_rate_schedule(0.1, cntk.UnitType.sample)))
trainer = cntk.train.Trainer(z, (loss, loss), learner, ProgressPrinter(freq=1))

test_config=cntk.TestConfig(minibatch_source = test_mb_source,
                            model_inputs_to_streams={ xy : test_mb_source['xy'], label : test_mb_source['r'] })

session = cntk.training_session(
    trainer = trainer, mb_source = mb_source, 
    mb_size = minibatch_size, 
    model_inputs_to_streams={ xy : mb_source['xy'], label : mb_source['r'] },
    max_samples = max_samples,
    checkpoint_config=cntk.CheckpointConfig(frequency=checkpoint_frequency, filename=checkpoint),
    test_config=cntk.TestConfig(minibatch_source = test_mb_source, minibatch_size = minibatch_size,
                                model_inputs_to_streams={ xy : test_mb_source['xy'], label : test_mb_source['r'] }))

session.train()

# When you use distributed learners, please call finalize MPI at the end of your script
cntk.distributed.Communicator.finalize()

Let's see how to configure different aspects of the training session:

### Progress tracking
In order to report progress, the training session uses [Trainer.summarize_training_progress](https://www.cntk.ai/pythondocs/cntk.train.trainer.html?highlight=restore_from_checkpoint#cntk.train.trainer.Trainer.summarize_training_progress) after each progress_frequency samples (rounded to the border of minibatches). Implicitly this call is despatched to the corresponding calls of the [ProgressWriter](https://www.cntk.ai/pythondocs/cntk.logging.progress_print.html?highlight=progresswriter#module-cntk.logging.progress_print), which has its own set of parameters (i.e. freq can be used to specify how often to print loss value). ProgressWriter should be specified during trainer creation.
If you need to have a custom logic for retrieving current status, please consider implementing your own ProgressWriter or using [Function.train](https://www.cntk.ai/pythondocs/cntk.ops.functions.html#cntk.ops.functions.Function.train) method.

### Checkpointing
[Checkpoint configuraiton](https://www.cntk.ai/pythondocs/cntk.train.training_session.html?highlight=checkpointconfig#cntk.train.training_session.CheckpointConfig) specifies how often to save a checkpoint to the given file. When given, the training session takes care of saving/restoring the state accross the trainer/learners/minibatch source and propogating this information among distributed workers. If you need to preserve all checkpoints that were taken during training, please set ```preserveAll``` to true. The checkpointing frequency is specified in samples.

### Cross validation
When [cross validation](https://www.cntk.ai/pythondocs/cntk.train.training_session.html?highlight=checkpointconfig#cntk.train.training_session.CrossValidationConfig) config is given, the training session runs the cross validation on the specified minibatch source with the specified frequency and reports average metric error. The user can also provide a cross validation callback, that will be called with the specified frequency. It is up to the user to perform cross validation in the callback and return back ```True``` if the training should be continued, or ```False``` otherwise. 

### Testing
If the test configuration is given, after completion of the training the training session runs evaluation on the specified minibatch source. If you need to run only evaluation without training, consider using [Function.test](https://www.cntk.ai/pythondocs/cntk.ops.functions.html#cntk.ops.functions.Function.test) method instead.