# Train a ready to use TensorFlow model with a simple pipeline

In [1]:
import os
import sys
import warnings
warnings.filterwarnings("ignore")

import numpy as np

# the following line is not required if BatchFlow is installed as a python package.
sys.path.append("../..")
from batchflow import Pipeline, B, C, F, V
from batchflow.opensets import MNIST, CIFAR10
from batchflow.models.tf import ResNet18

If you comment out the line below, the training will take much more time and the accuracy might slightly decrease.
So it is always a good idea to import [best_practice](https://analysiscenter.github.io/batchflow/intro/best_practice.html)

In [2]:
from batchflow import best_practice

BATCH_SIZE might be increased for modern GPUs with lots of memory (4GB and higher).

In [3]:
BATCH_SIZE = 64

# Create a dataset

[MNIST](http://yann.lecun.com/exdb/mnist/) is a dataset of handwritten digits frequently used as a baseline for machine learning tasks.

Downloading MNIST database might take a few minutes to complete.

In [4]:
dataset = MNIST()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Extracting /tmp/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading Extractinghttp://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
 Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
/tmp/train-labels-idx1-ubyte.gz
Extracting /tmp/t10k-images-idx3-ubyte.gz
Extracting /tmp/t10k-labels-idx1-ubyte.gz


There are also predefined CIFAR10 and CIFAR100 datasets.

# Define a pipeline config

Config allows to create flexible pipelines which take parameters.

For instance, if you put a model type into config, you can run a pipeline against different models.

See [a list of available models](https://analysiscenter.github.io/batchflow/intro/tf_models.html#ready-to-use-models) to choose the one which fits you best.

In [5]:
config = dict(model=ResNet18)

# Create a template pipeline

A template pipeline is not linked to any dataset. It's just an abstract sequence of actions, so it cannot be executed, but it serves as a convenient building block.

In [6]:
train_template = (Pipeline(config=config)
                .init_variable('loss_history', init_on_each_run=list)
                .init_variable('current_loss')
                .init_model('dynamic', C('model'), 'conv_nn',
                            config={'inputs': dict(images={'shape': B('image_shape')},
                                                   labels={'classes': 10}),
                                    'initial_block/inputs': 'images'})
                .to_array()
                .train_model('conv_nn', fetches='loss', images=B('images'), labels=B('labels'),
                             save_to=V('current_loss'))
                .update_variable('loss_history', V('current_loss'), mode='a')
)

# Train the model

Apply a dataset to a template pipeline to create a runnable pipeline:

In [7]:
train_pipeline = (train_template << dataset.train)

Run the pipeline (it might take from a few minutes to a few hours depending on your hardware)

In [8]:
train_pipeline.run(BATCH_SIZE, shuffle=True, n_epochs=1, drop_last=True, bar=True, prefetch=1)

  3%|▎         | 25/937 [00:42<22:30,  1.48s/it]

KeyboardInterrupt: 

Note that the progress bar often increments by 2 at a time - that's prefetch in action.

It does not give much here, though, since almost all time is spent in model training which is performed under a thread-lock one batch after another without any parallelism (otherwise the model would not learn anything as different batches would rewrite one another's model weights updates).

# Test the model

It is much faster than training, but if you don't have GPU it would take some patience.

In [9]:
test_pipeline = (dataset.test.p
                .import_model('conv_nn', train_pipeline)
                .init_variable('predictions') 
                .init_variable('metrics', init_on_each_run=None) 
                .to_array()
                .predict_model('conv_nn', fetches='predictions', images=B('images'), labels=B('labels'),
                               save_to=V('predictions'))
                .gather_metrics('class', targets=B('labels'), predictions=V('predictions'),
                                fmt='logits', axis=-1, save_to=V('metrics'), mode='w')
                .run(BATCH_SIZE, shuffle=True, n_epochs=1, drop_last=True, bar=True)
)


  0%|          | 0/156 [00:00<?, ?it/s][A
  1%|▏         | 2/156 [00:00<01:08,  2.25it/s][A
  3%|▎         | 4/156 [00:01<00:50,  2.98it/s][A
  4%|▍         | 6/156 [00:01<00:38,  3.86it/s][A
  5%|▌         | 8/156 [00:01<00:31,  4.77it/s][A
  6%|▌         | 9/156 [00:01<00:26,  5.49it/s][A
  6%|▋         | 10/156 [00:01<00:24,  6.06it/s][A
  7%|▋         | 11/156 [00:01<00:21,  6.63it/s][A
  8%|▊         | 12/156 [00:01<00:20,  7.12it/s][A
  8%|▊         | 13/156 [00:01<00:19,  7.49it/s][A
  9%|▉         | 14/156 [00:02<00:18,  7.77it/s][A
 10%|█         | 16/156 [00:02<00:15,  8.79it/s][A
 12%|█▏        | 18/156 [00:02<00:14,  9.67it/s][A
 13%|█▎        | 20/156 [00:02<00:13, 10.36it/s][A
 14%|█▍        | 22/156 [00:02<00:12, 10.95it/s][A
 15%|█▌        | 24/156 [00:02<00:11, 11.19it/s][A
 17%|█▋        | 26/156 [00:03<00:11, 11.47it/s][A
 18%|█▊        | 28/156 [00:03<00:10, 11.68it/s][A
 19%|█▉        | 30/156 [00:03<00:10, 11.92it/s][A
 21%|██        | 32/156 [

Let's get the accumulated [metrics information](https://analysiscenter.github.io/batchflow/intro/models.html#model-metrics)

In [10]:
metrics = test_pipeline.get_variable('metrics')

Now we can easiliy calculate any metrics we need

In [11]:
metrics.evaluate('accuracy')

0.890625

In [12]:
metrics.evaluate(['false_positive_rate', 'false_negative_rate'], multiclass=None)

{'false_positive_rate': array([0.01785714, 0.        , 0.        , 0.0483871 , 0.        ,
        0.        , 0.        , 0.03703704, 0.01694915, 0.        ]),
 'false_negative_rate': array([0.        , 0.125     , 0.33333333, 0.        , 0.        ,
        0.14285714, 0.        , 0.2       , 0.        , 1.        ])}

# Save the model
After learning the model, you may need to save it. It's easy to do this.

In [13]:
train_pipeline.save_model('conv_nn', path='path/to/save')

## What's next?

See [the image augmentation tutorial](./06_image_augmentation.ipynb) or return to the [table of contents](./00_description.ipynb).