In [0]:
import numpy as np
import time, math
from tqdm import tqdm_notebook as tqdm

import tensorflow as tf
import tensorflow.contrib.eager as tfe

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [0]:
tf.enable_eager_execution()


In [0]:
BATCH_SIZE = 512 #@param {type:"integer"}
MOMENTUM = 0.9 #@param {type:"number"}
LEARNING_RATE = 0.4 #@param {type:"number"}
WEIGHT_DECAY = 5e-4 #@param {type:"number"}
EPOCHS = 24 #@param {type:"integer"}

In [0]:
def init_pytorch(shape, dtype=tf.float32, partition_info=None):
  fan = np.prod(shape[:-1])
  bound = 1 / math.sqrt(fan)
  return tf.random.uniform(shape, minval=-bound, maxval=bound, dtype=dtype)

In [0]:
class ConvBN(tf.keras.Model):
  def __init__(self, c_out):
    super().__init__()
    self.conv = tf.keras.layers.Conv2D(filters=c_out, kernel_size=3, padding="SAME", kernel_initializer=init_pytorch, use_bias=False)
    self.bn = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

  def call(self, inputs):
    return tf.nn.relu(self.bn(self.conv(inputs)))

In [0]:
class ResBlk(tf.keras.Model):
  def __init__(self, c_out, pool, res = False):
    super().__init__()
    self.conv_bn = ConvBN(c_out)
    self.pool = pool
    self.res = res
    if self.res:
      self.res1 = ConvBN(c_out)
      self.res2 = ConvBN(c_out)

  def call(self, inputs):
    h = self.pool(self.conv_bn(inputs))
    if self.res:
      h = h + self.res2(self.res1(h))
    return h

In [0]:
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = ConvBN(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight

  def call(self, x, y):
    h = self.pool(self.blk3(self.blk2(self.blk1(self.init_conv_bn(x)))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

In [0]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
len_train, len_test = len(x_train), len(x_test)
y_train = y_train.astype('int64').reshape(len_train)
y_test = y_test.astype('int64').reshape(len_test)

train_mean = np.mean(x_train, axis=(0,1,2))
train_std = np.std(x_train, axis=(0,1,2))

normalize = lambda x: ((x - train_mean) / train_std).astype('float32') # todo: check here
pad4 = lambda x: np.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)], mode='reflect')

x_train = normalize(pad4(x_train))
x_test = normalize(x_test)

In [0]:
print(y_train.shape)

(50000,)


In [0]:
def byte_to_tf_feature(value):
  """Returns a bytes_list from a string / byte."""
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def float_to_tf_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def int64_to_tf_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [0]:
import os
from os import listdir
from os.path import join
 
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.keras.preprocessing.image import img_to_array, load_img
 
 
def save_tf_records(x_train, y_train, out_path):
    writer = tf.python_io.TFRecordWriter(out_path)
 
    for i in range(y_train.shape[0]):
   
        example = tf.train.Example(features=tf.train.Features(
            feature={'image': byte_to_tf_feature(x_train[i].tostring()),
                     'labels': int64_to_tf_feature(
                         y_train[i])
                     }))
 
        writer.write(example.SerializeToString())
 
    writer.close()

In [0]:
save_tf_records(x_train, y_train, './training.tf_records')

In [0]:
def load_tf_records(path):
    dataset = tf.data.TFRecordDataset(path)
 
    def parser(record):
        featdef = {
            'image': tf.FixedLenFeature(shape=[], dtype=tf.string),
            'labels': tf.FixedLenFeature(shape=[], dtype=tf.int64),
        }
 
        example = tf.parse_single_example(record, featdef)
        im = tf.decode_raw(example['image'], tf.float32)
        im = tf.reshape(im, ( 40, 40, 3))
        #im = tf.reshape(im, ( 40, 40, 3))
        lbl = tf.cast(example['labels'], tf.int64)
        #lbl = tf.decode_raw(example['labels'], tf.int64)
        #label = tf.cast(features['label'], tf.int32)
        #lbl = tf.reshape(lbl, [1])
        return im, lbl
 
    dataset = dataset.map(parser)
    #dataset = dataset.shuffle(buffer_size=50000)
    #dataset = dataset.batch(50000)
    #dataset = dataset.repeat(1)
    #iterator = dataset.make_one_shot_iterator()
    #return iterator.get_next()
    return dataset


    
    ##def parser(x):
      #  labels = tf.map_fn(lambda y: tf.cast(y, tf.int64), x['label'])
       # imgs = tf.map_fn(lambda y:tf.cast(tf.io.decode_raw(y, tf.uint8),tf.float32), x['image'], dtype=tf.float32)
        #imgs = tf.map_fn(lambda im: tf.reshape(im, shape=[32, 32, 3]),imgs, dtype=tf.float32)
        #return imgs,labels

