In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import warnings

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

sys.path.append("../../..")
from batchflow import *
from batchflow.opensets import MNIST
from batchflow.models.eager_torch import *
from batchflow.models.eager_torch.layers import ConvBlock, update_layers, ConvGroup

# Set GPU
# %env CUDA_VISIBLE_DEVICES=6

In [None]:
inputs = np.ones((5, 5, 5, 5), dtype=np.float32)
value = torch.from_numpy(inputs)

inputs = value

In [None]:
layer = ConvBlock(layout='cna', filters='same', inputs=inputs)

In [None]:
layer

In [None]:
group = ConvGroup(nn.Identity(), {}, {}, inputs=inputs, n_repeats=2, n_branches=2, combine='+',
                  layout='cna', filters='same+1')

In [None]:
group

In [None]:
get_shape(group(inputs))

In [None]:
multiplied = group * 2

In [None]:
get_shape(multiplied(inputs))

In [None]:
branched = group % 2

In [None]:
branched

In [None]:
get_shape(branched(inputs))

In [None]:
super_group = ConvGroup(group, {}, nn.Identity(), {'filters': '5'},
                        inputs=inputs, n_repeats=2, n_branches=2,
                        branch_combine='+', layout='cna', filters='same+1')

In [None]:
super_group

In [None]:
get_shape(super_group(inputs))

# Setup

In [None]:
mnist = MNIST(batch_class=ImagesBatch)

if __name__ == '__main__':
    MICROBATCH = None
    DEVICE = None

print('\nMicrobatching is: {}'.format(MICROBATCH))
print('\nDevice is: {}'.format(DEVICE))    

In [None]:
IMAGE_SHAPE = (1, 28, 28)

def get_classification_config(model_class, config):
    default_config = {
#         'inputs/images/shape': IMAGE_SHAPE,
#         'inputs/labels/classes': 10,
#         'initial_block/inputs': 'images',
        'loss': 'ce',
        'microbatch': MICROBATCH,
        'device': DEVICE,
    }

    pipeline_config = {
        'model': model_class,
        'model_config': {**default_config, **config},
        'feed_dict': {'images': B('images'),
                      'labels': B('labels')},
    }
    return pipeline_config

def get_segmentation_config(model_class, config):
    default_config = {
#         'inputs/images/shape': IMAGE_SHAPE,
#         'inputs/masks/shape': IMAGE_SHAPE,
#         'initial_block/inputs': 'images',
        'body/decoder/blocks/combine_op': 'concat', # for some reason `concat` is not working from within pytest 
        'loss': 'mse',
        'microbatch': MICROBATCH,
        'device': DEVICE,
    }
    
    pipeline_config = {
        'model': model_class,
        'model_config': {**default_config, **config},
        'feed_dict': {'images': B('images'),
                      'masks': B('images')},
    }
    return pipeline_config

In [None]:
def get_pipeline(pipeline_config):
    """ Pipeline config must contain 'model', 'model_config', 'feed_dict' keys. """
    vals = pipeline_config['feed_dict'].pop('images'), list(pipeline_config['feed_dict'].values())[0]

    pipeline = (Pipeline(config=pipeline_config)
                .init_variable('loss_history', [])
#                 .multiply(multiplier=1/255., preserve_type=False)
                .to_array(channels='first', dtype='float32')
                .init_model('dynamic', C('model'),
                            'MODEL', config=C('model_config'))
                .train_model('MODEL', *vals,
                             fetches='loss',
                             save_to=V('loss_history', mode='a'))
                )
    return pipeline

In [None]:
def run(task, model_class, config, description, batch_size=16, n_iters=10):
    if task.startswith('c'):
        pipeline_config = get_classification_config(model_class, config)
    elif task.startswith('s'):
        pipeline_config = get_segmentation_config(model_class, config)
        
    train_pipeline = get_pipeline(pipeline_config) << mnist.train
    _ = train_pipeline.run(batch_size, n_iters=n_iters, bar=True,
                           bar_desc=W(V('loss_history')[-1].format('Loss is {:7.7}')))
    
    print('{} {} is done'.format(task, description))
    return train_pipeline

# Classification

In [None]:
config = {
#     'inputs/images/shape': IMAGE_SHAPE,
    'inputs/labels/classes': 10,
#     'initial_block/inputs': 'images',
    'initial_block': {'layout': 'cna', 'filters': 8},
    'body/encoder': {'num_stages': 5},
    'head': {'layout': 'faf', 'units': [50, 10]}
}


ppl = run('classification', Encoder, config, 'simple fc', n_iters=100, batch_size=64)

In [None]:
plt.plot(ppl.v('loss_history'))

In [None]:
config = {
#     'inputs/images/shape': IMAGE_SHAPE,
#     'inputs/labels/classes': 10,
#     'initial_block/inputs': 'images',
    'loss': ['ce', 'ce'],
    'decay': 'exp',
    'n_iters': 25,
    'train_steps': {'a': {}, 'b': {}},
    'initial_block': {'layout': 'fafaf', 'units': [128, 256, 10]},
    'order': ['initial_block', ('ib_2', 'initial_block', EagerTorch.initial_block)],
}


ppl = run('classification', EagerTorch, config, 'simple fc', n_iters=99, batch_size=64)

In [None]:
plt.plot(ppl.v('loss_history'))

In [None]:
from batchflow.models.eager_torch.layers import Combine

class TestModel(EagerTorch):
    @classmethod
    def body(cls, inputs, **kwargs):
        """ Truly amazing docstring. """
        kwargs = cls.get_defaults('body', kwargs)
        return BodyModule(inputs=inputs, **kwargs)
    
