### Import Libraries

In [1]:
import numpy as np
import time, math
from tqdm import tqdm_notebook as tqdm
from matplotlib import pyplot as plt
%matplotlib inline
import tensorflow as tf
import tensorflow.contrib.eager as tfe

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [2]:
tf.enable_eager_execution()

In [3]:
import tensorflow as tf
tf_config=tf.ConfigProto()
tf_config.gpu_options.allow_growth=True
sess = tf.Session(config=tf_config)
import random

In [4]:

tf.__version__

'1.15.0'

### Model hyperparamaters

In [5]:
BATCH_SIZE = 512 #@param {type:"integer"}
MOMENTUM = 0.9 #@param {type:"number"}
LEARNING_RATE = 0.4 #@param {type:"number"}
WEIGHT_DECAY = 5e-4 #@param {type:"number"}
EPOCHS = 24 #@param {type:"integer"}

### Weight Initialization

In [6]:
def init_pytorch(shape, dtype=tf.float32, partition_info=None):
  fan = np.prod(shape[:-1])
  bound = 1 / math.sqrt(fan)
  return tf.random.uniform(shape, minval=-bound, maxval=bound, dtype=dtype)

### Resnet-18 Model
- Convolutional Block
- Residual Block

In [7]:
class ConvBN(tf.keras.Model):
  def __init__(self, c_out):
    super().__init__()
    self.conv = tf.keras.layers.Conv2D(filters=c_out, kernel_size=3, padding="SAME", kernel_initializer=init_pytorch, use_bias=False)
    self.bn = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

  def call(self, inputs):
    return tf.nn.relu(self.bn(self.conv(inputs)))

In [8]:
class ResBlk(tf.keras.Model):
  def __init__(self, c_out, pool, res = False):
    super().__init__()
    self.conv_bn = ConvBN(c_out)
    self.pool = pool
    self.res = res
    if self.res:
      self.res1 = ConvBN(c_out)
      self.res2 = ConvBN(c_out)

  def call(self, inputs):
    h = self.pool(self.conv_bn(inputs))
    if self.res:
      h = h + self.res2(self.res1(h))
    return h

In [9]:
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = ConvBN(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight

  def call(self, x, y):
    h = self.pool(self.blk3(self.blk2(self.blk1(self.init_conv_bn(x)))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

### Load CIFAR-10
- standardize images

In [10]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
len_train, len_test = len(x_train), len(x_test)
y_train = y_train.astype('int64').reshape(len_train)
y_test = y_test.astype('int64').reshape(len_test)

train_mean = np.mean(x_train, axis=(0,1,2))
train_std = np.std(x_train, axis=(0,1,2))

normalize = lambda x: ((x - train_mean) / train_std).astype('float32') # todo: check here
pad4 = lambda x: np.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)], mode='reflect')

x_train = normalize(pad4(x_train))
x_test = normalize(x_test)

In [11]:
x_train.shape

(50000, 40, 40, 3)

### Convert to TF records

In [12]:
IMAGE_HEIGHT = 40
IMAGE_WIDTH = 40
IMAGE_DEPTH = 3
NUM_CLASSES = 10
# Type convertion functions
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# convert images to tfrecords

def _convert_to_tfrecord(data, labels, tfrecords_filename):
  """Converts a file to TFRecords."""
  print('Generating %s' % tfrecords_filename)
  with tf.python_io.TFRecordWriter(tfrecords_filename) as record_writer:
    num_entries_in_batch = len(labels)
    for i in range(num_entries_in_batch):
      example = tf.train.Example(features=tf.train.Features(
        feature={
          'image': _bytes_feature(data[i].tobytes()),
          'label': _int64_feature(labels[i])
        }))
      record_writer.write(example.SerializeToString())


train_tfrecords_filename = 'TrainCifar10.tfrecords'
test_tfrecords_filename = 'TestCifar10.tfrecords'

_convert_to_tfrecord(x_train, y_train,train_tfrecords_filename)
_convert_to_tfrecord(x_test,y_test,test_tfrecords_filename)

Generating TrainCifar10.tfrecords
Generating TestCifar10.tfrecords


### Utility to decode tfrecords into tensor based images

In [13]:
# parsing the tf-record stored
def parse_record(serialized_example, isTraining = True):
  features = tf.parse_single_example(
    serialized_example,
    features={
      'image': tf.FixedLenFeature([], tf.string),
      'label': tf.FixedLenFeature([], tf.int64),
    })

  image = features['image']
  # decoding image data in bytes format to array
  image = tf.decode_raw(image, tf.float32)
  # reshape the image from linear list to image shape
  if(isTraining):
    image.set_shape([IMAGE_DEPTH * IMAGE_HEIGHT * IMAGE_WIDTH])
    image = tf.reshape(image, [IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_DEPTH])
  else:
    image.set_shape([IMAGE_DEPTH * 32 * 32])
    image = tf.reshape(image, [32, 32, IMAGE_DEPTH])
  
  #casting label data to integer format
  label = tf.cast(features['label'], tf.int64)

  return image, label

