# Update domain in Research

Sometimes one needs to change the domain of parameters during `Research` execution. `update_domain` method helps to do that.

We start with some useful imports and constant definitions

In [1]:
import sys
import os
import shutil

import numpy as np
import matplotlib
%matplotlib inline

In [2]:
sys.path.append('../../..')

from batchflow import Pipeline, B, C, V, D, F
from batchflow.opensets import CIFAR10
from batchflow.models.torch import VGG7, VGG16, ResNet18
from batchflow.research import Research, Domain, E, R, get_metrics

In [3]:
BATCH_SIZE = 64

ds = CIFAR10()

Let us solve the following problem: for one epoch we will train three models: VGG7, VGG16 and ResNet18, then we will choose the best model with the highest test accuracy and finally will train it for 10 epochs. Define pipelines where we will change `'model'` and `'n_epochs'`.

In [4]:
model_config={
    'inputs/images/shape': B('image_shape'),
    'inputs/labels/classes': D('num_classes'),
    'inputs/labels/name': 'targets',
    'initial_block/inputs': 'images'
}

In [5]:
train_pipeline = (ds.train.p
    .init_variable('loss')
    .init_model('conv', C('model'), 'dynamic', config=model_config)
    .to_array(dtype='float32')
    .train_model('conv', B('images'), B('labels'),
                 fetches='loss', save_to=V('loss', mode='w'))
    .run_later(batch_size=BATCH_SIZE, n_epochs=C('n_epochs'))
)

test_pipeline = (ds.test.p
    .init_variable('predictions')
    .init_variable('metrics')
    .import_model('conv', C('import_from'))
    .to_array(dtype='float32')
    .predict_model('conv', B('images'),
                   fetches='predictions', save_to=V('predictions'))
    .gather_metrics('class', targets=B('labels'), predictions=V('predictions'), 
                    fmt='logits', axis=-1, save_to=V('metrics', mode='a'))
    .run_later(batch_size=BATCH_SIZE, n_epochs=1)
)

Firstly, define initial domain.

In [6]:
domain = Domain(model=[VGG7, VGG16, ResNet18], n_epochs=[1])

To update domain we can define some function which return new `domain` or `None` if domain will not be updated. In our case funtion `update_domain` accepts research results as `pandas.DataFrame`, takes model with the highest accuracy and create new domain with that model and `n_epochs=2`.

In [7]:
def update_domain(results):
    results = results.to_df(pivot=True, use_alias=False)
    best_model = results.iloc[results['accuracy'].idxmax()].model
    domain = Domain(model=[best_model], n_epochs=[10])
    return domain

We add update function into research as a parameter of `update_domain` function. `when` parameter defines how often function will be applied. If `when='last'`, update function will be applied when current domain will be exhausted. All other parameters are used as `kwargs` for `update_domain` function.

In [8]:
research = (Research(domain=domain)
            .add_pipeline('train_ppl', train_pipeline, variables='loss')
            .add_pipeline('test_ppl', test_pipeline, run=True,
                          import_from=E('train_ppl').pipeline, when='last')
            .get_metrics(pipeline=E('test_ppl').pipeline,
                         metrics_var='metrics',
                         metrics_name='accuracy',
                         save_to='accuracy',
                         when='last')
            .update_domain(update_domain, when="%2", results=R())
           )

research.run(dump_results=False, parallel=False, bar=True)

100%|██████████| 4/4 [06:10<00:00, 92.68s/it]


<batchflow.research.research.Research at 0x7fb052a79940>

In [9]:
research.results.to_df(pivot=True, remove_auxilary=False, use_alias=True)

Unnamed: 0,id,model,n_epochs,repetition,updates,device,iteration,loss,accuracy
0,a0a92f2751121067,<class 'batchflow.models.torch.vgg.VGG7'>,1,0,0,,0,2.352301,
1,a0a92f2751121067,<class 'batchflow.models.torch.vgg.VGG7'>,1,0,0,,1,2.409408,
2,a0a92f2751121067,<class 'batchflow.models.torch.vgg.VGG7'>,1,0,0,,2,2.204967,
3,a0a92f2751121067,<class 'batchflow.models.torch.vgg.VGG7'>,1,0,0,,3,2.045897,
4,a0a92f2751121067,<class 'batchflow.models.torch.vgg.VGG7'>,1,0,0,,4,1.966656,
...,...,...,...,...,...,...,...,...,...
10158,d1b1ce1b39296110,<class 'batchflow.models.torch.vgg.VGG7'>,10,0,1,,7809,0.606818,
10159,d1b1ce1b39296110,<class 'batchflow.models.torch.vgg.VGG7'>,10,0,1,,7810,0.649437,
10160,d1b1ce1b39296110,<class 'batchflow.models.torch.vgg.VGG7'>,10,0,1,,7811,0.593156,
10161,d1b1ce1b39296110,<class 'batchflow.models.torch.vgg.VGG7'>,10,0,1,,7812,0.792372,


Resulting `pandas.DataFrame` will have `'updates'` column with the number of updates before we get current config.

In [10]:
acc = research.results.to_df(updates=1, pivot=True, use_alias=True)
print('Best model:    ', acc.model.values[0])
print('Final accuracy:', acc.accuracy.values[-1])

Best model:     <class 'batchflow.models.torch.vgg.VGG7'>
Final accuracy: 0.6238057324840764
