In [1]:
%load_ext autoreload
%autoreload 2

import tensorflow as tf
import warnings
warnings.filterwarnings('ignore')
import sys

import matplotlib.pyplot as plt
from batchflow.opensets import MNIST
from batchflow.models.tf import TFModel
from batchflow import Pipeline, L, F, V, D, B, DatasetIndex, Dataset, ImagesBatch, Config

import numpy as np


In [2]:
mnist = MNIST(batch_class=ImagesBatch)

DownloadingDownloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
 http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
ExtractingExtracting /tmp/train-labels-idx1-ubyte.gz
 /tmp/train-images-idx3-ubyte.gz
Extracting /tmp/t10k-labels-idx1-ubyte.gz
Extracting /tmp/t10k-images-idx3-ubyte.gz


In [7]:
decay_d = ('exp', {'learning_rate': 0.001,
                   'decay_steps': 150,
                   'decay_rate': 0.96})

model_config = {'inputs': dict(images={'shape': (28, 28, 1)},
                               labels={'classes': 10}),
                'initial_block': {'layout': 'cna'*2,
                                  'filters': [6]*2, 'kernel_size': [3]*2,
                                  'inputs': 'images'},
                'body': {'layout': 'nca nca',
                         'filters': [8, 16],
                         'kernel_size': [3, 3]},
                'head': {'layout': 'Pfa'},
                'loss': 'ce',
                'optimizer': 'Adam',
                'decay': decay_d,
                'train_modes': {'all': {'optimizer': 'RMSProp', 'decay': decay_d},
                                'body': {'optimizer': 'Adam', 'scope': 'body'},
                                'head': {'optimizer': 'Adagrad', 'scope': 'head', 'decay': decay_d},
                                'custom': {'optimizer': 'Adam', 'scope': '-initial_block/layer-0'}},
                'head/units': 10
}

data_dict = {'images': B('images'),
             'labels': B('labels')}

train_pipeline = (mnist.train.p
                 .init_variable('predictions')
                 .to_array()
                 .multiply(multiplier=1/255., preserve_type=False)
                 .init_model('dynamic', TFModel, 'conv', config=model_config)
                 .to_array()
                 .train_model('conv', fetches='predictions', feed_dict=data_dict, 
                              train_mode='head', save_to=V('predictions'))
)

In [8]:
%%time
n_b = train_pipeline.next_batch(256, n_epochs=None)




CURRENT KEY:  all
SUBCONFIG:  {'optimizer': 'RMSProp', 'decay': ('exp', {'learning_rate': 0.001, 'decay_steps': 150, 'decay_rate': 0.96}), 'scope': ''}
OPTIMIZER_:  <tensorflow.python.training.rmsprop.RMSPropOptimizer object at 0x7f3f98574a90> 

SCOPE COLLECTION
<tf.Variable 'TFModel/initial_block/layer-0/conv2d/kernel:0' shape=(3, 3, 1, 6) dtype=float32_ref>
<tf.Variable 'TFModel/initial_block/layer-0/conv2d/bias:0' shape=(6,) dtype=float32_ref>
<tf.Variable 'TFModel/initial_block/layer-1/batch_normalization/gamma:0' shape=(6,) dtype=float32_ref>
<tf.Variable 'TFModel/initial_block/layer-1/batch_normalization/beta:0' shape=(6,) dtype=float32_ref>
<tf.Variable 'TFModel/initial_block/layer-3/conv2d/kernel:0' shape=(3, 3, 6, 6) dtype=float32_ref>
<tf.Variable 'TFModel/initial_block/layer-3/conv2d/bias:0' shape=(6,) dtype=float32_ref>
<tf.Variable 'TFModel/initial_block/layer-4/batch_normalization/gamma:0' shape=(6,) dtype=float32_ref>
<tf.Variable 'TFModel/initial_block/layer-4/batch_

In [None]:
model = train_pipeline.get_model_by_name('conv')

In [None]:
train_pipeline = (mnist.test.p
                 .import_model('conv', train_pipeline)
                 .to_array()
                 .multiply(multiplier=1/255., preserve_type=False)
                 .init_model('dynamic', TFModel, 'conv', config=model_config)
                 .to_array()
                 .init_variable('result', init_on_each_run=list()) 
                 .predict_model('conv', fetches=['predictions', 'TFModel/body/Relu_1'], feed_dict=data_dict,
                                save_to=V('result'), mode='a')
)

In [None]:
batch = train_pipeline.next_batch(10, n_epochs=None)

In [None]:
train_pipeline.get_variable('result')[0][1].shape

In [None]:
model.graph.get_operations()

In [None]:
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

