In [1]:
from collections import namedtuple

import numpy as np
import tensorflow as tf
import six

from tensorflow.python.training import moving_averages

In [2]:
HParams = namedtuple('HParams',
                     'batch_size, num_classes, min_lrn_rate, lrn_rate, '
                     'num_residual_units, use_bottleneck, weight_decay_rate, '
                     'relu_leakiness, optimizer')

In [4]:
class Resnet(object):
    def __init__(self, hps, images, labels, mode):
        self.hps = hps
        self._images = images
        self.labels = labels
        self.mode = mode
        
        self._extra_train_ops = []
    
    def build_graph(self):
        self.global_step = tf.contrib.framework.get_or_create_global_step()
        self._build_model()
        if self.mode == 'train':
            self._build_train_op()
        self.summaries = tf.summary.merge_all()
    
    def _stride_arr(self, stride):
        return [1, stride, stride, 1]
    
    def _build_model(self):
        with tf.variable_scope('init'):
            x = self._images
            x = self._conv('init_conv', x, 3, 3, 16, self._stride_arr(1))
        
        strides = [1, 2, 2]
        activate_before_residual = [True, False, False]
        if self.hsp.use_bottleneck:
            res_func = self._bottleneck_residual
            filters = [16, 64, 128, 256]
        else:
            res_func = self._residual
            filters = [16, 16, 32, 64]
        
        with tf.variable_scope('unit_1_0'):
            x = res_func(x, filter[0], filters[1], self._stride_arr(strides[0]),
                        activate_before_residual[0])
        for i in six.moves.range(1, self.hps.num_residual_units):
            with tf.variable_scope('unit_1_%d' % i):
                x = res_func(x, filters[1], filters[1], self._stride_arr(1), False)
        
        with tf.variable_scope('unit_2_0'):
            x = res_func(x, filters[1], filters[2], self._stride_arr(strides[1]),
                        activate_before_residual[1])
        for i in six.moves.range(1, self.hps.num_residual_units):
            with tf.variable_scope('unit_2_%d' % i):
                x = res_func(x, filters[2], filters[2], self._stride_arr(1), False)
        
        with tf.variable_scope('unit_3_0'):
            x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]),
                        activate_before_residual[2])
        for i in six.moves.range(1, self.hps.num_residual_units):
            with tf.variable_scope('unit_3_%d' % i):
                x = res_func(x, filters[3], filters[3], self._stride_arr(1), False)
        
        with tf.variable_scope('unit_last'):
            x = self._batch_norm('final_bn', x)
            x = self._relu(x, self.hps.relu_leakiness)
            x = self._global_avg_pool(x)
        
        with tf.variable_scope('logit'):
            logits = self._fully_connected(x, self.hps.num_classes)
            self.predictions = tf.nn.softmax(logits)
        
        with tf.variable_scope('costs'):
            xent = tf.nn.softmax_cross_entropy_with_logits(
            logits=logits, labels=self.labels)
            self.cost = tf.reduce_mean(xent, name='xent')
            self.cost += self._decay()
            tf.summary.scalar('cost', self.cost)
    
    def _build_train_op(self):
        self.lrn_rate = tf.constant(self.hps.lrn_rate, tf.float32)
        tf.summary.scalar('learning_rate', self.lrn_rate)
        
        trainable_variables = tf.trainable_variables()
        grads = tf.gradients(self.cost, trainable_variables)
        
        if self.hps.optimizer == 'sgd':
            optimizer = tf.train.GradientDescentOptimizer(self.lrn_rate)
        elif self.hps.optimizer == 'mom':
            optimizer = tf.train.MomentumOptimizer(self.lrn_rate, 0.9)
        
        apply_op = optimizer.apply_gradients(
        zip(grads, trainable_variables), global_step=self.global_step,
        name='train_step')
        
        train_ops = [apply_op] + self._extra_train_ops
        self.train_op = tf.group(*train_ops)
        
    def _batch_norm(self, name, x):
        with tf.variable_scope(name):
            params_shape = [x.get_shape()[-1]]
            
            beta = tf.get_variable('beta', params_shape, tf.float32,
                                  initializer=tf.constant_initializer(0.0, tf.float32))
            gamma = tf.get_variable('gamma', params_shape, tf.float32,
                                   initializer=tf.constant_initializer(1.0, tf.float32))