In [1]:
import numpy as np
import time, math
from tqdm import tqdm_notebook as tqdm

import tensorflow as tf
import tensorflow.contrib.eager as tfe

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [0]:
tf.enable_eager_execution()


In [0]:
BATCH_SIZE = 512 #@param {type:"integer"}
MOMENTUM = 0.9 #@param {type:"number"}
LEARNING_RATE = 0.4 #@param {type:"number"}
WEIGHT_DECAY = 5e-4 #@param {type:"number"}
EPOCHS = 24 #@param {type:"integer"}

In [0]:
def init_pytorch(shape, dtype=tf.float32, partition_info=None):
  fan = np.prod(shape[:-1])
  bound = 1 / math.sqrt(fan)
  return tf.random.uniform(shape, minval=-bound, maxval=bound, dtype=dtype)

In [0]:
class ConvBN(tf.keras.Model):
  def __init__(self, c_out):
    super().__init__()
    self.conv = tf.keras.layers.Conv2D(filters=c_out, kernel_size=3, padding="SAME", kernel_initializer=init_pytorch, use_bias=False)
    self.bn = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

  def call(self, inputs):
    return tf.nn.relu(self.bn(self.conv(inputs)))

In [0]:
class ResBlk(tf.keras.Model):
  def __init__(self, c_out, pool, res = False):
    super().__init__()
    self.conv_bn = ConvBN(c_out)
    self.pool = pool
    self.res = res
    if self.res:
      self.res1 = ConvBN(c_out)
      self.res2 = ConvBN(c_out)

  def call(self, inputs):
    h = self.pool(self.conv_bn(inputs))
    if self.res:
      h = h + self.res2(self.res1(h))
    return h

In [0]:
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = ConvBN(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight

  def call(self, x, y):
    h = self.pool(self.blk3(self.blk2(self.blk1(self.init_conv_bn(x)))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

In [0]:
def get_random_eraser(p=0.5, s_l=0.02, s_h=0.4, r_1=0.3, r_2=1/0.3, v_l=0, v_h=255, pixel_level=False):
    def eraser(input_img):
        
        img_h, img_w, img_c = input_img.shape
        #img_h = input_img.shape[0].value
        #img_w = input_img.shape[1].value
        #img_c = input_img.shape[2].value
        
        
        p_1 = np.random.rand()

        if p_1 > p:
            return input_img

        while True:
            s = np.random.uniform(s_l, s_h) * img_h * img_w
            r = np.random.uniform(r_1, r_2)
            w = int(np.sqrt(s / r))
            h = int(np.sqrt(s * r))
            left = np.random.randint(0, img_w)
            top = np.random.randint(0, img_h)

            if left + w <= img_w and top + h <= img_h:
                break

        if pixel_level:
            c = np.random.uniform(v_l, v_h, (h, w, img_c))
        else:
            c = np.random.uniform(v_l, v_h)

        #print(type(c), type(img), img.shape)
        input_img[top:top + h, left:left + w, :] = c
        

        return input_img

    return eraser

In [0]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
len_train, len_test = len(x_train), len(x_test)
y_train = y_train.astype('int64').reshape(len_train)
y_test = y_test.astype('int64').reshape(len_test)

train_mean = np.mean(x_train, axis=(0,1,2))
train_std = np.std(x_train, axis=(0,1,2))

normalize = lambda x: ((x - train_mean) / train_std).astype('float32') # todo: check here
pad4 = lambda x: np.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)], mode='reflect')

x_train = normalize(pad4(x_train))
x_test = normalize(x_test)

eraser = get_random_eraser(p=0.5, s_l=0.02, s_h=0.4, r_1=0.3, r_2=1/0.3, v_l=0, v_h=255, pixel_level=False)
x_train = np.array(list(map(lambda img:eraser(img), x_train)))


In [0]:
def random_erasing(img, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3):
    '''
    img is a 3-D variable (ex: tf.Variable(image, validate_shape=False) ) and  HWC order
    '''
    # HWC order
    height = tf.shape(img)[0]
    width = tf.shape(img)[1]
    channel = tf.shape(img)[2]
    area = tf.cast(width*height, tf.float32)

    erase_area_low_bound = tf.cast(tf.round(tf.sqrt(sl * area * r1)), tf.int32)
    erase_area_up_bound = tf.cast(tf.round(tf.sqrt((sh * area) / r1)), tf.int32)
    h_upper_bound = tf.minimum(erase_area_up_bound, height)
    w_upper_bound = tf.minimum(erase_area_up_bound, width)

    h = tf.random.uniform([], erase_area_low_bound, h_upper_bound, tf.int32)
    w = tf.random.uniform([], erase_area_low_bound, w_upper_bound, tf.int32)

    x1 = tf.random.uniform([], 0, height+1 - h, tf.int32)
    y1 = tf.random.uniform([], 0, width+1 - w, tf.int32)

    erase_area = tf.cast(tf.random.uniform([h, w, channel], 0, 255, tf.int32), tf.uint8)
    print(type(img), type(erase_area))
    image_var = tf.Variable(img, validate_shape=False)
    init_op = tf.variables_initializer(var_list=[image_var])
    sess.run(init_op)
    
    erasing_img = img[x1:x1+h, y1:y1+w, :].assign(erase_area)
    

    return tf.cond(tf.random.uniform([], 0, 1) > probability, lambda: img, lambda: erasing_img)

