# Advanced Research Module Usage

We start with some useful imports and constant definitions

In [1]:
import sys
import os
import shutil

import warnings
warnings.filterwarnings('ignore')

from tensorflow import logging
logging.set_verbosity(logging.ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import matplotlib
%matplotlib inline

In [2]:
sys.path.append('../../..')

from batchflow import Pipeline, B, C, V, D, L
from batchflow.opensets import MNIST
from batchflow.models.tf import VGG7, VGG16
from batchflow.research import Research, Option

In [3]:
BATCH_SIZE=64
ITERATIONS=10
TEST_EXECUTE_FREQ=10

In [4]:
def clear_previous_results(res_name):
    if os.path.exists(res_name):
        shutil.rmtree(res_name)

## Reducing Extra Dataset Loads

### Running Research Sequentially

In previous tutorial we learned how to use Research to run experimetrs multiple times and with varying parameters.

Firstly we define a dataset to work with and a pipeline that reads this dataset

In [5]:
mnist = MNIST()
train_root = mnist.train.p.run_later(BATCH_SIZE, shuffle=True, n_epochs=None)

Then we define a grid of parameters whose nodes will be used to form separate experiments

In [6]:
domain = Option('layout', ['cna', 'can']) * Option('bias', [True, False])        

These parameters can be passed to model's configs using named expressions. 

In [7]:
model_config={
    'inputs/images/shape': B('image_shape'),
    'inputs/labels/classes': D('num_classes'),
    'inputs/labels/name': 'targets',
    'initial_block/inputs': 'images',
    'body/block/layout': C('layout'),
    'common/conv/use_bias': C('bias'),
}

After that we define a pipeline to run during our experiments. We initialise a pipeline variable `'loss'` to store loss on each iteration

In [8]:
train_template = (Pipeline()
            .init_variable('loss')
            .init_model('dynamic', VGG7, 'conv', config=model_config)
            .to_array()
            .train_model('conv', 
                         images=B('images'), labels=B('labels'),
                         fetches='loss', save_to=V('loss', mode='w'))
)

Each research is assigned with a name and writes its results to a folder with this name. The names must be unique, so if one attempts to run a research with a name that already exists, an error will be thrown. In the cell below we clear the results of previous research runs so as to allow multiple runs of a research. This is done solely for purposes of ths tutorial and should not be done in real work

In [9]:
res_name = 'simple_research'
clear_previous_results(res_name)

Finally we define a Research that runs the pipeline substituting its parameters using different nodes of the `grid`, and saves values of the `'loss'` named expressions to results.

In [10]:
research = (Research()
            .add_pipeline(train_root + train_template, variables='loss')
            .init_domain(domain, n_reps=4))

research.run(n_iters=10, name=res_name, bar=True)

Research simple_research is starting...


Domain updated: 0: 100%|██████████| 160/160.0 [05:54<00:00,  2.21s/it]


<batchflow.research.research.Research at 0x7fe1c82d66d8>

16 experiments are run (4 grid nodes x 4 repetitions) each consisting of 10 iterations.

We can load results of the research and see that the table has 160 entries.

In [11]:
research.load_results().info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 160 entries, 0 to 159
Data columns (total 7 columns):
bias            160 non-null object
iteration       160 non-null int64
layout          160 non-null object
loss            160 non-null float64
name            160 non-null object
repetition      160 non-null int64
sample_index    160 non-null int64
dtypes: float64(1), int64(3), object(3)
memory usage: 8.8+ KB


### Branches: Reducing Data Loading and Preprocessing

Each experiment can be divided into 2 stages: root stage that is roughly same for all experiments (for example, data loading and preprocessing) and branch stage that varies. If data loading and preprocessing take significant time one can use the batches generated on a single root stage to feed to several branches that belong to different experiments. 

For example, if you want to test 4 different models, and yor workflow includes some complicated data preprocessing and augmentation that is done separatey for each model, you may want to do preprocessing and augmentation once and feed resulting batches of data to all these 4 models. 

![Title](img/Branch_Root_Figure_crop.png)

Figure above shows the difference. 

On the left, simple workflow is shown. Same steps of common preprocessing are performed 4 times, and the batches that are generated after different runs of common stages are also different due to shuffling and possible randomisation inside common steps.

On the right, common steps are performed once on root stage and the very same batches are passed to different branches. This has the advantage of reducing extra computations but it also reduces variability becauce all models get exactly same pieces of data.

To perform root-branch division, one should pass `root` and `branch` parameters to `add_pipeline()` and define number of branches per root via `branches` parameter of `run()`.

A root with corresponding branches is called a **job**. Note that different roots still produce different batches.

One constraint when using branches is that branch pipelines do not calculate dataset variables properly, so we have to redefine `model_config` and `train_template` and hard-code `'inputs/labels/classes'` parameter

In [12]:
model_config={
    'inputs/images/shape': B('image_shape'),
    'inputs/labels/classes': 10,
    'inputs/labels/name': 'targets',
    'initial_block/inputs': 'images',
    'body/block/layout': C('layout'),
    'common/conv/use_bias': C('bias'),
}

train_template = (Pipeline()
            .init_variable('loss')
            .init_model('dynamic', VGG7, 'conv', config=model_config)
            .to_array()
            .train_model('conv', 
                         images=B('images'), labels=B('labels'),
                         fetches='loss', save_to=V('loss', mode='w'))
)

res_name = 'no_extra_dataload_research'
clear_previous_results(res_name)
    
research = (Research()
            .add_pipeline(root=train_root, branch=train_template, variables='loss')
            .init_domain(domain, n_reps=4))

research.run(n_iters=10, branches=8, name=res_name, bar=True)

Research no_extra_dataload_research is starting...


Domain updated: 0: 100%|██████████| 20/20.0 [02:04<00:00,  6.23s/it]


<batchflow.research.research.Research at 0x7fe1ad886518>

Scince every root is now assigned to 8 branches, there are only 2 jobs.

We can see that the whole research duration reduced.
In this toy example we use only 10 iterations to make the effect of reduced dataset load more visible.

The numbers of results entries is the same. 

In [None]:
research.load_results().info()

### Functions on Root

If each job has several branches, they are all executed in parallel threads. To run a function on root, one should add it with `on_root=True`.

Functions on root have required parameters `iteration` and `experiments` and optional keyword parameters. They are not allowed to return anything

In [15]:
res_name = 'on_root_research'
clear_previous_results(res_name)

def function_on_root():
    print('on root')
    
research = (Research()
            .add_callable(function_on_root, execute="#0", on_root=True)
            .add_pipeline(root=train_root, branch=train_template, variables='loss')
            .init_domain(domain, n_reps=4)
           )

research.run(branches=8, n_iters=ITERATIONS, name=res_name, bar=True)

Research on_root_research is starting...


Domain updated: 0: : 0it [00:00, ?it/s]

on root


Domain updated: 0:  50%|█████     | 10/20.0 [01:04<01:04,  6.49s/it]

on root


Domain updated: 0: 100%|██████████| 20/20.0 [02:02<00:00,  6.14s/it]


<batchflow.research.research.Research at 0x7f93cf54f4a8>

## Improving Performance

Research can ran experiments in parallel if number of workers if defined in `workers` parameter. 
Each worker starts in a separate process and performs one or several jobs assigned to it. Moreover if several GPU's are accessible one can pass indices of GPUs to use via `gpu` parameter.

Following parameters are also useful to control research execution:
* `timeout` in `run` specifies time in minutes to kill non-responding job, default value is 5
* `trials` in `run` specifies number of attempts to restart a job, default=2
* `dump` in `add_pipeline`, `add_function` and `get_metrics` tells how often results are written to disk and cleared. By default results are dumped on the last iteration, but if they consume too much memory one may want to do it more often. The format is same as `execute`

In [None]:
from batchflow.research import ResearchPipeline as RP

os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"

model_config={
    'device': C('device'), # it's technical parameter for TFModel
    'inputs/images/shape': B('image_shape'),
    'inputs/labels/classes': 10,
    'inputs/labels/name': 'targets',
    'initial_block/inputs': 'images',
    'body/block/layout': C('layout'),
    'common/conv/use_bias': C('bias'),
}

train_template = (Pipeline()
            .init_variable('train_loss')
            .init_model('dynamic', VGG7, 'conv', config=model_config)
            .to_array()
            .train_model('conv', 
                         images=B('images'), labels=B('labels'),
                         fetches='loss', save_to=V('train_loss', mode='w'))
)

test_root = mnist.test.p.run_later(BATCH_SIZE, shuffle=True, n_epochs=1) #Note  n_epochs=1

test_template = (Pipeline()
                 .init_variable('predictions')
                 .init_variable('metrics')
                 .import_model('conv', C('import_from'))
                 .to_array()
                 .predict_model('conv', 
                                images=B('images'), labels=B('labels'),
                                fetches='predictions', save_to=V('predictions'))
                 .gather_metrics('class', targets=B('labels'), predictions=V('predictions'), 
                                fmt='logits', axis=-1, save_to=V('metrics')))

research = (Research()
            .add_pipeline(root=train_root, branch=train_template, variables='train_loss', name='train_ppl',
                          dump=TEST_EXECUTE_FREQ)
            .add_pipeline(root=test_root, branch=test_template, name='test_ppl',
                         execute=TEST_EXECUTE_FREQ, run=True, import_from=RP('train_ppl'))
            .get_metrics(pipeline='test_ppl', metrics_var='metrics', metrics_name='accuracy',
                         returns='accuracy', 
                         execute=TEST_EXECUTE_FREQ,
                         dump=TEST_EXECUTE_FREQ,)
            .init_domain(domain, n_reps=4))

res_name = 'faster_research'
clear_previous_results(res_name)

research.run(n_iters=ITERATIONS, name=res_name, bar=True, 
             branches=2, workers=2, devices=[0, 1], 
             timeout=2, trials=1)

Research faster_research is starting...


Domain updated: 0: : 0it [00:00, ?it/s]

[59400 11756 16042 15449 56852  8029  8834 47440 10212 17029 28135  5086
 19477 57745 14349 38152 35071 36933 31905 42255 59531 17203  3465 57459
 33842 24094 38011 23619 28137 26768 35830  5822 16113 20211  9602 37468
 36774  6010 19633 51276 39398  8979 53129 44961 51012 40774 30908 35444
 48577  5342 14506 15058  4247   290 18814 19138 52040  4282  5134 45542
 42744 22772 50869 28513]
[59400 11756 16042 15449 56852  8029  8834 47440 10212 17029 28135  5086
 19477 57745 14349 38152 35071 36933 31905 42255 59531 17203  3465 57459
 33842 24094 38011 23619 28137 26768 35830  5822 16113 20211  9602 37468
 36774  6010 19633 51276 39398  8979 53129 44961 51012 40774 30908 35444
 48577  5342 14506 15058  4247   290 18814 19138 52040  4282  5134 45542
 42744 22772 50869 28513]


Domain updated: 0:   0%|          | 0/80.0 [00:34<?, ?it/s]

[ 4556  3378 32565 36973 37290  2786 16494 18789 44965 54033 14938 13555
 43571 50290 47946 36561   401 34573  6162 11577 17416 53151  7308 49409
  8755 32147 22989 52989 32164 56206   689  2536  5104 20017  5023 32035
   279 39282 20337 35280 34043 29703  2474  5682 55842 51088 53508  8546
 17647 58006 38078 23696 41440 26315 25822 29635 52385 23066  4637 29766
 37798   933 25301 36184]

Domain updated: 0:   0%|          | 0/80.0 [00:34<?, ?it/s]




Domain updated: 0:   1%|▏         | 1/80.0 [00:34<45:24, 34.49s/it]

[18180 49899 51133  7394 41203  9909 18694 10177 24075 35206 57455 50629
 27423  6138 51022 39779  5945 59172 28725 43675 28405 48685 10126 16964
 50896  7754  8927 20523 25095 27351 13845 29929 47108  2999 37651 23462
 11627 16175  7045 37874 58513 27901  3571 23358 48132  3268 24670  2129
   489 55485 39481 43726 59364 23925 59742 19776  6968 34309 52764 54950
 57530 20511 12534  7395]

Domain updated: 0:   1%|▏         | 1/80.0 [00:34<45:24, 34.49s/it]




Domain updated: 0:   2%|▎         | 2/80.0 [00:34<22:27, 17.27s/it]

[34326 39469 25775  3141 22680 57319 54973 16193 14344 15837 39970  2094
 21803 49753 56460 15429 38585 15454  2247 36342  5519 44894 30036 13117
 18164 17202 46668 58133 31709 58401 53376 56935 52134 25859 50515 43860
 41725 55109 14648  4265  1741 43491 19466 56317 16622 42352 49371  5538
 48009 58801 19211 35248 29281 26961 45159 45499  5173 56481 32127 33863
 49477 41290 22162 52616]

Domain updated: 0:   2%|▎         | 2/80.0 [00:34<22:27, 17.27s/it]




Domain updated: 0:   4%|▍         | 3/80.0 [00:34<14:47, 11.53s/it]

[23124 39952 53329 38081 40436 59120 22489 24200 52818 38646  8190 26394
 54781  5662 14403 44670 38060 51070 22494 22548 16802 57114  8504 47769
 10734 43241 13163 42182 56104 42615 46294 54268 14827 47876 15712 34750
 43457 25255 58903  4194 54692 40392 45047 14965 36756  5515 26687  1966
 42403 47164 10673 13688 42936 29611 12822  5130  9651 23399 26412 18908
 30872 45633 48618 41976]

Domain updated: 0:   4%|▍         | 3/80.0 [00:34<14:48, 11.53s/it]




Domain updated: 0:   5%|▌         | 4/80.0 [00:34<10:58,  8.66s/it]

[56859 41470 44467 13140 55231 55281 24435  1459 26516 11980  9019 28258
 36691 30416 29339  5397 42130 31297 12409 16060 25602 55382 59792 21797
 21011 35739 58785 19933 16530 48820 40562 18641 32881 16780 31690  3051
  4682 54450 31474 21854 49770  6629 31437 56096 42018   350 58225 55257
 49372 26588 31605 25763   386 27828   358 25561 18001 37162 27370 17468
 27725 33856 37118  7484]

Domain updated: 0:   5%|▌         | 4/80.0 [00:34<10:58,  8.66s/it]




Domain updated: 0:   6%|▋         | 5/80.0 [00:34<08:40,  6.94s/it]

[ 1307  9683 16254  2828 33934  4253 32820 36692 17465 51062 30286 52921
 32916 52441 36616  9054 39646 20237 33364 52819   456 28767 28287 19261
 31949 54035 23345 14249 44057 19036 50019 53248 40304 40717 55005 59965
  1182  5504  4544 23366 24008 55118 12956 52063 36579 39199 43095  4605
 14622 24577 18916 56808   860   452 25204  1070 17854 26012 59452 10054
 29403 39072 17963 53338]

Domain updated: 0:   6%|▋         | 5/80.0 [00:34<08:40,  6.94s/it]




Domain updated: 0:   8%|▊         | 6/80.0 [00:34<07:08,  5.79s/it]

[52364 57833 44585 48759 51737 41924 45050 40495 24501 19869 46758 24441
 29496  3410 41736 41001 48022 34086 59737 12366 34523 51325 28921 37333
 35608 35946 48186 51536 12855  9402 53727 29404 20652  1233  8984 27455
 47373 53405 44725 52135 16396 31738  8780  5344 26667   706  5299 12202
 19920 28813 23924 34345  2690 32353 48762 41539  2504 20912 10709 21807
 48788  9879 34926  9827]

Domain updated: 0:   8%|▊         | 6/80.0 [00:34<07:08,  5.79s/it]




Domain updated: 0:   9%|▉         | 7/80.0 [00:34<06:02,  4.97s/it]

[ 9861 16878 19590 13949 48003 23075 57446    92 46743 53511 16578 11325
 20510 37851  1448 26251 55185 17389 19361 40455 34727  3351 52774 50900
  9091 50091 51939 36056 49514 15426  1065 30519 27253 16525 54532 21312
 44777 13707 44404 15491  1156  5573 50051 47917 30867  6659 48679  2640
 20975  4721 11189 46037 22332 19862 49018 28640 48967 39907 33512 42354
 58579 24037 12577 36700]

Domain updated: 0:   9%|▉         | 7/80.0 [00:34<06:02,  4.97s/it]




Domain updated: 0:  10%|█         | 8/80.0 [00:34<05:13,  4.36s/it]

[28387 57135 21729 33336 41787 30907 47578 58714 16360 39759 26256  7266
 51388 54389 20224 31088 17053 47695 25499 39564 23730 33941 31613 33896
 24705 45540 56194  8643 50856 28890 31781 45544 13775 33260 34640 32708
 25965 26668 13716 33051 52619 44901 53968 31767 46194 51593 57111 18749
 39854 45598 27780 46569 36540 54527 50284 54605 49702 37699  2248 30314
 20850 54841  8748 53407]

Domain updated: 0:  10%|█         | 8/80.0 [00:34<05:13,  4.36s/it]




Domain updated: 0:  12%|█▎        | 10/80.0 [00:39<04:37,  3.96s/it]

[ 4556  3378 32565 36973 37290  2786 16494 18789 44965 54033 14938 13555
 43571 50290 47946 36561   401 34573  6162 11577 17416 53151  7308 49409
  8755 32147 22989 52989 32164 56206   689  2536  5104 20017  5023 32035
   279 39282 20337 35280 34043 29703  2474  5682 55842 51088 53508  8546
 17647 58006 38078 23696 41440 26315 25822 29635 52385 23066  4637 29766
 37798   933 25301 36184]


Domain updated: 0:  12%|█▎        | 10/80.0 [00:39<04:39,  3.99s/it]




Domain updated: 0:  12%|█▎        | 10/80.0 [00:39<04:39,  4.00s/it]


[18180 49899 51133  7394 41203  9909 18694 10177 24075 35206 57455 50629
 27423  6138 51022 39779  5945 59172 28725 43675 28405 48685 10126 16964
 50896  7754  8927 20523 25095 27351 13845 29929 47108  2999 37651 23462
 11627 16175  7045 37874 58513 27901  3571 23358 48132  3268 24670  2129
   489 55485 39481 43726 59364 23925 59742 19776  6968 34309 52764 54950
 57530 20511 12534  7395]

Domain updated: 0:  12%|█▎        | 10/80.0 [00:40<04:41,  4.02s/it]

[34326 39469 25775  3141 22680 57319 54973 16193 14344 15837 39970  2094
 21803 49753 56460 15429 38585 15454  2247 36342  5519 44894 30036 13117
 18164 17202 46668 58133 31709 58401 53376 56935 52134 25859 50515 43860
 41725 55109 14648  4265  1741 43491 19466 56317 16622 42352 49371  5538
 48009 58801 19211 35248 29281 26961 45159 45499  5173 56481 32127 33863
 49477 41290 22162 52616]


Domain updated: 0:  12%|█▎        | 10/80.0 [00:40<04:42,  4.03s/it]

[23124 39952 53329 38081 40436 59120 22489 24200 52818 38646  8190 26394
 54781  5662 14403 44670 38060 51070 22494 22548 16802 57114  8504 47769
 10734 43241 13163 42182 56104 42615 46294 54268 14827 47876 15712 34750
 43457 25255 58903  4194 54692 40392 45047 14965 36756  5515 26687  1966
 42403 47164 10673 13688 42936 29611 12822  5130  9651 23399 26412 18908
 30872 45633 48618 41976][56859 41470 44467 13140 55231 55281 24435  1459 26516 11980  9019 28258
 36691 30416 29339  5397 42130 31297 12409 16060 25602 55382 59792 21797
 21011 35739 58785 19933 16530 48820 40562 18641 32881 16780 31690  3051
  4682 54450 31474 21854 49770  6629 31437 56096 42018   350 58225 55257
 49372 26588 31605 25763   386 27828   358 25561 18001 37162 27370 17468
 27725 33856 37118  7484]

Domain updated: 0:  12%|█▎        | 10/80.0 [00:40<04:44,  4.06s/it]


[ 1307  9683 16254  2828 33934  4253 32820 36692 17465 51062 30286 52921
 32916 52441 36616  9054 39646 20237 33364 52819   456 28767 28287 19261
 31949 54035 23345 14249 44057 19036 50019 53248 40304 40717 55005 59965
  1182  5504  4544 23366 24008 55118 12956 52063 36579 39199 43095  4605
 14622 24577 18916 56808   860   452 25204  1070 17854 26012 59452 10054
 29403 39072 17963 53338]

Domain updated: 0:  12%|█▎        | 10/80.0 [00:40<04:45,  4.08s/it]


[52364 57833 44585 48759 51737 41924 45050 40495 24501 19869 46758 24441
 29496  3410 41736 41001 48022 34086 59737 12366 34523 51325 28921 37333
 35608 35946 48186 51536 12855  9402 53727 29404 20652  1233  8984 27455
 47373 53405 44725 52135 16396 31738  8780  5344 26667   706  5299 12202
 19920 28813 23924 34345  2690 32353 48762 41539  2504 20912 10709 21807
 48788  9879 34926  9827]

Domain updated: 0:  14%|█▍        | 11/80.0 [00:40<04:17,  3.73s/it]


[ 9861 16878 19590 13949 48003 23075 57446    92 46743 53511 16578 11325
 20510 37851  1448 26251 55185 17389 19361 40455 34727  3351 52774 50900
  9091 50091 51939 36056 49514 15426  1065 30519 27253 16525 54532 21312
 44777 13707 44404 15491  1156  5573 50051 47917 30867  6659 48679  2640
 20975  4721 11189 46037 22332 19862 49018 28640 48967 39907 33512 42354
 58579 24037 12577 36700]

Domain updated: 0:  14%|█▍        | 11/80.0 [00:41<04:18,  3.74s/it]


[28387 57135 21729 33336 41787 30907 47578 58714 16360 39759 26256  7266
 51388 54389 20224 31088 17053 47695 25499 39564 23730 33941 31613 33896
 24705 45540 56194  8643 50856 28890 31781 45544 13775 33260 34640 32708
 25965 26668 13716 33051 52619 44901 53968 31767 46194 51593 57111 18749
 39854 45598 27780 46569 36540 54527 50284 54605 49702 37699  2248 30314
 20850 54841  8748 53407]

Domain updated: 0:  14%|█▍        | 11/80.0 [00:41<04:19,  3.76s/it]




Domain updated: 0:  25%|██▌       | 20/80.0 [00:45<02:17,  2.30s/it]

In [16]:
results = research.load_results()
results.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 140 entries, 0 to 139
Data columns (total 8 columns):
accuracy        0 non-null float64
bias            140 non-null object
iteration       140 non-null int64
layout          140 non-null object
name            140 non-null object
repetition      140 non-null int64
sample_index    140 non-null int64
train_loss      140 non-null float64
dtypes: float64(2), int64(3), object(3)
memory usage: 8.8+ KB


## Cross-validation

One can easyly perform cross-validation with Research

Firstly we will define a dataset: we will use train subset of MNIST

In [17]:
mnist_train = MNIST().train
mnist_train.cv_split(n_splits=3)

Next, we define our train and test pipelines. When performing cross-validation, Research will automatically split the dataset given and feed the folds to pipelines, so we will rather define pipeline templates that will wait for a dataset to work with. In contrast with previous tutorials we are adding `run` method to pipeline templates, not dataset pipelines.

In [18]:
model_config={
    'inputs/images/shape': B('image_shape'),
    'inputs/labels/classes': D('num_classes'),
    'inputs/labels/name': 'targets',
    'initial_block/inputs': 'images',
    'body/block/layout': C('layout'),
}

train_template = (ds.p
            .cv_fold(C('fold'), 'train')
            .init_variable('train_loss')
            .init_model('dynamic', VGG7, 'conv', config=model_config)
            .to_array()
            .train_model('conv', 
                         images=B('images'), labels=B('labels'),
                         fetches='loss', save_to=V('train_loss', mode='w'))
            .run_later(BATCH_SIZE, shuffle=True, n_epochs=None))

test_template = (Pipeline()
                 .cv_fold(C('fold'), 'test')
                 .init_variable('predictions')
                 .init_variable('metrics')
                 .import_model('conv', C('import_from'))
                 .to_array()
                 .predict_model('conv', 
                                images=B('images'), labels=B('labels'),
                                fetches='predictions', save_to=V('predictions'))
                 .gather_metrics('class', targets=B('labels'), predictions=V('predictions'), 
                                fmt='logits', axis=-1, save_to=V('metrics'))
                 .run_later(BATCH_SIZE, shuffle=True, n_epochs=1))

We are now defining our Research object. To use cross-validation we should pass `dataset` and `part` parameters to `add_pipeline` methods. We will use train subset of MNIST `mnist_train` created above, so we pass `dataset=mnist_train`. This subset was also split on train and test parts when created, so we pass `part='train'` when adding train pipeline and `part='test'` when adding test pipeline. We don't pass any dataset to `root` explicitely scince this is now Research that cares for data. 

Next, we call `run` with `n_splits` parameter that defines the number of cv folds. We can also pass `shuffle` to specify whether to shuffle the dataset before splitting. `shuffle` can be bool, int, `numpy.random.RandomState` or callable, its default value is *False* which means no shuffling.

In [19]:
domain = Option('layout', ['cna', 'can']) * Option('fold', [0, 1, 2])

research = (Research()
            .add_pipeline(train_template, dataset=mnist_train, variables='train_loss', name='train_ppl')
            .add_pipeline(test_template, dataset=mnist_train, name='test_ppl',
                         execute=TEST_EXECUTE_FREQ, run=True, import_from='train_ppl')
            .get_metrics(pipeline='test_ppl', metrics_var='metrics', metrics_name='accuracy', returns='accuracy', 
                         execute=TEST_EXECUTE_FREQ)
            .init_domain(domain))

res_name = 'cv_research'
clear_previous_results(res_name)
    
research.run(n_iters=ITERATIONS, name=res_name, bar=True, workers=1, devices=[0])

Research cv_research is starting...


Domain updated: 0: 100%|██████████| 60/60.0 [03:15<00:00,  3.26s/it]


<batchflow.research.research.Research at 0x7f395babeb70>

We can now load results, specifying which folds to get if needed

In [22]:
results = research.load_results(fold=0)
results.sample(5)

Unnamed: 0,accuracy,fold,iteration,layout,name,repetition,sample_index,train_loss
15,,0,5.0,can,train_ppl,0,0.0,0.91454
5,,0,5.0,cna,train_ppl,0,0.0,1.068651
12,,0,2.0,can,train_ppl,0,0.0,1.48679
17,,0,7.0,can,train_ppl,0,0.0,0.721579
2,,0,2.0,cna,train_ppl,0,0.0,1.761201


In [None]:
from matplotlib import pyplot as plt
test_results = research.load_results(names= 'test_ppl_metrics', use_alias=True)

fig, ax = plt.subplots(1, 2, figsize=(15, 5))
for i, (config, df) in enumerate(test_results.groupby('config')):
    x, y = i//2, i%2
    df.pivot(index='iteration', columns='cv_split', values='accuracy').plot(ax=ax[y])
    ax[y].set_title(config)
    ax[y].set_xlabel('iteration')
    ax[y].set_ylabel('accuracy')
    ax[y].grid(True)
    ax[y].legend()

### Cross Validation with Extra Performance Settings 

We can still use branch-root division to preprocess the data 

In [None]:
model_config={
    'inputs/images/shape': B('image_shape'),
    'inputs/labels/classes': 10,
    'inputs/labels/name': 'targets',
    'initial_block/inputs': 'images',
    'body/block/layout': C('layout'),
}

train_template = (Pipeline()
            .init_variable('train_loss')
            .init_model('dynamic', VGG7, 'conv', config=model_config)
            .to_array()
            .train_model('conv', 
                         images=B('images'), labels=B('labels'),
                         fetches='loss', save_to=V('train_loss', mode='w'))
            .run_later(BATCH_SIZE, shuffle=True, n_epochs=None))

augmentation_pipeline = Pipeline().salt(p=0.5).run_later(BATCH_SIZE, shuffle=True, n_epochs=None)

research = (Research()
            .add_pipeline(root=augmentation_pipeline, branch=train_template,
                          dataset=mnist_train, part='train', 
                          variables='train_loss', name='train_ppl')
            .add_pipeline(test_template, dataset=mnist_train, part='test', name='test_ppl',
                          execute=TEST_EXECUTE_FREQ, run=True, import_from='train_ppl')
            .get_metrics(pipeline='test_ppl', metrics_var='metrics', metrics_name='accuracy', returns='accuracy', 
                         execute=TEST_EXECUTE_FREQ)
            .add_grid(grid))

res_name = 'cv_branches_research'
clear_previous_results(res_name)

research.run(n_iters=ITERATIONS,
             n_splits=3, shuffle=True,
             workers=2, gpu=[5,6], 
             branches=2, 
             name=res_name, bar=True)

research.load_results().info()
