# Train a ready to use TensorFlow model with a simple pipeline

In [85]:
import os
import sys
import warnings
warnings.filterwarnings("ignore")

import numpy as np

# the following line is not required if BatchFlow is installed as a python package.
sys.path.append("../..")
from batchflow import Pipeline, B, C, F, V, Batch, ImagesBatch, action
from batchflow.opensets import MNIST, CIFAR10
from batchflow.models.tf import ResNet18

If you comment out the line below, the training will take much more time and the accuracy might slightly decrease.
So it is always a good idea to import [best_practice](https://analysiscenter.github.io/batchflow/intro/best_practice.html)

In [72]:
from batchflow import best_practice

BATCH_SIZE might be increased for modern GPUs with lots of memory (4GB and higher).

In [73]:
BATCH_SIZE = 64

We make function convetring the labels more or equal 5 to 1 and 0 otherwise. 

In [74]:
def making_two_class(label):
    return np.where(label >= 5, 1, 0)

# Create a dataset

[MNIST](http://yann.lecun.com/exdb/mnist/) is a dataset of handwritten digits frequently used as a baseline for machine learning tasks.

Downloading MNIST database might take a few minutes to complete.

In [75]:
dataset = MNIST()

DownloadingDownloadingDownloadingDownloading    http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gzhttp://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gzhttp://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gzhttp://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz



ExtractingExtractingExtractingExtracting    C:\Users\AACE~1\AppData\Local\Temp\train-labels-idx1-ubyte.gzC:\Users\AACE~1\AppData\Local\Temp\t10k-images-idx3-ubyte.gzC:\Users\AACE~1\AppData\Local\Temp\train-images-idx3-ubyte.gzC:\Users\AACE~1\AppData\Local\Temp\t10k-labels-idx1-ubyte.gz





There are also predefined CIFAR10 and CIFAR100 datasets.

# Define a pipeline config

Config allows to create flexible pipelines which take parameters.

For instance, if you put a model type into config, you can run a pipeline against different models.

See [a list of available models](https://analysiscenter.github.io/batchflow/intro/tf_models.html#ready-to-use-models) to choose the one which fits you best.

In [76]:
config = dict(model=ResNet18)

# Create a template pipeline

A template pipeline is not linked to any dataset. It's just an abstract sequence of actions, so it cannot be executed, but it serves as a convenient building block.

Add our new function to pipeline and change the number of classes to 2.

In [77]:
train_template = (Pipeline(config=config)
                .init_variable('loss_history', init_on_each_run=list)
                .init_variable('current_loss')
                .init_model('dynamic', C('model'), 'conv_nn',
                            config={'inputs': dict(images={'shape': B('image_shape')},
                                                   labels={'classes': 2}),
                                    'initial_block/inputs': 'images'})
                .to_array()
                .apply_transform(making_two_class, src='labels', dst='labels')
                .train_model('conv_nn', fetches='loss', images=B('images'), labels=B('labels'),
                             save_to=V('current_loss'))
                .update_variable('loss_history', V('current_loss'), mode='a')
                
)

# Train the model

Apply a dataset to a template pipeline to create a runnable pipeline:

In [78]:
train_pipeline = (train_template << dataset.train)

Show that our labels are two class now. 

In [79]:
train_pipeline.next_batch(BATCH_SIZE).labels

array([1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0,
       1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1,
       0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0])

In [80]:
train_pipeline.run(BATCH_SIZE, shuffle=True, n_epochs=1, drop_last=True, bar=True, prefetch=1)

100%|████████████████████████████████████████| 937/937 [28:38<00:00,  1.56s/it]


<batchflow.pipeline.Pipeline at 0x2700cee95f8>

Note that the progress bar often increments by 2 at a time - that's prefetch in action.

It does not give much here, though, since almost all time is spent in model training which is performed under a thread-lock one batch after another without any parallelism (otherwise the model would not learn anything as different batches would rewrite one another's model weights updates).

# Test the model

It is much faster than training, but if you don't have GPU it would take some patience.

Add our new function to a test pipeline.

In [81]:
test_pipeline = (dataset.test.p
                .import_model('conv_nn', train_pipeline)
                .init_variable('predictions') 
                .init_variable('metrics', init_on_each_run=None) 
                .to_array()
                .apply_transform(making_two_class, src='labels', dst='labels')
                .predict_model('conv_nn', fetches='predictions', images=B('images'), labels=B('labels'),
                               save_to=V('predictions'))
                .gather_metrics('class', targets=B('labels'), predictions=V('predictions'),
                                fmt='logits', axis=-1, save_to=V('metrics'), mode='w')
                .run(BATCH_SIZE, shuffle=True, n_epochs=1, drop_last=True, bar=True)
)

100%|████████████████████████████████████████| 156/156 [00:14<00:00, 11.10it/s]


Let's get the accumulated [metrics information](https://analysiscenter.github.io/batchflow/intro/models.html#model-metrics)

In [82]:
metrics = test_pipeline.get_variable('metrics')

Now we can easiliy calculate any metrics we need

In [83]:
metrics.evaluate('accuracy')

0.984375

In [84]:
metrics.evaluate(['false_positive_rate', 'false_negative_rate'], multiclass=None)

{'false_negative_rate': 0.03571428571428571, 'false_positive_rate': 0.0}

# Save the model
After learning the model, you may need to save it. It's easy to do this.

In [13]:
train_pipeline.save_model('conv_nn', path='path/to/save')

## What's next?

See [the image augmentation tutorial](./06_image_augmentation.ipynb) or return to the [table of contents](./00_description.ipynb).