In [1]:
# Import requires libraries

import numpy as np
import time, math
from tqdm import tqdm_notebook as tqdm

import tensorflow as tf
import tensorflow.contrib.eager as tfe

from PIL import Image
import numpy as np
import skimage.io as io

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [0]:
# enable TF eager mode
tf.enable_eager_execution()

In [0]:
# set training parameters

BATCH_SIZE = 512 #@param {type:"integer"}
MOMENTUM = 0.9 #@param {type:"number"}
LEARNING_RATE = 0.4 #@param {type:"number"}
WEIGHT_DECAY = 5e-4 #@param {type:"number"}
EPOCHS = 24 #@param {type:"integer"}

In [0]:
def init_pytorch(shape, dtype=tf.float32, partition_info=None):
  fan = np.prod(shape[:-1])
  bound = 1 / math.sqrt(fan)
  return tf.random.uniform(shape, minval=-bound, maxval=bound, dtype=dtype)

In [0]:
# Defining the Convlution Block.

class ConvBN(tf.keras.Model):
  def __init__(self, c_out):
    super().__init__()
    self.conv = tf.keras.layers.Conv2D(filters=c_out, kernel_size=3, padding="SAME", kernel_initializer=init_pytorch, use_bias=False)
    self.bn = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

  def call(self, inputs):
    return tf.nn.relu(self.bn(self.conv(inputs)))

In [0]:
# Defining the Residual Block.

class ResBlk(tf.keras.Model):
  def __init__(self, c_out, pool, res = False):
    super().__init__()
    self.conv_bn = ConvBN(c_out)
    self.pool = pool
    self.res = res
    if self.res:
      self.res1 = ConvBN(c_out)
      self.res2 = ConvBN(c_out)

  def call(self, inputs):
    h = self.pool(self.conv_bn(inputs))
    if self.res:
      h = h + self.res2(self.res1(h))
    return h

In [0]:
# Defing the David Net Model.

class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = ConvBN(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight

  def call(self, x, y):
    h = self.pool(self.blk3(self.blk2(self.blk1(self.init_conv_bn(x)))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

In [8]:
# Loading the Dataset and then Normalizing it.

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
len_train, len_test = len(x_train), len(x_test)
y_train = y_train.astype('int64').reshape(len_train)
y_test = y_test.astype('int64').reshape(len_test)

train_mean = np.mean(x_train, axis=(0,1,2))
train_std = np.std(x_train, axis=(0,1,2))

normalize = lambda x: ((x - train_mean) / train_std).astype('float32') # todo: check here
pad4 = lambda x: np.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)], mode='reflect')

x_train = normalize(pad4(x_train))
x_test = normalize(x_test)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [0]:
# Defining the Optimization Strategies.

model = DavidNet()
batches_per_epoch = len_train//BATCH_SIZE + 1

lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, EPOCHS], [0, LEARNING_RATE, 0])[0]
global_step = tf.train.get_or_create_global_step()
lr_func = lambda: lr_schedule(global_step/batches_per_epoch)/BATCH_SIZE
opt = tf.train.MomentumOptimizer(lr_func, momentum=MOMENTUM, use_nesterov=True)
data_aug = lambda x, y: (tf.image.random_flip_left_right(tf.random_crop(x, [32, 32, 3])), y)

In [10]:
############Define parameters###########################
IMAGE_HEIGHT = 40
IMAGE_WIDTH = 40
IMAGE_DEPTH = 3
NUM_CLASSES = 10

############## ENCODE ####################################
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

#################

def _convert_to_tfrecord(data, labels, tfrecords_filename):
  """Converts a file to TFRecords."""
  print('Generating %s' % tfrecords_filename)
  with tf.python_io.TFRecordWriter(tfrecords_filename) as record_writer:
    num_entries_in_batch = len(labels)
    for i in range(num_entries_in_batch):
      example = tf.train.Example(features=tf.train.Features(
        feature={
          'image': _bytes_feature(data[i].tobytes()),
          'label': _int64_feature(labels[i])
        }))
      record_writer.write(example.SerializeToString())

#################

train_tfrecords_filename = 'TrainCifar10.tfrecords'
test_tfrecords_filename = 'TestCifar10.tfrecords'

_convert_to_tfrecord(x_train, y_train,train_tfrecords_filename)
_convert_to_tfrecord(x_test,y_test,test_tfrecords_filename)

##################################################

Generating TrainCifar10.tfrecords
Generating TestCifar10.tfrecords


In [0]:
####################### DECODE #####################################