In [None]:
class ModifiedTFModel(TFModel):
    
    def __init__(self, *args, **kwargs):
        self.session = kwargs.get('session', None)
        self.graph = tf.Graph() if self.session is None else self.session.graph
        self._graph_context = None
        self.is_training = None
        self.global_step = None
        self.loss = None
        self.train_steps = None
        self._train_lock = threading.Lock()
        self._attrs = []
        self._saver = None
        self._to_classes = {}
        self._inputs = {}
        self.inputs = None

        super().__init__(*args, **kwargs)    

    def build(self, *args, **kwargs):
        def _device_context():
            if 'device' in self.config:
                device = self.config.get('device')
                context = self.graph.device(device)
            else:
                context = contextlib.ExitStack()
            return context


        with self.graph.as_default(), _device_context():
            with tf.variable_scope(self.__class__.__name__):
                with tf.variable_scope('globals'):
                    if self.is_training is None:
                        self.store_to_attr('is_training', tf.placeholder(tf.bool, name='is_training'))
                    if self.global_step is None:
                        self.store_to_attr('global_step', tf.Variable(0, trainable=False, name='global_step'))

                config = self.build_config()
                self._build(config)
                if self.train_steps is None:
                    self._make_loss(config)
                    self.store_to_attr('loss', tf.losses.get_total_loss())

                    ######################
                    if config.get('train_modes') is None:
                        config['train_modes'] = {}
                    
                    _decay = config.get('decay')
                    _optimizer = config.get('optimizer')
                    _scope = config.get('scope')
                    if _optimizer is not None:
                        config['train_modes'].update({'default': {'optimizer': _optimizer,
                                                                  'decay': _decay,
                                                                  'scope': _scope}})
                    
                    for key, subconfig in config.get('train_modes').items():
                        if subconfig.get('optimizer') is None:
                            subconfig.update({'optimizer': _optimizer})
                        if subconfig.get('decay') is None:
                            subconfig.update({'decay': _decay})
                        if subconfig.get('scope') is None:
                            subconfig.update({'scope': _scope})
                    

                    train_steps = {}
                    for key, subconfig in config['train_modes'].items():
                        print('\n\n\nCURRENT KEY: ', key)
                        print('\nSUBCONFIG: ', subconfig)
                        optimizer_ = self._make_optimizer(subconfig)
                        print('\nOPTIMIZER_: ', optimizer_)
                        
                        if optimizer_:
                            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                            with tf.control_dependencies(update_ops):
    
                                scope = subconfig.get('scope')
                                var_scope = self.__class__.__name__ + '/' + scope
                                scope_collection = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                                                     var_scope)
                                
                                pprint(scope_collection)
                                train_step_ = optimizer_.minimize(self.loss,
                                                                  global_step=self.global_step,
                                                                  var_list=scope_collection)
                                train_steps.update({key: train_step_})
                    
                    self.store_to_attr('train_steps', train_steps)
                    print('\n\nTRAIN_STEPS: ', self.train_steps)
                    
                    ######################
                else:
                    self.store_to_attr('train_step', self.train_steps)

            if self.session is None:
                self.create_session(config)
                self.reset()
                
    @classmethod
    def default_config(cls):
        config = Config()
        config['inputs'] = {}
        config['initial_block'] = {}
        config['body'] = {}
        config['head'] = {}
        config['predictions'] = None
        config['output'] = None
        config['optimizer'] = ('Adam', dict())
        config['decay'] = (None, dict())
        config['scope'] = ''
        config['common'] = {'batch_norm': {'momentum': .1}}

        return config                
                
                
    def train(self, fetches=None, feed_dict=None, use_lock=False, train_mode='default', **kwargs):
        print('\n\n TRAIN_MODE IS: ', train_mode)
        print('\n\n CURRENT train_step: ', self.train_steps[train_mode])
        with self.graph.as_default():
            feed_dict = {} if feed_dict is None else feed_dict
            feed_dict = {**feed_dict, **kwargs}
            _feed_dict = self._fill_feed_dict(feed_dict, is_training=True)
            if fetches is None:
                _fetches = tuple()
            else:
                _fetches = self._fill_fetches(fetches, default=None)

            if use_lock:
                self._train_lock.acquire()

            _all_fetches = []
            
            #####
            if self.train_steps:
                _all_fetches += [self.train_steps[train_mode]]
                
                
                #######
            if _fetches is not None:
                _all_fetches += [_fetches]
            if len(_all_fetches) > 0:
                _, output = self.session.run(_all_fetches, feed_dict=_feed_dict)
            else:
                output = None

            if use_lock:
                self._train_lock.release()

            return self._fill_output(output, _fetches)