def get_decoded_records(file_name, isTraining = True):
  # returns list of tuples each containing image and label
  dataset = tf.data.TFRecordDataset(filenames=file_name)
  dataset = dataset.map(lambda x: parse_record(x, isTraining))
  return dataset

### Cutout function modified for tensorflow

In [14]:
def random_erasing(img, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3):
    '''
    img is a 3-D variable (ex: tf.Variable(image, validate_shape=False) ) and  HWC order
    '''
    # HWC order
    height = tf.shape(img)[0]
    width = tf.shape(img)[1]
    channel = tf.shape(img)[2]
    area = tf.cast(width*height, tf.float32)

    erase_area_low_bound = tf.cast(tf.round(tf.sqrt(sl * area * r1)), tf.int32)
    erase_area_up_bound = tf.cast(tf.round(tf.sqrt((sh * area) / r1)), tf.int32)
    h_upper_bound = tf.minimum(erase_area_up_bound, height)
    w_upper_bound = tf.minimum(erase_area_up_bound, width)

    h = tf.random.uniform([], erase_area_low_bound, h_upper_bound, tf.int32)
    w = tf.random.uniform([], erase_area_low_bound, w_upper_bound, tf.int32)
    import random as r
    h = r.choice(range(0, 30, 2))
    w = r.choice(range(0, 30, 2))

    h=8
    w=8
    x1 = tf.random.uniform([], 0, height+1 - h, tf.int32)
    y1 = tf.random.uniform([], 0, width+1 - w, tf.int32)

    
    erase_area = tf.cast(tf.random.uniform([h, w, channel], 0, 255, tf.int32), tf.int8)
#     h1 = (40-h)/2
#     h2=h1
#     w1 = (40-w)/2
#     w2=w1
#     print(h1,h2,w1,w2)
#     if h>=6:
#         h1 = (40-h)/2
#         h2=h1
#         w1 = (4-w)/2
#         w2=w1
#     else:
#         h1 = ((40-h)/2)+2
#         h2 = (40-h-h1)
#         w1 = ((4-w)/2)-2
#         w2 = (40-w-w1)
#     paddings = tf.constant([[h1, h2], [w1, w2]])
    paddings = tf.constant([[12, 12], [12, 12]])
#     erase_area = erase_area[:,:,1]
#     paddings = tf.constant([[14, 14], [14, 14]])
    test1 = tf.pad(erase_area[:,:,0],paddings,"CONSTANT")
    test2 = tf.pad(erase_area[:,:,1],paddings,"CONSTANT")
    test3 = tf.pad(erase_area[:,:,2],paddings,"CONSTANT")
#     print(test1.shape)
#     tf.stack([r,g,b], axis=2)
    test = tf.stack([test1,test2,test3],axis=2)
#     print(test.shape)
#     return(test)
#     print(erase_area.shape)
#     print(img.shape)
    erasing_img1 = tf.add(tf.cast(test1,tf.float32),tf.cast(img[:,:,0],tf.float32))
    erasing_img2 = tf.add(tf.cast(test2,tf.float32),tf.cast(img[:,:,1],tf.float32))
    erasing_img3 = tf.add(tf.cast(test3,tf.float32),tf.cast(img[:,:,2],tf.float32))
    erasing_img = tf.stack([erasing_img1,erasing_img2,erasing_img3],axis=2)
    return(erasing_img)
    erasing_img = img[x1:x1+h, y1:y1+w, :].assign(erase_area)
    print(erasing_img.shape)
#     return(erasing_img)

    return tf.cond(tf.random.uniform([], 0, 1) > probability, lambda: img, lambda: erasing_img)

In [15]:
x = get_decoded_records('TrainCifar10.tfrecords', isTraining = True)






In [16]:
# mnist_example = x.take(1)
# for sample in mnist_example:
#     image, label = sample[0], sample[1].numpy()
#     image = random_erasing(image,probability=0.1)
#     print(type(image))
# #     image = image.numpy()
#     plt.imshow(image.numpy().astype(np.float32), cmap=plt.get_cmap("gray"))
#     plt.show()

#     print("Label: %d" % label)

### Define Model
- LR scheduler
- OPtimizer with momentum
- Data Augmentation using flip, random crop and cutout

In [17]:
model = DavidNet()
batches_per_epoch = len_train//BATCH_SIZE + 1

lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, EPOCHS], [0, LEARNING_RATE, 0])[0]
global_step = tf.train.get_or_create_global_step()
lr_func = lambda: lr_schedule(global_step/batches_per_epoch)/BATCH_SIZE
opt = tf.train.MomentumOptimizer(lr_func, momentum=MOMENTUM, use_nesterov=True)
data_aug = lambda x, y: (random_erasing(tf.image.random_flip_left_right(tf.random_crop(x, [32, 32, 3]))), y)
# data_aug = lambda x, y: (random_erasing(x), y)