def parse_record(serialized_example, isTraining = True):
  features = tf.parse_single_example(
    serialized_example,
    features={
      'image': tf.FixedLenFeature([], tf.string),
      'label': tf.FixedLenFeature([], tf.int64),
    })
  
  image = features['image']
  image = tf.decode_raw(image, tf.float32)

  if(isTraining):
    image = tf.reshape(image, [IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_DEPTH])
  else:
    image = tf.reshape(image, [32, 32, IMAGE_DEPTH])
  
  label = tf.cast(features['label'], tf.int64)
  return image, label

def generate_input_fn(file_name, isTraining = True):
  dataset = tf.data.TFRecordDataset(filenames=file_name)
  dataset = dataset.map(lambda x: parse_record(x, isTraining))
  return dataset  
########################################################################

In [12]:
t = time.time()
test_set = generate_input_fn(test_tfrecords_filename, isTraining=False).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = generate_input_fn(train_tfrecords_filename).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)
  tf.keras.backend.set_learning_phase(1)

  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)






HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.5700133514404297 train acc: 0.4335 val loss: 1.2944655334472657 val acc: 0.5474 time: 59.146493434906006


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.8266529840087891 train acc: 0.7055 val loss: 0.7616413375854493 val acc: 0.7345 time: 96.29093360900879


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.6246615435791015 train acc: 0.782 val loss: 0.6903264373779296 val acc: 0.766 time: 132.88453102111816


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.53742564453125 train acc: 0.81488 val loss: 0.6760789138793946 val acc: 0.7832 time: 169.66940712928772


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.47151458251953127 train acc: 0.83776 val loss: 0.6483724807739257 val acc: 0.7883 time: 206.57940912246704


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37894736842105264 train loss: 0.3824378744506836 train acc: 0.86886 val loss: 0.43738893661499023 val acc: 0.8513 time: 243.26254510879517


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.35789473684210527 train loss: 0.31419364349365236 train acc: 0.8919 val loss: 0.44823578186035157 val acc: 0.853 time: 279.88696336746216


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.33684210526315794 train loss: 0.2672888885498047 train acc: 0.90796 val loss: 0.5612103988647461 val acc: 0.8305 time: 316.3540744781494


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.31578947368421056 train loss: 0.23177601989746094 train acc: 0.92098 val loss: 0.35451045761108396 val acc: 0.8844 time: 353.0509297847748


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.2947368421052632 train loss: 0.19854824798583984 train acc: 0.93124 val loss: 0.32905500946044924 val acc: 0.8905 time: 389.66523575782776


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.2736842105263158 train loss: 0.17230789375305175 train acc: 0.94092 val loss: 0.31513900527954103 val acc: 0.8959 time: 425.83828020095825


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.25263157894736843 train loss: 0.14670713203430175 train acc: 0.94934 val loss: 0.38285472259521486 val acc: 0.8822 time: 462.0678904056549


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.23157894736842108 train loss: 0.12869091361999513 train acc: 0.95546 val loss: 0.3262969261169434 val acc: 0.8954 time: 499.92179346084595


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.2105263157894737 train loss: 0.11400431312561035 train acc: 0.96096 val loss: 0.3328449157714844 val acc: 0.8995 time: 537.5691368579865


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.18947368421052635 train loss: 0.0959979401397705 train acc: 0.9672 val loss: 0.3723915252685547 val acc: 0.8946 time: 574.2257447242737


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.16842105263157897 train loss: 0.07802940856933593 train acc: 0.97392 val loss: 0.2903667812347412 val acc: 0.9102 time: 610.6983668804169


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.1473684210526316 train loss: 0.0655011933708191 train acc: 0.97878 val loss: 0.2667197700500488 val acc: 0.9182 time: 648.166779756546


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.12631578947368421 train loss: 0.05549937532424927 train acc: 0.98216 val loss: 0.263979080581665 val acc: 0.9213 time: 685.6296856403351


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.10526315789473689 train loss: 0.04565949991226196 train acc: 0.9858 val loss: 0.2687651744842529 val acc: 0.9198 time: 723.6167635917664


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.08421052631578951 train loss: 0.03811391527175903 train acc: 0.98896 val loss: 0.25346498107910154 val acc: 0.9272 time: 761.4369156360626


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.06315789473684214 train loss: 0.029510499458312987 train acc: 0.99208 val loss: 0.250476277923584 val acc: 0.9263 time: 799.2508842945099


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.04210526315789476 train loss: 0.02606761640548706 train acc: 0.99314 val loss: 0.2462168960571289 val acc: 0.9295 time: 836.870521068573


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.02105263157894738 train loss: 0.022653040800094604 train acc: 0.99428 val loss: 0.24615054168701173 val acc: 0.9295 time: 873.9710192680359


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.0 train loss: 0.021020205895900725 train acc: 0.99456 val loss: 0.24484500770568848 val acc: 0.9303 time: 911.3703899383545