In [0]:
ttt = load_tf_records('./training.tf_records')
ttt

<DatasetV1Adapter shapes: ((40, 40, 3), ()), types: (tf.float32, tf.int64)>

In [0]:
import sys

In [0]:
def replace_slice(input_: tf.Tensor, replacement, begin) -> tf.Tensor:
    inp_shape = tf.shape(input_)
    #print(inp_shape)
    size = tf.shape(replacement)
    #print(size)
    padding = tf.stack([begin, inp_shape - (begin + size)], axis=1)
    replacement_pad = tf.pad(replacement, padding)
    mask = tf.pad(tf.ones_like(replacement, dtype=tf.bool), padding)
    return tf.where(mask, replacement_pad, input_)

In [0]:
def get_cutout_eraser(minimum, maximum, area: int = 81, c: int = 3, min_aspect_ratio=0.5, max_aspect_ratio=2.0):
    sqrt_area = np.sqrt(area)

    def get_h_w(aspect_ratio):
        h = sqrt_area / aspect_ratio
        w = tf.math.round(area / h)
        h = tf.math.round(h)
        h = tf.cast(h, tf.int32)
        w = tf.cast(w, tf.int32)
        return h, w

    def tf_cutout(x: tf.Tensor) -> tf.Tensor:
        """
        Cutout data augmentation. Randomly cuts a h by w whole in the image, and fill the whole with zeros.
        :param x: Input image.
        :param h: Height of the hole.
        :param w: Width of the hole
        :param c: Number of color channels in the image. Default: 3 (RGB).
        :return: Transformed image.
        """
        dtype = x.dtype
        minval = tf.cast(minimum, dtype=dtype)
        maxval = tf.cast(maximum, dtype=dtype)

        #tf.print(minval, maxval, output_stream=sys.stdout, sep=',')
        #tf.print(min_aspect_ratio, max_aspect_ratio, output_stream=sys.stdout, sep=',')


        aspect_ratio = tf.random.uniform([], min_aspect_ratio, max_aspect_ratio)
        h, w = get_h_w(min_aspect_ratio)
        
        #tf.print(h, w, output_stream=sys.stdout, sep=',')
        shape = tf.shape(x)
        #tf.print(shape, output_stream=sys.stdout)

        
        x0 = tf.random.uniform([], 0, shape[0] + 1 - h, dtype=tf.int32)
        y0 = tf.random.uniform([], 0, shape[1] + 1 - w, dtype=tf.int32)

        slic = tf.random.uniform([h, w, c], minval=0, maxval=255, dtype=dtype)
        x = replace_slice(x, slic, [x0, y0, 0])
        return x

    return tf_cutout

In [0]:
cutout = get_cutout_eraser(minimum=0,maximum=255)

In [0]:
model = DavidNet()
batches_per_epoch = len_train//BATCH_SIZE + 1
last_batch_size = len_train % BATCH_SIZE

lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, EPOCHS], [0, LEARNING_RATE, 0])[0]
global_step = tf.train.get_or_create_global_step()
lr_func = lambda: lr_schedule(global_step/batches_per_epoch)/BATCH_SIZE
opt = tf.train.MomentumOptimizer(lr_func, momentum=MOMENTUM, use_nesterov=True)
data_aug = lambda x, y: (tf.image.random_flip_left_right(tf.random_crop(x, [32, 32, 3])), y)
data_aug2 = lambda x, y: (cutout(tf.random_crop(x, [32, 32, 3])), y)

