In [24]:
import tensorflow as tf
import model_fine as model_lib

In [7]:
#################################################################

class LinearLayer(tf.keras.layers.Layer):
    def __init__(self, num_classes, use_bias=True, use_bn=False, name='linear_layer', **kwargs):
        # Note: use_bias is ignored for the dense layer when use_bn =True. However, it is still used for batch norm
        super(LinearLayer, self).__init__(**kwargs)
        self.num_classes = num_classes
        self.use_bn = use_bn
        self._name = name
        if callable(self.num_classes):
            num_classes = -1
        else:
            num_classes = self.num_classes
        self.dense = tf.keras.layers.Dense(num_classes, kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
        use_bias=use_bias and not self.use_bn)
        if self.use_bn:
            self.bn_relu = BatchNormRelu(relu=False, center=use_bias)

    def build(self, input_shape):
        if callable(self.num_classes):
            self.dense.units = self.num_classes(input_shape)
        super(LinearLayer,self).build(input_shape)

    def call(self, inputs, training):
        assert inputs.shape.ndims == 2, inputs.shape
        inputs = self.dense(inputs)
        if self.use_bn:
            inputs = self.bn_relu(inputs, training=training)
        return inputs

class ProjectionHead(tf.keras.layers.Layer):
    #using nonlinear projectionHead
    def __init__(self, **kwargs):
        self.linear_layers = []
        for j in range(3):
            if j != 3 - 1:
                #for the middle layers, use bias and relu for the output
                self.linear_layers.append(LinearLayer(num_classes=lambda input_shape: int(input_shape[-1]),
                use_bias=True, use_bn=True, name='nl_%d' % j))
            else:
                #for the final layer, neither bias nor relu is used
                self.linear_layers.append(LinearLayer(num_classes=128, use_bias=False, use_bn=True, name='nl_%d' %j))
        
        super(ProjectionHead, self).__init__(**kwargs)
    
    def call(self, inputs, training):
        hiddens_list = [tf.identity(inputs, 'proj_head_input')]
        for j in range(3):
            hiddens = self.linear_layers[j](hiddens_list[-1], training)
            if j!= 3 - 1:
                #for the middle layers, use bias and relu for the output.
                hiddens = tf.nn.relu(hiddens)
            hiddens_list.append(hiddens)

        #The element is the input of the finetune head
        return hiddens_list[0]

class SupervisedHead(tf.keras.layers.Layer):
    def __init__(self, num_classes, name='head_supervised', **kwargs):
        super(SupervisedHead, self).__init__(name=name, **kwargs)
        self.linear_layer = LinearLayer(num_classes)

    def call(self, inputs, training):
        inputs = self.linear_layer(inputs, training)
        inputs = tf.identity(inputs, name='logits_sup')
        return inputs

class Model(tf.keras.models.Model):
    #Resnet model with supervised layer

    def __init__(self, **kwargs):
        super(Model, self).__init__(**kwargs)
        #resnet
        self.resnet_model = resnet(resnet_depth=18, cifar_stem=False)
        self._projection_head = ProjectionHead()
        #self.supervised_head = SupervisedHead(num_classes)
    
    def __call__(self, inputs, training):
        features = inputs
        num_transforms = 1
        
        #split channels and optionally apply extra batched augmentation
        features_list = tf.split(features, num_or_size_splits=num_transforms, axis=-1)
        features = tf.concat(features_list, 0) #(num_transforms * bsz, h, w, c)

        #base network forward pass
        hiddens = self.resnet_model(features, training=training)

        #add heads
        supervised_head_inputs = self._projection_head(hiddens, training)

        supervised_head_outputs = self.supervised_head(supervised_head_inputs, training)
        return supervised_head_outputs

In [8]:

##Conv2D function with fixed padding
class Conv2Ds(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides, data_format, **kwargs):
        super(Conv2Ds, self).__init__(**kwargs)
        if strides > 1:
            self.fixed_padding = FixedPadding(kernel_size, data_format=data_format)
        else:
            self.fixed_padding = None

        self.conv2d = tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,
            padding=('SAME' if strides==1 else 'VALID'), use_bias=False,
            kernel_initializer=tf.keras.initializers.VarianceScaling(), data_format=data_format)

    def call(self, inputs, training):
        if self.fixed_padding:
            inputs = self.fixed_padding(inputs, training=training)
        return self.conv2d(inputs, training=training)

