In [None]:
import numpy as np
import pandas as pd
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib import colors as mcolors


import sys
sys.path.append('../../..')
from batchflow.opensets import Imagenette160
from batchflow import Pipeline, B, V, C, W

from batchflow.models.torch import ResNet34, ResBlock, SelfAttention
from batchflow.models.torch.layers import ConvBlock

from batchflow.models.metrics import ClassificationMetrics
from batchflow.research import Research, Option, Results, KV, RP, REU, RI
from batchflow.utils import plot_results_by_config, show_research, print_results

In [None]:
# Global constants
NUM_ITERS = 15000                               # number of iterations to train each model for
N_REPS = 5                                      # number of times to repeat each model train
RESEARCH_NAME = 'research'                      # name of Research object
DEVICES = [4, 5, 6, 7]                          # devices to use
WORKERS = len(DEVICES)                          # number of simultaneously trained models

data = Imagenette160()                          # dataset to train models on

In [None]:
class SAResBlock(nn.Module):
    def __init__(self, inputs=None, **kwargs):
        super(SAResBlock, self).__init__()

        self.layer = ConvBlock({'base': ResBlock, **kwargs},
                               {'base': SelfAttention},
                               inputs=inputs)
        
    def forward(self, x):
        return self.layer(x)

In [None]:
domain = (Option('body', [KV({'encoder/blocks/base':ResBlock, 'encoder/blocks/se': False},
                             'ResBlock'),
                          KV({'encoder/blocks/base':SAResBlock},
                             'SAResBlock'), 
                          KV({'encoder/blocks/base':ResBlock, 'encoder/blocks/se': True},
                             'SEResBlock'),
                          KV({'encoder/blocks/base':SAResBlock, 'encoder/blocks/se': True},
                             'SESAResBlock')]))

In [None]:
config = {
    'inputs/labels/classes': 10,
    'body': C('body'),
    'head/layout': 'cV',
    'device': C('device'),
}

In [None]:
train_root = (data.train.p      
                  .crop(shape=(160, 160), origin='center')
                  .to_array(channels='first', dtype=np.float32)
                  .run_later(64, n_epochs=None, drop_last=True,
                             shuffle=True)
                   )

train_pipeline = (Pipeline()
                  .init_variable('loss')
                  .init_model('dynamic', ResNet34, 'my_model', config=config) 
                  .train_model('my_model', B('images'), B('labels'), 
                               fetches='loss', save_to=V('loss'))
                 )

test_pipeline = (data.train.p
                 .import_model('my_model', C('import_from'))
                 .init_variable('true', [])
                 .update(V('true', mode='a'), B.labels) 
                 .init_variable('predictions', [])
                 .crop(shape=(160, 160), origin='center')
                 .to_array(channels='first', dtype=np.float32)
                 .predict_model('my_model', B('images'), fetches='predictions',
                                save_to=V('predictions', mode='a'))
                 .run_later(128, n_epochs=1, drop_last=False,
                            shuffle=True)
                 )

In [None]:
def acc(iteration, experiment):
    pipeline = experiment.pipeline
    pred = np.concatenate(pipeline.v('predictions'))
    true = np.concatenate(pipeline.v('true'))
    accuracy = ClassificationMetrics(true, pred, fmt='logits',
                                     num_classes=10, axis=1).accuracy()
    return accuracy

In [None]:
research = (Research()
            .init_domain(domain, n_reps=N_REPS)
            .add_pipeline(root=train_root, branch=train_pipeline, variables='loss',
                          name='train_ppl', logging=True)
            .add_pipeline(test_pipeline, name='test_ppl',
                          execute=10, run=True, import_from=RP('train_ppl'))
            .add_callable(acc, returns='acc_vall', name='acc_fn',
                          execute=10, iteration=RI(), experiment=REU('test_ppl')))

In [None]:
research.run(NUM_ITERS, name=RESEARCH_NAME,
             devices=DEVICES, workers=WORKERS,
             bar=True)

In [None]:
results  = research.load_results(concat_config=True)

In [None]:
def aggreg(values):
    values = list(values)
    values = [item for item in values if not pd.isna(item)]
    return np.mean(values[-3:])

(results.df
 .groupby(['config'])['sample_index', 'acc_vall']
 .agg(aggreg)
 .reset_index()
 .sort_values('acc_vall', ascending=False)
)

In [None]:
show_research(results.df, layout=['train_ppl/loss', 'acc_fn/acc_vall'], average_repetitions=True, 
              color=list(mcolors.TABLEAU_COLORS.keys()), log_scale=False, rolling_window=10)

In [None]:
print_results(results.df, 'acc_fn/acc_vall', False, ascending=True, n_last=10)

In [None]:
#plot_results_by_config(results.df, (('train_ppl', 'loss'), ('acc_fn', 'acc_vall')))