In [0]:
#eraser = get_random_eraser(p=0.5, s_l=0.02, s_h=0.4, r_1=0.3, r_2=1/0.3, v_l=0, v_h=255, pixel_level=False)
#print(x_train.shape)
#print(type(x_train))
#x_temp = np.array(list(map(lambda img:eraser(img), x_train)))
#x_temp = [eraser(x_img) for x_img in x_train]
#print(x_temp.shape)
#print(type(x_temp))

In [0]:
model = DavidNet()
batches_per_epoch = len_train//BATCH_SIZE + 1

lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, EPOCHS], [0, LEARNING_RATE, 0])[0]
global_step = tf.train.get_or_create_global_step()
lr_func = lambda: lr_schedule(global_step/batches_per_epoch)/BATCH_SIZE
opt = tf.train.MomentumOptimizer(lr_func, momentum=MOMENTUM, use_nesterov=True)
data_aug = lambda x, y: (tf.image.random_flip_left_right(tf.random_crop(x, [32, 32, 3])), y)
#data_aug = lambda x, y: (random_erasing(x), y)

In [13]:
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 2.298995177001953 train acc: 0.11496 val loss: 2.7651372375488283 val acc: 0.1008 time: 49.39686989784241


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 2.3018826025390626 train acc: 0.13036 val loss: 2.294669427490234 val acc: 0.1013 time: 83.42735147476196


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 2.3426835021972656 train acc: 0.14558 val loss: 3.2855468505859373 val acc: 0.1 time: 117.80059504508972


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 2.22014978515625 train acc: 0.17972 val loss: 2.153628790283203 val acc: 0.1909 time: 152.31337666511536


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 2.0524633203125 train acc: 0.23454 val loss: 2.1266359619140625 val acc: 0.2047 time: 187.06967210769653


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37894736842105264 train loss: 1.9543112060546874 train acc: 0.2786 val loss: 2.5765706298828124 val acc: 0.2133 time: 221.53399300575256


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.35789473684210527 train loss: 1.7458061743164062 train acc: 0.35614 val loss: 1.6498028625488281 val acc: 0.4062 time: 255.88191938400269


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.33684210526315794 train loss: 1.5962169921875 train acc: 0.41748 val loss: 1.4613842895507811 val acc: 0.459 time: 290.4546046257019


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.31578947368421056 train loss: 1.4612428363037109 train acc: 0.46814 val loss: 1.5504996154785156 val acc: 0.4444 time: 325.3446841239929


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.2947368421052632 train loss: 1.3123903344726562 train acc: 0.5236 val loss: 1.0948476501464843 val acc: 0.6082 time: 359.9935030937195


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.2736842105263158 train loss: 1.171117720336914 train acc: 0.57922 val loss: 1.176301205444336 val acc: 0.5946 time: 395.0523056983948


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.25263157894736843 train loss: 1.0373220239257812 train acc: 0.62912 val loss: 1.636059799194336 val acc: 0.4804 time: 429.91463589668274


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.23157894736842108 train loss: 0.9354395068359375 train acc: 0.66738 val loss: 0.9938467254638672 val acc: 0.6491 time: 464.753098487854


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.2105263157894737 train loss: 0.8605134936523438 train acc: 0.69592 val loss: 0.8751170745849609 val acc: 0.6914 time: 499.65594363212585


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.18947368421052635 train loss: 0.8022829052734375 train acc: 0.71592 val loss: 0.9466999145507813 val acc: 0.6752 time: 534.68705534935


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.16842105263157897 train loss: 0.7449658416748047 train acc: 0.73746 val loss: 0.7400355773925781 val acc: 0.7423 time: 569.7117702960968


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.1473684210526316 train loss: 0.6996205096435547 train acc: 0.75396 val loss: 0.6819577346801757 val acc: 0.762 time: 604.8253288269043


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.12631578947368421 train loss: 0.6598103735351563 train acc: 0.76764 val loss: 0.67346650390625 val acc: 0.7695 time: 639.5763328075409


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.10526315789473689 train loss: 0.6226819885253906 train acc: 0.78358 val loss: 0.6209227691650391 val acc: 0.782 time: 673.7531177997589


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.08421052631578951 train loss: 0.5846632553100586 train acc: 0.79632 val loss: 0.7016513610839844 val acc: 0.7576 time: 707.7818737030029


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.06315789473684214 train loss: 0.5558993130493164 train acc: 0.807 val loss: 0.6184125305175782 val acc: 0.7909 time: 741.8124034404755


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.04210526315789476 train loss: 0.5216878295898437 train acc: 0.81948 val loss: 0.5989397735595703 val acc: 0.7957 time: 775.9591763019562


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.02105263157894738 train loss: 0.49099412017822264 train acc: 0.83036 val loss: 0.5292462707519531 val acc: 0.8169 time: 810.153561592102


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.0 train loss: 0.4638265548706055 train acc: 0.84156 val loss: 0.5107321548461914 val acc: 0.8241 time: 844.4459080696106