### Validation Accuracy went as high as 89.4% in 24 epochs

In [18]:
t = time.time()
test_set = get_decoded_records(test_tfrecords_filename, isTraining=False).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = get_decoded_records(train_tfrecords_filename).map(data_aug).shuffle(len_train).\
    batch(BATCH_SIZE).prefetch(1)
  tf.keras.backend.set_learning_phase(1)

  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  # Remove the CWD from sys.path while we load stuff.


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 2.209973565673828 train acc: 0.15352 val loss: 2.0896396911621093 val acc: 0.1968 time: 33.06539607048035


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 1.8277530078125 train acc: 0.3254 val loss: 2.179609436035156 val acc: 0.3144 time: 57.813419342041016


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 1.5720829345703125 train acc: 0.4174 val loss: 2.402398553466797 val acc: 0.3034 time: 82.59964227676392


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 1.4138253759765624 train acc: 0.48222 val loss: 1.8524856140136718 val acc: 0.4145 time: 107.34524893760681


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 1.257271201171875 train acc: 0.54466 val loss: 1.186934799194336 val acc: 0.5689 time: 132.06825470924377


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 6 lr: 0.37894736842105264 train loss: 1.085635213623047 train acc: 0.61056 val loss: 1.1350576751708985 val acc: 0.6205 time: 156.85752749443054


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 7 lr: 0.35789473684210527 train loss: 0.9106676440429687 train acc: 0.67482 val loss: 1.2904062957763671 val acc: 0.5795 time: 181.6296465396881


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 8 lr: 0.33684210526315794 train loss: 0.793538186340332 train acc: 0.72088 val loss: 1.2662181213378907 val acc: 0.6054 time: 206.39536094665527


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 9 lr: 0.31578947368421056 train loss: 0.7085600769042969 train acc: 0.74994 val loss: 1.9263920227050781 val acc: 0.4513 time: 230.26305222511292


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 10 lr: 0.2947368421052632 train loss: 0.6449219897460937 train acc: 0.7729 val loss: 0.792445199584961 val acc: 0.7316 time: 254.47372841835022


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 11 lr: 0.2736842105263158 train loss: 0.5863064056396484 train acc: 0.7938 val loss: 0.8120649337768555 val acc: 0.7298 time: 279.47020530700684


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 12 lr: 0.25263157894736843 train loss: 0.5392587322998047 train acc: 0.81258 val loss: 0.6659625015258789 val acc: 0.7927 time: 304.2263126373291


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 13 lr: 0.23157894736842108 train loss: 0.49857772064208983 train acc: 0.82694 val loss: 0.8037304397583008 val acc: 0.7428 time: 328.9921100139618


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 14 lr: 0.2105263157894737 train loss: 0.46297371429443357 train acc: 0.83838 val loss: 1.0774305877685546 val acc: 0.6597 time: 353.7749297618866


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 15 lr: 0.18947368421052635 train loss: 0.43406385955810545 train acc: 0.8481 val loss: 0.5718366638183594 val acc: 0.8084 time: 378.53198409080505


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 16 lr: 0.16842105263157897 train loss: 0.400472200012207 train acc: 0.8602 val loss: 0.6260364471435547 val acc: 0.7941 time: 403.3097379207611


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 17 lr: 0.1473684210526316 train loss: 0.36989736663818357 train acc: 0.87212 val loss: 0.6082208068847657 val acc: 0.799 time: 428.13292813301086


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 18 lr: 0.12631578947368421 train loss: 0.34336972076416017 train acc: 0.88174 val loss: 0.4813381103515625 val acc: 0.8429 time: 453.12046241760254


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 19 lr: 0.10526315789473689 train loss: 0.3138114041137695 train acc: 0.89134 val loss: 0.4944735153198242 val acc: 0.846 time: 477.91576051712036


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 20 lr: 0.08421052631578951 train loss: 0.29616096710205075 train acc: 0.8969 val loss: 0.43834268798828124 val acc: 0.8516 time: 502.66363167762756


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 21 lr: 0.06315789473684214 train loss: 0.26794476013183594 train acc: 0.9074 val loss: 0.42076380767822263 val acc: 0.8599 time: 527.3859009742737


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 22 lr: 0.04210526315789476 train loss: 0.2502942250061035 train acc: 0.9132 val loss: 0.3915229293823242 val acc: 0.872 time: 552.1640326976776


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 23 lr: 0.02105263157894738 train loss: 0.22866988731384277 train acc: 0.9223 val loss: 0.35262313232421877 val acc: 0.8846 time: 576.8463439941406


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


epoch: 24 lr: 0.0 train loss: 0.21513941848754883 train acc: 0.9271 val loss: 0.33112587127685544 val acc: 0.8937 time: 601.5121681690216