In [0]:
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  #train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)
  train_set = load_tf_records('./training.tf_records')
  train_set = train_set.map(data_aug2).shuffle(len_train).batch(BATCH_SIZE)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)




Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 2.314319249267578 train acc: 0.10264 val loss: 2.3034358520507814 val acc: 0.1157 time: 55.94857358932495


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 2.067728614501953 train acc: 0.22232 val loss: 2.5639319702148438 val acc: 0.1188 time: 96.70784664154053


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 1.7579870092773437 train acc: 0.34922 val loss: 1.6767996337890625 val acc: 0.3836 time: 137.67745971679688


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 1.506467509765625 train acc: 0.44768 val loss: 1.5531252410888672 val acc: 0.4354 time: 178.96970534324646


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 1.2960021771240235 train acc: 0.53352 val loss: 2.031077178955078 val acc: 0.3875 time: 221.02517986297607


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37894736842105264 train loss: 1.0538534197998046 train acc: 0.62026 val loss: 1.1432040893554687 val acc: 0.5925 time: 263.254741191864


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.35789473684210527 train loss: 0.8582094927978515 train acc: 0.69376 val loss: 0.9455450073242188 val acc: 0.6777 time: 305.20783495903015


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.33684210526315794 train loss: 0.7358981036376954 train acc: 0.73912 val loss: 1.0973429077148438 val acc: 0.6397 time: 346.90628576278687


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.31578947368421056 train loss: 0.6359951028442383 train acc: 0.77536 val loss: 1.525308563232422 val acc: 0.517 time: 388.886515378952


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.2947368421052632 train loss: 0.5621119729614258 train acc: 0.80384 val loss: 0.8021380920410156 val acc: 0.7366 time: 430.02142810821533


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.2736842105263158 train loss: 0.4906763226318359 train acc: 0.8296 val loss: 2.593282696533203 val acc: 0.4024 time: 470.61569476127625


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.25263157894736843 train loss: 0.43897034881591795 train acc: 0.84666 val loss: 0.7444220535278321 val acc: 0.7526 time: 511.27403807640076


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.23157894736842108 train loss: 0.39132661865234375 train acc: 0.86378 val loss: 0.6211820495605469 val acc: 0.7912 time: 553.4114372730255


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.2105263157894737 train loss: 0.3465982339477539 train acc: 0.87974 val loss: 0.8035955657958984 val acc: 0.7238 time: 594.1978523731232


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.18947368421052635 train loss: 0.30956333572387695 train acc: 0.89166 val loss: 0.6951574462890625 val acc: 0.7788 time: 634.8443999290466


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.16842105263157897 train loss: 0.2705331477355957 train acc: 0.9067 val loss: 0.44940441513061524 val acc: 0.8486 time: 675.4265942573547


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.1473684210526316 train loss: 0.23157508697509765 train acc: 0.91976 val loss: 0.5927290725708008 val acc: 0.8129 time: 715.7538735866547


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.12631578947368421 train loss: 0.20516752868652344 train acc: 0.92874 val loss: 0.4085129730224609 val acc: 0.8672 time: 756.3794734477997


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.10526315789473689 train loss: 0.17865859603881837 train acc: 0.93856 val loss: 0.46414509201049803 val acc: 0.8543 time: 796.6763145923615


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.08421052631578951 train loss: 0.15116089767456053 train acc: 0.94984 val loss: 0.43306814270019534 val acc: 0.8632 time: 836.9648406505585


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.06315789473684214 train loss: 0.1325440948486328 train acc: 0.95742 val loss: 0.4309610954284668 val acc: 0.867 time: 877.2372715473175


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.04210526315789476 train loss: 0.11273083652496338 train acc: 0.964 val loss: 0.37505027160644533 val acc: 0.8807 time: 917.6604516506195


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.02105263157894738 train loss: 0.1011423365020752 train acc: 0.96818 val loss: 0.35858747329711915 val acc: 0.8892 time: 957.9918975830078


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.0 train loss: 0.08845540225982666 train acc: 0.9745 val loss: 0.34903181533813477 val acc: 0.8907 time: 998.6231739521027