class BodyModule(nn.Module):
    def __init__(self, inputs=None, **kwargs):
        super().__init__()
        self.x1 = ConvBlock(inputs=inputs, **kwargs)
        self.x2 = ConvBlock(inputs=inputs, **kwargs)
        
        self.combine = Combine(op='concat')
        
    def forward(self, x):
        x1 = self.x1(x)
        x2 = self.x2(x)
        return self.combine([x1, x2])

In [None]:
config = {
    'initial_block': {'layout': 'Rcna. p cnap',
                      'filters': [16, 32], 'scale_factor': 2,},
#     'body': {'layout': 'ca'*2,
#              'filters': [32, 32]},
    'body': {'module': BodyModule, 'module_kwargs': {'layout': 'cnacna',
                                                     'filters': [32, 32]}
             },
    'head': {'layout': 'Dnfaf',
             'units': [600,10], 'dropout_rate': 0.3},
}

ppl = run('classification', TestModel, config, 'simple fc', n_iters=50, batch_size=64)

In [None]:
plt.plot(ppl.v('loss_history'))

In [None]:
test_pipeline = (mnist.test.p
                .import_model('MODEL', ppl)
                .init_variable('predictions')
                .init_variable('metrics', init_on_each_run=None) 
                .to_array(channels='first', dtype='float32')
#                 .train_model('MODEL', B.images, B.labels,
#                                fetches='predictions', save_to=V('predictions'))
                .predict_model('MODEL', B.images,
                               fetches='predictions', save_to=V('predictions'))
                .gather_metrics('class', targets=B.labels, predictions=V('predictions'),
                                fmt='logits', axis=-1, save_to=V('metrics', mode='w'))
                .run(64, shuffle=True, n_epochs=1, drop_last=False, bar=True)
)

In [None]:
metrics = test_pipeline.get_variable('metrics')
metrics.evaluate('accuracy')

In [None]:
ppl.get_model_by_name('MODEL').model

# Segmentation

In [None]:
config = {
    'initial_block': {'layout': 'cna', 'filters': 8},
#     'body/encoder/num_stages': 3,
#     'body/encoder/single_return': False,
    'body/decoder/num_stages': 3,
#     'body/decoder/filters': [32, 48, 64],
    'body/decoder/factor': [1, 1, 1],
#     'head': {'layout': 'c', 'filters': 1,}
}

ppl = run('segmentation', Decoder, config, 'unet?', n_iters=10, batch_size=64)

In [None]:
ppl.get_model_by_name('MODEL').model

In [None]:
plt.plot(ppl.v('loss_history'))

In [None]:
config = {
    'initial_block': {'layout': 'cna', 'filters': 8},
    'body/encoder/num_stages': 3,
#     'body/decoder/num_stages': 3,
#     'head': {'layout': 'c', 'filters': 1,}
}

ppl = run('segmentation', EncoderDecoder, config, 'unet?', n_iters=100, batch_size=64)

In [None]:
plt.plot(ppl.v('loss_history'))

In [None]:
config = {
    'step_on_each': 1,
    'initial_block': {
#         'layout': 'cnap AAbcna++ c',
        'layout': 'cnaRp cnaRp tna+ tna+ BScna*+ cnac',
        'filters': [16, 32, 32, 16, 'same', 8, 1],
        'transposed_conv': {'kernel_size': 2, 'strides': 2},
#         'kernel_size': [3, 3, 2, 2, 3, 3, 3],
#         'strides':     [1, 1, 2, 2, 1, 1, 1],
        'side_branch': {'layout': 'ca', 'filters': 'same'}
    },
}

ppl = run('segmentation', EagerTorch, config, 'unet?', n_iters=100, batch_size=64)

In [None]:
plt.plot(ppl.v('loss_history'))

In [None]:
ppl.get_model_by_name('MODEL').model

In [None]:
class UNet(EncoderDecoder):
    @classmethod
    def default_config(cls):
        config = super().default_config()

        config['body/encoder/num_stages'] = 4
        config['body/encoder/order'] = ['block', 'skip', 'downsampling']
        config['body/encoder/blocks'] += dict(layout='cna cna', kernel_size=3)
        config['body/embedding'] += dict(layout='cna cna', kernel_size=3, filters=1024)
        config['body/decoder/order'] = ['upsampling', 'combine', 'block']
        config['body/decoder/blocks'] += dict(layout='cna cna', kernel_size=3)

        config['loss'] = 'ce'
        return config

    def build_config(self):
        config = super().build_config()

        num_stages = config.get('body/encoder/num_stages')

        if config.get('body/encoder/blocks/filters') is None:
            config['body/encoder/blocks/filters'] = [2 * 2**i for i in range(num_stages)]

        if config.get('body/embedding/filters') is None:
            config['body/embedding/filters'] = 2 * 2**num_stages

        if config.get('body/decoder/blocks/filters') is None:
            enc_filters = config.get('body/encoder/blocks/filters')
            config['body/decoder/blocks/filters'] = enc_filters[::-1]

        if config.get('body/decoder/upsample/filters') is None:
            config['body/decoder/upsample/filters'] = config.get('body/decoder/blocks/filters')

        return config

In [None]:
config = {
    'initial_block/filters': 2,
    'body/encoder': {'num_stages': 3},
}

ppl = run('segmentation', UNet, config, 'unet', n_iters=1000, batch_size=64)

In [None]:
plt.plot(ppl.v('loss_history'))