#necesario para poder sumar input con shortcut en bottleneckres
class FixedPadding(tf.keras.layers.Layer):

    def __init__(self, kernel_size, data_format='channels_last', **kwargs):
        super(FixedPadding, self).__init__(**kwargs)
        self.kernel_size = kernel_size
        self.data_format = data_format

    def call(self, inputs, training):
        kernel_size = self.kernel_size
        data_format = self.data_format
        pad_total = kernel_size -1
        pad_beg = pad_total // 2
        pad_end = pad_total - pad_beg
        if data_format == 'channels_first':
            padded_inputs = tf.pad(inputs, [[0,0], [0,0], [pad_beg, pad_end], [pad_beg, pad_end]])
        else:
            padded_inputs = tf.pad(inputs, [[0,0], [pad_beg, pad_end], [pad_beg, pad_end], [0,0]])

        return padded_inputs

#Apply batch normalization with or without relu activation
class BatchNormRelu(tf.keras.layers.Layer):
    def __init__(self, relu=True, init_zero=False, center=True, scale=True, data_format='channels_last', **kwargs):
        super(BatchNormRelu, self).__init__(**kwargs)
        self.activation = relu
        if init_zero:
            gamma_initializer = tf.zeros_initializer()
        else:
            gamma_initializer = tf.ones_initializer()
        if data_format == 'channels_first':
            bn_axis = 1
        else:
            bn_axis = -1

        if True:
            self.bn = tf.keras.layers.experimental.SyncBatchNormalization(
                axis=bn_axis, momentum=0.9, epsilon=1.001e-5, center=center,
                scale=scale, gamma_initializer=gamma_initializer)
        else:
            self.bn = tf.keras.layers.BatchNormalization(
                axis=bn_axis, momentum=0.9, epsilon=1.001e-5, center=center,
                scale=scale, fused=True, gamma_initializer=gamma_initializer)
    
    def call(self, inputs, training):
        inputs = self.bn(inputs, training=training)
        if self.activation:
            inputs = tf.nn.relu(inputs)
        return inputs


class ResidualBlock(tf.keras.layers.Layer):
    def __init__(self, filters, strides, data_format='channels_last', **kwargs):
        super(ResidualBlock, self).__init__(**kwargs)
        self.conv2d_layers = []
        self.shortcut_layers = []

        self.shortcut_layers.append(Conv2Ds(filters=filters, kernel_size=1, strides=strides, data_format=data_format))
        self.shortcut_layers.append(BatchNormRelu(relu=False, data_format=data_format))
        self.conv2d_layers.append(Conv2Ds(filters=filters, kernel_size=3, strides=strides, data_format=data_format))
        self.conv2d_layers.append(BatchNormRelu(data_format=data_format))
        self.conv2d_layers.append(Conv2Ds(filters=filters, kernel_size=3, strides=1, data_format=data_format))
        self.conv2d_layers.append(BatchNormRelu(relu=False, init_zero=True, data_format=data_format))
    
    def call(self, inputs, training):
        shortcut = inputs
        for layer in self.shortcut_layers:
            # Projection shortcut in first layer to match filters and strides
            shortcut = layer(shortcut, training=training)
        
        for layer in self.conv2d_layers:
            inputs = layer(inputs, training=training)

        return tf.nn.relu(inputs + shortcut)

## stack bottleneck layers depending on resnet architecture 18,34,50,101 etc
class BlockGroup(tf.keras.layers.Layer):
    def __init__(self, filters, block_fn, blocks, strides, data_format='channels_last', **kwargs):
        self._name = kwargs.get('name')
        super(BlockGroup, self).__init__(**kwargs)
        self.layers = []
        self.layers.append(block_fn(filters, strides, data_format=data_format))
        for _ in range(1, blocks):
            self.layers.append(block_fn(filters, 1, data_format=data_format))
        
    def call(self, inputs, training):
        for layer in self.layers:
            inputs = layer(inputs, training=training)
        return tf.identity(inputs, self._name)


class IdentityLayer(tf.keras.layers.Layer):
    def call(self, inputs, training):
        return tf.identity(inputs)


class Resnet(tf.keras.layers.Layer):
    def __init__(self, block_fn, layers, cifar_stem=False, data_format='channels_last', **kwargs):
        super(Resnet, self).__init__(**kwargs)
        self.data_format = data_format

        trainable = True

        self.initial_layers = []
        if cifar_stem:
            self.initial_layers.append(Conv2Ds(filters=64, kernel_size=3, strides=1,
                data_format=data_format, trainable=trainable))
            self.initial_layers.append(IdentityLayer(name='initial_conv', trainable=trainable))
            self.initial_layers.append(BatchNormRelu(data_format=data_format, trainable=trainable))
            self.initial_layers.append(IdentityLayer(name='initial_max_pool', trainable=trainable))
        else:
            self.initial_layers.append(Conv2Ds(filters=64, kernel_size=7, strides=2,
                data_format=data_format, trainable=trainable))
            self.initial_layers.append(IdentityLayer(name='initial_conv', trainable=trainable))
            self.initial_layers.append(BatchNormRelu(data_format=data_format, trainable=trainable))
            self.initial_layers.append(tf.keras.layers.MaxPooling2D(pool_size=3, strides=2, 
                padding='SAME', data_format=data_format, trainable=trainable))
            self.initial_layers.append(IdentityLayer(name='initial_max_pool', trainable=trainable))

        self.block_groups = []
    
        #first block
        self.block_groups.append(BlockGroup(filters=64, block_fn=block_fn, blocks=layers[0],
            strides=1, name='block_group1', data_format=data_format, trainable=trainable))

        #second block
        self.block_groups.append(BlockGroup(filters=128, block_fn=block_fn, blocks=layers[1],
            strides=2, name='block_group2', data_format=data_format, trainable=trainable))

              
        #third block
        self.block_groups.append(BlockGroup(filters=256, block_fn=block_fn, blocks=layers[2],
            strides=2, name='block_group3', data_format=data_format, trainable=trainable))

        #fourth block
        self.block_groups.append(BlockGroup(filters=512, block_fn=block_fn, blocks=layers[3],
            strides=2, name='block_group4', data_format=data_format, trainable=trainable))
        
    def call(self, inputs, training):
        for layer in self.initial_layers:
            inputs = layer(inputs, training=training)
        
        for layer in self.block_groups:
            inputs = layer(inputs, training=training)
        
        if self.data_format == 'channels_last':
            inputs = tf.reduce_mean(inputs, [1, 2])
        else:
            inputs = tf.reduce_mean(inputs, [2, 3])
        
        inputs = tf.identity(inputs, 'final_avg_pool')
        return inputs

def resnet(resnet_depth, cifar_stem=False, data_format='channels_last'):
    model_params = {
        18: {'block': ResidualBlock, 'layers': [2, 2, 2, 2]}}

    if resnet_depth not in model_params:
        raise ValueError('Not implemented resnet_depth:', resnet_depth)
        
    params = model_params[resnet_depth]
    return Resnet(params['block'], params['layers'], cifar_stem=cifar_stem,data_format=data_format)

In [9]:
model = Model()

In [5]:
checkpoint_manager2 = tf.train.CheckpointManager(tf.train.Checkpoint(model=model), directory='drf/', max_to_keep=1)

In [None]:
checkpoint_manager2.checkpoint.restore('pretrain/')

In [10]:
checkpoint = tf.train.Checkpoint(model)

In [25]:
all_metrics = []
weight_decay_metric = tf.keras.metrics.Mean('train/weight_decay')
total_loss_metric = tf.keras.metrics.Mean('train/total_loss')
supervised_loss_metric = tf.keras.metrics.Mean('train/supervised_loss')
supervised_acc_metric = tf.keras.metrics.Mean('train/supervised_acc')
all_metrics.extend([weight_decay_metric, total_loss_metric,
    supervised_loss_metric, supervised_acc_metric])

In [26]:
checkpoint_manager = tf.train.CheckpointManager(
      checkpoint,
      directory='pretrain/',
      max_to_keep=1)
latest_ckpt = checkpoint_manager.latest_checkpoint

In [20]:
latest_ckpt

'pretrain/ckpt-13635'

In [30]:
checkpoint_manager2 = tf.train.CheckpointManager(tf.train.Checkpoint(model=model), directory='pedro/', max_to_keep=1)
checkpoint_manager2.checkpoint.restore(latest_ckpt)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f9b0c216b10>

In [31]:
model = checkpoint_manager2.checkpoint.model

In [32]:
model

<__main__.Model at 0x7f9b2c8ead50>