### Links for other notebooks
Note: Each notebook's first cell has the summary of that notebook

* Notebook2 DavidNet.ipynb: https://colab.research.google.com/drive/1s3KPa8B-nugUZZID1R-gO6iGddIfJIJH  
DavidNet in Tensorflow  
Val Accuracy: **93.13 in 24 epochs** (Exp: BenchMark with Ammar's LR)

* Notebook3 DavidNet.ipnb: https://colab.research.google.com/drive/15sXfkEh4ptegc-7K9AcUVIBwoPQP3Jj_  
Davidnet (10 experiments)  
Val Accuracy: **93.23 in 27 epochs** (Exp: Test for 27 epochs)   

* Notebook4 DavidNet.ipynb: https://colab.research.google.com/drive/1PbkR1_cgtqTSeCqf8Ys6o0HtWuI7T4rx  
DavidNet implementation in Keras (21 experiments)  
Val Accuracy: 89.60 in 24 epochs (Exp: Exp 9-17, 1st best)

* Notebook5 ResNet.ipynb: https://colab.research.google.com/drive/1KkkaZft25mJ9ncHQSAlX50CS_p-ASTdn  
ResNet in Tensorflow keras (2 experiments)  
Val Accuracy: 89.29 in 24 epochs (Exp: BN16)  

* Notebook6 ResNet.ipynb: https://colab.research.google.com/drive/1rKltNhgPwOA5WwJ8XfL9eyofa73e8GrJ  
ResNet using Tensorflow Keras (12 Experiments)  
Val Accuracy: 88.56 in 23 epochs (Exp: MaxLR(8th))

* Notebook7 ResNet.ipynb: https://colab.research.google.com/drive/1T94mqainInCUQfdWRTOGR6T9ksOllFs2  
ResNet using Keras  
Val Accuracy: 87.4 in 24 epochs (Exp4)

### Summary


This notebook has experiments to train model by removing augentation and later changing it. Then there are experiments to find maxLR for oneCycleLR with gradual drop   
**Model**: DavidNet  
**Benchmark**: Val Accuracy: 93.10 in 24 epochs (Exp: Benchmark)  

In [0]:
# memory footprint support libraries/code
!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
!pip install gputil
!pip install psutil
!pip install humanize
import psutil
import humanize
import os
import GPUtil as GPU
GPUs = GPU.getGPUs()
# XXX: only one GPU on Colab and isn’t guaranteed
gpu = GPUs[0]
def printm():
 process = psutil.Process(os.getpid())
 print("Gen RAM Free: " + humanize.naturalsize( psutil.virtual_memory().available ), " | Proc size: " + humanize.naturalsize( process.memory_info().rss))
 print("GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))
printm()

Collecting gputil
  Downloading https://files.pythonhosted.org/packages/ed/0e/5c61eedde9f6c87713e89d794f01e378cfd9565847d4576fa627d758c554/GPUtil-1.4.0.tar.gz
Building wheels for collected packages: gputil
  Building wheel for gputil (setup.py) ... [?25l[?25hdone
  Created wheel for gputil: filename=GPUtil-1.4.0-cp36-none-any.whl size=7411 sha256=b6283c5397de406042e219e2c896155d0cde31d0c35c77332bd5268c352a1aaf
  Stored in directory: /root/.cache/pip/wheels/3d/77/07/80562de4bb0786e5ea186911a2c831fdd0018bda69beab71fd
Successfully built gputil
Installing collected packages: gputil
Successfully installed gputil-1.4.0
Gen RAM Free: 12.9 GB  | Proc size: 153.9 MB
GPU RAM Free: 11441MB | Used: 0MB | Util   0% | Total 11441MB


### Code



In [0]:
import numpy as np
import time, math
from tqdm import tqdm_notebook as tqdm

import tensorflow as tf
import tensorflow.contrib.eager as tfe

In [0]:
tf.enable_eager_execution()


In [0]:
BATCH_SIZE = 512 #@param {type:"integer"}
MOMENTUM = 0.9 #@param {type:"number"}
LEARNING_RATE = 0.4 #@param {type:"number"}
WEIGHT_DECAY = 5e-4 #@param {type:"number"}
EPOCHS = 24 #@param {type:"integer"}

In [0]:
def init_pytorch(shape, dtype=tf.float32, partition_info=None):
  fan = np.prod(shape[:-1])
  bound = 1 / math.sqrt(fan)
  return tf.random.uniform(shape, minval=-bound, maxval=bound, dtype=dtype)

In [0]:
class Conv(tf.keras.Model):
  def __init__(self, c_out):
    super().__init__()
    self.conv = tf.keras.layers.Conv2D(filters=c_out, kernel_size=3, padding="SAME", kernel_initializer=init_pytorch, use_bias=False)    

  def call(self, inputs):
    return tf.nn.relu(self.conv(inputs))

In [0]:
class ConvBN(tf.keras.Model):
  def __init__(self, c_out):
    super().__init__()
    self.conv = tf.keras.layers.Conv2D(filters=c_out, kernel_size=3, padding="SAME", kernel_initializer=init_pytorch, use_bias=False)
    self.bn = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

  def call(self, inputs):
    return tf.nn.relu(self.bn(self.conv(inputs)))

In [0]:
class ResBlk(tf.keras.Model):
  def __init__(self, c_out, pool, res = False):
    super().__init__()
    self.conv_bn = ConvBN(c_out)
    self.pool = pool
    self.res = res
    if self.res:
      self.res1 = ConvBN(c_out)
      self.res2 = ConvBN(c_out)

  def call(self, inputs):
    h = self.pool(self.conv_bn(inputs))
    if self.res:
      h = h + self.res2(self.res1(h))
    return h

In [0]:
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = Conv(c)
    self.blk1 = ResBlk(c*2, pool, res = True)
    self.blk2 = ResBlk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool, res = True)
    self.blk4 = ResBlk(c*16, pool, res = True)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight

  def call(self, x, y):
    h = self.pool(self.blk4(self.blk3(self.blk2(self.blk1(self.init_conv_bn(x))))))
    h = self.linear(h) * self.weight
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y)
    loss = tf.reduce_sum(ce)
    correct = tf.reduce_sum(tf.cast(tf.math.equal(tf.argmax(h, axis = 1), y), tf.float32))
    return loss, correct

In [0]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
len_train, len_test = len(x_train), len(x_test)
y_train = y_train.astype('int64').reshape(len_train)
y_test = y_test.astype('int64').reshape(len_test)

train_mean = np.mean(x_train, axis=(0,1,2))
train_std = np.std(x_train, axis=(0,1,2))

test_mean = np.mean(x_train, axis=(0,1,2))
test_std = np.std(x_train, axis=(0,1,2))

normalize = lambda x: ((x - train_mean) / train_std).astype('float32') # todo: check here
normalize_test = lambda x: ((x - test_mean) / test_std).astype('float32') # todo: check here
pad4 = lambda x: np.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)], mode='reflect')

x_train = normalize(pad4(x_train))
# x_test = normalize(x_test)
x_test = normalize_test(x_test)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [0]:
!wget "https://raw.githubusercontent.com/yu4u/cutout-random-erasing/master/random_eraser.py"
from random_eraser import get_random_eraser
# eraser = get_random_eraser(s_l=0.25, s_h=0.25,v_l=0, v_h=1)
eraser = get_random_eraser(p=0.3,s_l=0.2,s_h=0.2,v_l=0, v_h=1,pixel_level=True)

In [0]:

def augment(images, 
            resize=None, # (width, height) tuple or None
            horizontal_flip=False,
            vertical_flip=False,
            rotate=0, # Maximum rotation angle in degrees
            crop_probability=0, # How often we do crops
            crop_min_percent=0.6, # Minimum linear dimension of a crop
            crop_max_percent=1.,  # Maximum linear dimension of a crop
            mixup=0):  # Mixup coeffecient, see https://arxiv.org/abs/1710.09412.pdf
  if resize is not None:
    images = tf.image.resize_bilinear(images, resize)
  
  # My experiments showed that casting on GPU improves training performance
  if images.dtype is not tf.float32:
    images = tf.image.convert_image_dtype(images, dtype=tf.float32)
    images = tf.subtract(images, 0.5)
    images = tf.multiply(images, 2.0)
  # labels = tf.to_float(labels)

  with tf.name_scope('augmentation'):
    shp = tf.shape(images)
    batch_size, height, width = shp[0], shp[1], shp[2]
    width = tf.cast(width, tf.float32)
    height = tf.cast(height, tf.float32)

    # The list of affine transformations that our image will go under.
    # Every element is Nx8 tensor, where N is a batch size.
    transforms = []
    identity = tf.constant([1, 0, 0, 0, 1, 0, 0, 0], dtype=tf.float32)
    if horizontal_flip:
      coin = tf.less(tf.random_uniform([batch_size], 0, 1.0), 0.5)
      flip_transform = tf.convert_to_tensor(
          [-1., 0., width, 0., 1., 0., 0., 0.], dtype=tf.float32)
      transforms.append(
          tf.where(coin,
                   tf.tile(tf.expand_dims(flip_transform, 0), [batch_size, 1]),
                   tf.tile(tf.expand_dims(identity, 0), [batch_size, 1])))

    if vertical_flip:
      coin = tf.less(tf.random_uniform([batch_size], 0, 1.0), 0.5)
      flip_transform = tf.convert_to_tensor(
          [1, 0, 0, 0, -1, height, 0, 0], dtype=tf.float32)
      transforms.append(
          tf.where(coin,
                   tf.tile(tf.expand_dims(flip_transform, 0), [batch_size, 1]),
                   tf.tile(tf.expand_dims(identity, 0), [batch_size, 1])))

    if rotate > 0:
      angle_rad = rotate / 180 * math.pi
      angles = tf.random_uniform([batch_size], -angle_rad, angle_rad)
      transforms.append(
          tf.contrib.image.angles_to_projective_transforms(
              angles, height, width))

    if crop_probability > 0:
      crop_pct = tf.random_uniform([batch_size], crop_min_percent,
                                   crop_max_percent)
      left = tf.random_uniform([batch_size], 0, width * (1 - crop_pct))
      top = tf.random_uniform([batch_size], 0, height * (1 - crop_pct))
      crop_transform = tf.stack([
          crop_pct,
          tf.zeros([batch_size]), top,
          tf.zeros([batch_size]), crop_pct, left,
          tf.zeros([batch_size]),
          tf.zeros([batch_size])
      ], 1)

      coin = tf.less(
          tf.random_uniform([batch_size], 0, 1.0), crop_probability)
      transforms.append(
          tf.where(coin, crop_transform,
                   tf.tile(tf.expand_dims(identity, 0), [batch_size, 1])))

    if transforms:
      images = tf.contrib.image.transform(
          images,
          tf.contrib.image.compose_transforms(*transforms),
          interpolation='BILINEAR') # or 'NEAREST'

    def cshift(values): # Circular shift in batch dimension
      return tf.concat([values[-1:, ...], values[:-1, ...]], 0)

    if mixup > 0:
      mixup = 1.0 * mixup # Convert to float, as tf.distributions.Beta requires floats.
      beta = tf.distributions.Beta(mixup, mixup)
      lam = beta.sample(batch_size)
      ll = tf.expand_dims(tf.expand_dims(tf.expand_dims(lam, -1), -1), -1)
      images = ll * images + (1 - ll) * cshift(images)
      # labels = lam * labels + (1 - lam) * cshift(labels)

  return images

In [0]:
x_train = augment(images=x_train, 
                          rotate=10)#, crop_probability=0.8, )

In [0]:
model = DavidNet()
batches_per_epoch = len_train//BATCH_SIZE + 1

# lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, EPOCHS], [0, LEARNING_RATE, 0])[0]
lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, int(0.8*EPOCHS), EPOCHS], [0, LEARNING_RATE, 0.1*LEARNING_RATE, 0.005])[0]
global_step = tf.train.get_or_create_global_step()
lr_func = lambda: lr_schedule(global_step/batches_per_epoch)/BATCH_SIZE
opt = tf.train.MomentumOptimizer(lr_func, momentum=MOMENTUM, use_nesterov=True)
# data_aug = lambda x, y: (tf.image.random_flip_left_right(tf.random_crop(x, [32, 32, 3])), y)
data_aug = lambda x, y: tf.switch_case(
        tf.random.uniform([],0,1,tf.dtypes.int32), 
        branch_fns={
            0: lambda: (tf.image.random_flip_left_right(tf.random_crop(x, [32, 32, 3])), y),
            # 1: lambda: (tf.image.central_crop(x, 0.8), y),
            1: lambda: (tf.contrib.image.rotate(tf.random_crop(x, [32, 32, 3]), np.random.randint(low=1,high=15) * math.pi / 180, interpolation='BILINEAR'),y)
            # 1: lambda: 
        }, 
        # default=lambda: (tf.random_crop(x, [32, 32, 3]), y)
        default=lambda: (tf.image.random_flip_left_right(tf.random_crop(x, [32, 32, 3])), y)
    )

### BenchMark

In [0]:
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.5673736126708984 train acc: 0.43722 val loss: 1.2368145568847657 val acc: 0.5492 time: 80.57430982589722


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.8291827587890624 train acc: 0.707 val loss: 0.8341683685302734 val acc: 0.7114 time: 147.37661361694336


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.6343590902709961 train acc: 0.78006 val loss: 0.7416124740600586 val acc: 0.7599 time: 214.04439735412598


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.5407362075805664 train acc: 0.81262 val loss: 0.6011689559936524 val acc: 0.7929 time: 280.59751987457275


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.47634036224365234 train acc: 0.8379 val loss: 0.5181582809448242 val acc: 0.8229 time: 347.53501439094543


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37894736842105264 train loss: 0.3864410986328125 train acc: 0.86564 val loss: 0.4367912521362305 val acc: 0.8548 time: 414.16570687294006


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.35789473684210527 train loss: 0.31341086654663086 train acc: 0.89064 val loss: 0.43331273956298827 val acc: 0.8601 time: 481.0168013572693


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.33684210526315794 train loss: 0.26463253036499024 train acc: 0.90938 val loss: 0.5730444274902343 val acc: 0.8274 time: 547.8165826797485


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.31578947368421056 train loss: 0.22757173164367675 train acc: 0.92208 val loss: 0.3746946662902832 val acc: 0.8758 time: 614.5737209320068


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.2947368421052632 train loss: 0.19927756729125976 train acc: 0.931 val loss: 0.3706046813964844 val acc: 0.8815 time: 681.3642194271088


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.2736842105263158 train loss: 0.17219827728271483 train acc: 0.94048 val loss: 0.361001700592041 val acc: 0.8796 time: 748.1032664775848


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.25263157894736843 train loss: 0.1464209497833252 train acc: 0.94876 val loss: 0.28872755889892576 val acc: 0.9068 time: 814.8787863254547


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.23157894736842108 train loss: 0.12961100524902344 train acc: 0.95518 val loss: 0.31661532440185547 val acc: 0.8984 time: 881.6932971477509


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.2105263157894737 train loss: 0.10842801765441895 train acc: 0.96294 val loss: 0.3301973304748535 val acc: 0.9011 time: 948.5785715579987


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.18947368421052635 train loss: 0.09290014068603515 train acc: 0.96844 val loss: 0.31067129821777345 val acc: 0.9036 time: 1015.3529236316681


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.16842105263157897 train loss: 0.07639410102844238 train acc: 0.9746 val loss: 0.29912985916137697 val acc: 0.9136 time: 1082.1454706192017


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.1473684210526316 train loss: 0.06623087390899658 train acc: 0.9783 val loss: 0.28839372177124023 val acc: 0.9147 time: 1148.8937113285065


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.12631578947368421 train loss: 0.05393559543609619 train acc: 0.98326 val loss: 0.2814727714538574 val acc: 0.9173 time: 1215.7298860549927


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.10526315789473689 train loss: 0.04521117359161377 train acc: 0.98584 val loss: 0.2688369606018066 val acc: 0.9204 time: 1282.5568597316742


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.08421052631578951 train loss: 0.03810818576812744 train acc: 0.9889 val loss: 0.257457247543335 val acc: 0.9248 time: 1349.2928335666656


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.06315789473684214 train loss: 0.031118259716033936 train acc: 0.99108 val loss: 0.2638925582885742 val acc: 0.9263 time: 1416.204303741455


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.04210526315789476 train loss: 0.025494536838531492 train acc: 0.99346 val loss: 0.2495221580505371 val acc: 0.9302 time: 1483.0338776111603


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.02105263157894738 train loss: 0.022272464170455933 train acc: 0.9945 val loss: 0.24605972785949706 val acc: 0.9307 time: 1549.8490784168243


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.0 train loss: 0.02078292353630066 train acc: 0.99498 val loss: 0.2454392734527588 val acc: 0.931 time: 1616.5044195652008


### BenchMark without extra BN

In [0]:
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.2859700598144532 train acc: 0.53412 val loss: 0.8336404083251953 val acc: 0.7056 time: 65.72036981582642


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.7266783175659179 train acc: 0.74418 val loss: 0.715204833984375 val acc: 0.7535 time: 131.0114870071411


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.5749895935058593 train acc: 0.80072 val loss: 0.7324721160888672 val acc: 0.7484 time: 196.2042441368103


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.4941643084716797 train acc: 0.83242 val loss: 0.5341999465942383 val acc: 0.8211 time: 261.5650517940521


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.38463194931030276 train acc: 0.86714 val loss: 0.630856558227539 val acc: 0.7866 time: 326.8199689388275


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.3006648422241211 train acc: 0.89728 val loss: 0.5103783554077148 val acc: 0.8329 time: 391.8842749595642


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.24813826675415038 train acc: 0.91344 val loss: 0.42718442153930664 val acc: 0.8548 time: 457.0111231803894


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.20553834533691406 train acc: 0.92906 val loss: 0.3559027847290039 val acc: 0.8794 time: 522.3070878982544


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.1682792137145996 train acc: 0.94212 val loss: 0.332693864440918 val acc: 0.8873 time: 587.6753318309784


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.14391308143615722 train acc: 0.9517 val loss: 0.3356159591674805 val acc: 0.8892 time: 652.8294279575348


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.11683262710571289 train acc: 0.96118 val loss: 0.3447002563476563 val acc: 0.8923 time: 718.2954711914062


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.09816422298431396 train acc: 0.96744 val loss: 0.3116374984741211 val acc: 0.9022 time: 783.5155110359192


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.08464891521453857 train acc: 0.97288 val loss: 0.3163067008972168 val acc: 0.9026 time: 848.6671254634857


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.0704071242904663 train acc: 0.97798 val loss: 0.29070014877319333 val acc: 0.913 time: 914.1472024917603


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.060923093757629394 train acc: 0.98124 val loss: 0.30896007537841796 val acc: 0.9101 time: 979.3542218208313


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.054100586700439456 train acc: 0.98384 val loss: 0.28496411743164063 val acc: 0.916 time: 1044.666315793991


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.045457747859954836 train acc: 0.98662 val loss: 0.2709280143737793 val acc: 0.9181 time: 1109.8722751140594


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.03825597622871399 train acc: 0.9889 val loss: 0.26149426879882814 val acc: 0.9227 time: 1175.0769064426422


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.035830200185775755 train acc: 0.9896 val loss: 0.2576686996459961 val acc: 0.9216 time: 1240.417499780655


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.032200000000000006 train loss: 0.03129239737510681 train acc: 0.99146 val loss: 0.2551196815490723 val acc: 0.9242 time: 1305.6084349155426


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.024400000000000005 train loss: 0.03132392505645752 train acc: 0.99084 val loss: 0.25356032028198244 val acc: 0.9241 time: 1370.7649157047272


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.016600000000000004 train loss: 0.03001959012031555 train acc: 0.9915 val loss: 0.25228773803710935 val acc: 0.924 time: 1436.0178806781769


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.008800000000000002 train loss: 0.027385428304672242 train acc: 0.9928 val loss: 0.2522149917602539 val acc: 0.9239 time: 1501.4287130832672


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.001 train loss: 0.02721629337310791 train acc: 0.99258 val loss: 0.2514957916259766 val acc: 0.9239 time: 1566.635277748108


### BenchMark without Augmentation and First BN : Train reached 100

In [0]:
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  # train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.5673690911865235 train acc: 0.43916 val loss: 1.1392370941162109 val acc: 0.5959 time: 111.104576587677


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.7960348654174805 train acc: 0.71598 val loss: 0.969080810546875 val acc: 0.6655 time: 204.68000388145447


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.5750762902832032 train acc: 0.80074 val loss: 0.872841079711914 val acc: 0.7085 time: 298.27933716773987


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.4318459466552734 train acc: 0.8538 val loss: 0.7656824066162109 val acc: 0.7439 time: 391.9845116138458


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.3335385580444336 train acc: 0.88706 val loss: 1.1862882720947265 val acc: 0.6745 time: 485.656131029129


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.23330503173828124 train acc: 0.9186 val loss: 0.9742921600341797 val acc: 0.7215 time: 579.2364251613617


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.11644330345153808 train acc: 0.96088 val loss: 0.731602409362793 val acc: 0.7855 time: 672.8700435161591


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.052268385677337645 train acc: 0.98428 val loss: 0.7696061508178711 val acc: 0.7775 time: 766.5922839641571


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.023357878913879394 train acc: 0.99386 val loss: 0.7101159942626953 val acc: 0.799 time: 860.2115137577057


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.00702259993314743 train acc: 0.9991 val loss: 0.6747681640625 val acc: 0.8208 time: 953.8165438175201


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.00193244862139225 train acc: 0.99996 val loss: 0.5611591964721679 val acc: 0.8482 time: 1047.4309196472168


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.0007132641804218292 train acc: 1.0 val loss: 0.5719428771972657 val acc: 0.8495 time: 1141.1272656917572


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.0005147198396921158 train acc: 1.0 val loss: 0.5780513351440429 val acc: 0.8481 time: 1234.75514960289


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.0004385600881278515 train acc: 1.0 val loss: 0.5807274597167968 val acc: 0.8475 time: 1328.4458725452423


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.00039139844372868536 train acc: 1.0 val loss: 0.582683561706543 val acc: 0.8479 time: 1422.1205241680145


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.00035929645225405695 train acc: 1.0 val loss: 0.5842583648681641 val acc: 0.8479 time: 1515.7834219932556


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.00033664067178964613 train acc: 1.0 val loss: 0.5855786254882812 val acc: 0.8483 time: 1609.3777675628662


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.00032057932317256925 train acc: 1.0 val loss: 0.5865876449584961 val acc: 0.8488 time: 1703.1454510688782


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.0003095045708119869 train acc: 1.0 val loss: 0.5872516006469727 val acc: 0.8483 time: 1796.7931394577026


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.032200000000000006 train loss: 0.0003024866585433483 train acc: 1.0 val loss: 0.5876671173095703 val acc: 0.8483 time: 1890.5168538093567


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

KeyboardInterrupt: ignored

### LR with Gradual drop after One-CycleLR

In [0]:
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 1 lr: 0.08 train loss: 1.5781250744628905 train acc: 0.42716 val loss: 1.197358255004883 val acc: 0.5875 time: 67.35015916824341


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 2 lr: 0.16 train loss: 0.849458623046875 train acc: 0.69822 val loss: 1.0340927764892578 val acc: 0.651 time: 130.46793007850647


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 3 lr: 0.24 train loss: 0.6339761224365235 train acc: 0.77728 val loss: 0.7426778518676758 val acc: 0.7439 time: 193.4626920223236


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 4 lr: 0.32 train loss: 0.5422529977416992 train acc: 0.81206 val loss: 0.5425949234008789 val acc: 0.8173 time: 256.4593138694763


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 5 lr: 0.4 train loss: 0.47718792022705075 train acc: 0.83612 val loss: 0.5699813064575195 val acc: 0.8073 time: 319.596755027771


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 6 lr: 0.37428571428571433 train loss: 0.39729801315307617 train acc: 0.8626 val loss: 0.604079443359375 val acc: 0.811 time: 382.67665100097656


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 7 lr: 0.3485714285714286 train loss: 0.3149744332885742 train acc: 0.8906 val loss: 0.5024481391906739 val acc: 0.8309 time: 445.6613428592682


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 8 lr: 0.3228571428571429 train loss: 0.26525276596069336 train acc: 0.90902 val loss: 0.3371494140625 val acc: 0.8848 time: 508.64891815185547


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 9 lr: 0.29714285714285715 train loss: 0.22875417938232423 train acc: 0.9194 val loss: 0.33358311386108397 val acc: 0.8905 time: 571.5709638595581


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 10 lr: 0.27142857142857146 train loss: 0.1951402702331543 train acc: 0.9335 val loss: 0.3341755599975586 val acc: 0.8904 time: 634.619818687439


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 11 lr: 0.24571428571428575 train loss: 0.16476744354248046 train acc: 0.94412 val loss: 0.44251397171020507 val acc: 0.867 time: 697.7591528892517


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 12 lr: 0.22000000000000003 train loss: 0.14090808166503907 train acc: 0.95152 val loss: 0.3259385864257813 val acc: 0.895 time: 760.7454333305359


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 13 lr: 0.1942857142857143 train loss: 0.12130741722106933 train acc: 0.95892 val loss: 0.3076984596252441 val acc: 0.906 time: 823.6588683128357


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 14 lr: 0.1685714285714286 train loss: 0.10124667243957519 train acc: 0.96566 val loss: 0.30221018905639646 val acc: 0.9085 time: 886.7004406452179


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 15 lr: 0.1428571428571429 train loss: 0.08420347003936768 train acc: 0.9721 val loss: 0.2698875869750977 val acc: 0.9191 time: 949.9629197120667


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 16 lr: 0.11714285714285716 train loss: 0.06760005863189697 train acc: 0.97802 val loss: 0.2637753101348877 val acc: 0.9192 time: 1013.016925573349


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 17 lr: 0.09142857142857147 train loss: 0.05687027130126953 train acc: 0.9826 val loss: 0.2633765125274658 val acc: 0.9218 time: 1075.9946248531342


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 18 lr: 0.06571428571428573 train loss: 0.04654673864364624 train acc: 0.98598 val loss: 0.2450230339050293 val acc: 0.9265 time: 1139.0550689697266


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 19 lr: 0.04000000000000001 train loss: 0.038690360870361325 train acc: 0.98892 val loss: 0.2456646369934082 val acc: 0.9299 time: 1202.1118652820587


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 20 lr: 0.032200000000000006 train loss: 0.03519728136062622 train acc: 0.99034 val loss: 0.24964814338684083 val acc: 0.9288 time: 1265.3090002536774


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 21 lr: 0.024400000000000005 train loss: 0.030772331314086913 train acc: 0.99208 val loss: 0.24758256492614747 val acc: 0.9282 time: 1328.153358221054


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 22 lr: 0.016600000000000004 train loss: 0.027760543994903564 train acc: 0.99316 val loss: 0.24373788299560548 val acc: 0.9311 time: 1391.0234067440033


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 23 lr: 0.008800000000000002 train loss: 0.026477796268463135 train acc: 0.99326 val loss: 0.24229511184692382 val acc: 0.9295 time: 1453.9991524219513


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

epoch: 24 lr: 0.001 train loss: 0.02582888198852539 train acc: 0.994 val loss: 0.2426891700744629 val acc: 0.9299 time: 1516.8821604251862


### LR with Gradual drop after One-CycleLR with Seperate Test Normalization

In [0]:
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.5617154077148439 train acc: 0.43632 val loss: 1.3378051727294922 val acc: 0.5329 time: 73.42414212226868


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.836080629272461 train acc: 0.70398 val loss: 1.0398654083251953 val acc: 0.6677 time: 136.92781925201416


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.6254481799316406 train acc: 0.78398 val loss: 0.8237745101928711 val acc: 0.737 time: 200.23369097709656


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.5398001080322266 train acc: 0.81392 val loss: 0.6408422973632812 val acc: 0.7743 time: 263.4912347793579


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.47980604278564454 train acc: 0.8352 val loss: 0.5952761169433594 val acc: 0.8025 time: 326.8753447532654


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.39385492126464844 train acc: 0.86446 val loss: 0.5150402816772461 val acc: 0.8323 time: 390.1674931049347


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.31289699279785155 train acc: 0.8917 val loss: 0.4493382385253906 val acc: 0.8531 time: 453.6277995109558


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.25897709136962893 train acc: 0.9114 val loss: 0.4816672348022461 val acc: 0.8506 time: 516.9500689506531


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.22494093811035157 train acc: 0.92254 val loss: 0.3738273239135742 val acc: 0.8785 time: 580.2138628959656


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.19593995300292968 train acc: 0.93278 val loss: 0.36109345092773437 val acc: 0.8822 time: 643.5278704166412


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.16688016693115235 train acc: 0.94294 val loss: 0.31560276336669923 val acc: 0.8961 time: 706.8506219387054


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.14142989669799805 train acc: 0.95222 val loss: 0.36380703659057617 val acc: 0.8826 time: 770.2891557216644


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.12502670234680177 train acc: 0.95742 val loss: 0.30598887252807616 val acc: 0.9045 time: 833.6147863864899


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.10455691162109375 train acc: 0.9655 val loss: 0.3166338104248047 val acc: 0.8998 time: 896.8502404689789


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.0848653814315796 train acc: 0.97196 val loss: 0.2910980613708496 val acc: 0.9118 time: 960.0805604457855


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.07153363788604736 train acc: 0.9771 val loss: 0.28914592666625977 val acc: 0.9088 time: 1023.4222857952118


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.058743316612243655 train acc: 0.9817 val loss: 0.2665668476104736 val acc: 0.9181 time: 1086.8541641235352


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.04737685424804688 train acc: 0.98594 val loss: 0.2565766418457031 val acc: 0.9234 time: 1150.2114806175232


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.03999059543609619 train acc: 0.98836 val loss: 0.25253084411621096 val acc: 0.9254 time: 1213.5777168273926


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.032200000000000006 train loss: 0.03348892301559448 train acc: 0.99112 val loss: 0.2455391757965088 val acc: 0.9269 time: 1276.899026632309


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.024400000000000005 train loss: 0.031195758628845216 train acc: 0.99226 val loss: 0.2497573871612549 val acc: 0.9265 time: 1340.2563495635986


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.016600000000000004 train loss: 0.028970465030670165 train acc: 0.99272 val loss: 0.24941670608520508 val acc: 0.9267 time: 1403.6912610530853


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.008800000000000002 train loss: 0.027110953369140624 train acc: 0.99354 val loss: 0.24591126403808594 val acc: 0.9281 time: 1466.865451335907


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.001 train loss: 0.025880767068862914 train acc: 0.99386 val loss: 0.24678536376953125 val acc: 0.9271 time: 1529.9745862483978


### LR with Gradual drop after One-CycleLR

In [0]:
model = DavidNet()
batches_per_epoch = len_train//BATCH_SIZE + 1

# lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, EPOCHS], [0, LEARNING_RATE, 0])[0]
lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, int(0.8*EPOCHS), EPOCHS], [0, LEARNING_RATE, 0.15*LEARNING_RATE, 0.01])[0]
global_step = tf.train.get_or_create_global_step()
lr_func = lambda: lr_schedule(global_step/batches_per_epoch)/BATCH_SIZE
opt = tf.train.MomentumOptimizer(lr_func, momentum=MOMENTUM, use_nesterov=True)
data_aug = lambda x, y: (tf.image.random_flip_left_right(tf.random_crop(x, [32, 32, 3])), y)

In [0]:
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.572046083984375 train acc: 0.42786 val loss: 1.5202794067382812 val acc: 0.4875 time: 75.9919843673706


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.8463304397583008 train acc: 0.69774 val loss: 1.0581017669677735 val acc: 0.6357 time: 138.58160400390625


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.6340433840942383 train acc: 0.7789 val loss: 0.8111426498413086 val acc: 0.7285 time: 201.1136486530304


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.5395852044677735 train acc: 0.8138 val loss: 0.6576350845336915 val acc: 0.7847 time: 263.77787828445435


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.4743790112304688 train acc: 0.83844 val loss: 0.5384391540527343 val acc: 0.8137 time: 326.33546805381775


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.3757142857142857 train loss: 0.3919457095336914 train acc: 0.86352 val loss: 0.5571366485595703 val acc: 0.8109 time: 388.8012640476227


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3514285714285714 train loss: 0.31874830139160154 train acc: 0.88816 val loss: 0.5018534912109375 val acc: 0.8325 time: 451.47718572616577


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3271428571428572 train loss: 0.26803186477661134 train acc: 0.90592 val loss: 0.47422888946533204 val acc: 0.8417 time: 514.0587005615234


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.3028571428571429 train loss: 0.22847473358154297 train acc: 0.9207 val loss: 0.36085285263061523 val acc: 0.8781 time: 576.648542881012


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.2785714285714286 train loss: 0.1978956322479248 train acc: 0.93138 val loss: 0.30232611083984373 val acc: 0.8981 time: 639.1051406860352


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.2542857142857143 train loss: 0.16723573379516601 train acc: 0.94214 val loss: 0.34676898727416994 val acc: 0.8883 time: 701.641122341156


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.23 train loss: 0.14808100090026854 train acc: 0.94974 val loss: 0.3758795631408691 val acc: 0.8784 time: 764.3596653938293


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.2057142857142857 train loss: 0.12382277732849122 train acc: 0.9586 val loss: 0.3134396903991699 val acc: 0.8973 time: 826.9148457050323


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1814285714285714 train loss: 0.10342196132659912 train acc: 0.96538 val loss: 0.3089270721435547 val acc: 0.9024 time: 889.6358325481415


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.15714285714285714 train loss: 0.08774578128814697 train acc: 0.97052 val loss: 0.3058958267211914 val acc: 0.9063 time: 952.3367908000946


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.13285714285714284 train loss: 0.07095701885223389 train acc: 0.97624 val loss: 0.2848160011291504 val acc: 0.9166 time: 1015.0082907676697


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.10857142857142854 train loss: 0.06260657951354981 train acc: 0.97996 val loss: 0.27398252296447756 val acc: 0.9206 time: 1077.9293506145477


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.08428571428571424 train loss: 0.05044389177322388 train acc: 0.98442 val loss: 0.25358090171813963 val acc: 0.9258 time: 1140.5041494369507


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.06 train loss: 0.0411694386100769 train acc: 0.98782 val loss: 0.25430382804870605 val acc: 0.9277 time: 1203.2459607124329


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.05 train loss: 0.03579047821044922 train acc: 0.9897 val loss: 0.246568127822876 val acc: 0.9297 time: 1265.9602336883545


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.04 train loss: 0.03048294871330261 train acc: 0.99202 val loss: 0.24421445541381837 val acc: 0.9291 time: 1328.51975607872


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.030000000000000002 train loss: 0.02842438310623169 train acc: 0.99242 val loss: 0.24273134727478027 val acc: 0.9296 time: 1391.401605129242


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.020000000000000004 train loss: 0.024963816080093383 train acc: 0.99368 val loss: 0.24412520751953126 val acc: 0.93 time: 1454.256047964096


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.01 train loss: 0.02358230936050415 train acc: 0.99436 val loss: 0.24446614303588868 val acc: 0.9305 time: 1516.957738161087


### LR with Gradual drop after One-CycleLR

In [0]:
model = DavidNet()
batches_per_epoch = len_train//BATCH_SIZE + 1

# lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, EPOCHS], [0, LEARNING_RATE, 0])[0]
lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, int(0.8*EPOCHS), EPOCHS], [0, LEARNING_RATE, 0.12*LEARNING_RATE, 0.01])[0]
global_step = tf.train.get_or_create_global_step()
lr_func = lambda: lr_schedule(global_step/batches_per_epoch)/BATCH_SIZE
opt = tf.train.MomentumOptimizer(lr_func, momentum=MOMENTUM, use_nesterov=True)
data_aug = lambda x, y: (tf.image.random_flip_left_right(tf.random_crop(x, [32, 32, 3])), y)

In [0]:
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.5764755249023437 train acc: 0.43158 val loss: 1.21459326171875 val acc: 0.5723 time: 72.41288375854492


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.8421278759765625 train acc: 0.7016 val loss: 0.8049965667724609 val acc: 0.7201 time: 135.02836632728577


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.6243019998168945 train acc: 0.78252 val loss: 1.1561366882324218 val acc: 0.648 time: 197.73972821235657


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.5557411483764648 train acc: 0.8064 val loss: 0.779678025817871 val acc: 0.748 time: 260.41374683380127


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.4701523489379883 train acc: 0.83778 val loss: 0.6792806259155273 val acc: 0.7874 time: 323.30933237075806


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.3748571428571429 train loss: 0.3910506282043457 train acc: 0.86578 val loss: 0.4872090270996094 val acc: 0.8359 time: 385.9614064693451


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.34971428571428576 train loss: 0.3202196250915527 train acc: 0.88938 val loss: 0.3708888946533203 val acc: 0.8748 time: 448.65349769592285


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3245714285714286 train loss: 0.2651235939025879 train acc: 0.90768 val loss: 0.4671445022583008 val acc: 0.8497 time: 511.33636236190796


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29942857142857143 train loss: 0.2285106739807129 train acc: 0.921 val loss: 0.32114498291015625 val acc: 0.893 time: 573.9800176620483


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.2742857142857143 train loss: 0.19616619018554687 train acc: 0.93224 val loss: 0.3261442474365234 val acc: 0.8919 time: 636.9177992343903


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24914285714285717 train loss: 0.1722099678039551 train acc: 0.94106 val loss: 0.3937788101196289 val acc: 0.8741 time: 699.5077781677246


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.224 train loss: 0.14444025062561036 train acc: 0.95062 val loss: 0.35516192169189453 val acc: 0.886 time: 762.1145753860474


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.19885714285714287 train loss: 0.12169082347869874 train acc: 0.95938 val loss: 0.308295255279541 val acc: 0.9012 time: 824.6815400123596


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.17371428571428574 train loss: 0.10519016422271729 train acc: 0.96418 val loss: 0.28563486328125 val acc: 0.9113 time: 887.3065664768219


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.14857142857142858 train loss: 0.08541472644805909 train acc: 0.97168 val loss: 0.28142936096191407 val acc: 0.9141 time: 950.2246468067169


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.12342857142857144 train loss: 0.0712834981918335 train acc: 0.97704 val loss: 0.29140904655456545 val acc: 0.9093 time: 1012.9946157932281


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09828571428571431 train loss: 0.058730549659729 train acc: 0.98194 val loss: 0.2662386070251465 val acc: 0.9191 time: 1075.5915484428406


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.07314285714285718 train loss: 0.04871991079330444 train acc: 0.9858 val loss: 0.2668613090515137 val acc: 0.9187 time: 1138.2927782535553


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.048 train loss: 0.04061679504394531 train acc: 0.988 val loss: 0.2701192768096924 val acc: 0.9219 time: 1200.9797172546387


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.0404 train loss: 0.03443793603897095 train acc: 0.99098 val loss: 0.2570962371826172 val acc: 0.9243 time: 1263.8208668231964


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.0328 train loss: 0.031042065839767458 train acc: 0.99212 val loss: 0.2589907272338867 val acc: 0.9232 time: 1326.501947402954


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.0252 train loss: 0.02840428681373596 train acc: 0.99284 val loss: 0.25457475662231444 val acc: 0.9255 time: 1389.2130670547485


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.0176 train loss: 0.026847732057571412 train acc: 0.99312 val loss: 0.2547789939880371 val acc: 0.9257 time: 1451.8105771541595


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.01 train loss: 0.024859117164611817 train acc: 0.99394 val loss: 0.25448406867980955 val acc: 0.9263 time: 1514.6610689163208


### LR with Gradual drop after One-CycleLR

In [0]:
model = DavidNet()
batches_per_epoch = len_train//BATCH_SIZE + 1

# lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, EPOCHS], [0, LEARNING_RATE, 0])[0]
lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, int(0.75*EPOCHS), EPOCHS], [0, LEARNING_RATE, 0.15*LEARNING_RATE, 0.01])[0]
global_step = tf.train.get_or_create_global_step()
lr_func = lambda: lr_schedule(global_step/batches_per_epoch)/BATCH_SIZE
opt = tf.train.MomentumOptimizer(lr_func, momentum=MOMENTUM, use_nesterov=True)
data_aug = lambda x, y: (tf.image.random_flip_left_right(tf.random_crop(x, [32, 32, 3])), y)

In [0]:
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.5788017864990234 train acc: 0.42946 val loss: 1.203389111328125 val acc: 0.5766 time: 72.36773324012756


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.8321050769042969 train acc: 0.70628 val loss: 1.0963363861083983 val acc: 0.6494 time: 135.07237148284912


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.6357986819458008 train acc: 0.777 val loss: 0.8308140075683593 val acc: 0.7331 time: 197.55912518501282


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.5418064562988282 train acc: 0.81316 val loss: 1.0762475250244141 val acc: 0.6544 time: 260.1972095966339


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.4761518014526367 train acc: 0.83748 val loss: 0.8346518798828125 val acc: 0.74 time: 322.62034583091736


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.3738461538461539 train loss: 0.39108003646850586 train acc: 0.86632 val loss: 0.5485706985473633 val acc: 0.823 time: 385.3237543106079


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3476923076923077 train loss: 0.3134155917358398 train acc: 0.89278 val loss: 0.5740720138549805 val acc: 0.8199 time: 448.1067645549774


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.32153846153846155 train loss: 0.26708167388916015 train acc: 0.90798 val loss: 0.3996830177307129 val acc: 0.8654 time: 510.67581939697266


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.2953846153846154 train loss: 0.22739399909973146 train acc: 0.92128 val loss: 0.4457690155029297 val acc: 0.86 time: 573.2855653762817


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.2692307692307693 train loss: 0.19626544540405275 train acc: 0.93234 val loss: 0.3570320068359375 val acc: 0.8832 time: 635.8858783245087


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24307692307692308 train loss: 0.16784945556640626 train acc: 0.94258 val loss: 0.48486279373168945 val acc: 0.8541 time: 698.5416135787964


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.21692307692307694 train loss: 0.14427605628967285 train acc: 0.94944 val loss: 0.32397517852783203 val acc: 0.8973 time: 761.2722957134247


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.19076923076923077 train loss: 0.12339527519226075 train acc: 0.95632 val loss: 0.40259837188720704 val acc: 0.8775 time: 823.9263849258423


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1646153846153846 train loss: 0.10273787498474121 train acc: 0.9651 val loss: 0.2992329071044922 val acc: 0.9076 time: 886.4212908744812


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.13846153846153847 train loss: 0.08814954940795898 train acc: 0.96996 val loss: 0.2978466323852539 val acc: 0.9109 time: 948.9016513824463


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11230769230769233 train loss: 0.07013114463806153 train acc: 0.97776 val loss: 0.28713838500976563 val acc: 0.9127 time: 1011.5895173549652


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.08615384615384614 train loss: 0.05834564785003662 train acc: 0.9809 val loss: 0.2678711517333984 val acc: 0.9173 time: 1074.330777168274


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06 train loss: 0.04871126132965088 train acc: 0.98534 val loss: 0.25876396980285643 val acc: 0.9243 time: 1136.9329216480255


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.051666666666666666 train loss: 0.0413754478263855 train acc: 0.98828 val loss: 0.2469338535308838 val acc: 0.9256 time: 1199.4503073692322


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.043333333333333335 train loss: 0.03665470775604248 train acc: 0.98978 val loss: 0.24693914108276369 val acc: 0.9272 time: 1262.0146005153656


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.034999999999999996 train loss: 0.03211457399368286 train acc: 0.99152 val loss: 0.2462463653564453 val acc: 0.9282 time: 1324.5957052707672


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.026666666666666665 train loss: 0.03030180227279663 train acc: 0.99172 val loss: 0.24153601722717286 val acc: 0.9294 time: 1387.3312833309174


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.018333333333333333 train loss: 0.02680439739227295 train acc: 0.99302 val loss: 0.24117785148620605 val acc: 0.9294 time: 1449.7635388374329


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.01 train loss: 0.024386605606079102 train acc: 0.99454 val loss: 0.24342081756591796 val acc: 0.929 time: 1512.3256363868713


### LR with Gradual drop after One-CycleLR

In [0]:
model = DavidNet()
batches_per_epoch = len_train//BATCH_SIZE + 1

# lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, EPOCHS], [0, LEARNING_RATE, 0])[0]
lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, int(0.8*EPOCHS), EPOCHS], [0, LEARNING_RATE, 0.1*LEARNING_RATE, 0])[0]
global_step = tf.train.get_or_create_global_step()
lr_func = lambda: lr_schedule(global_step/batches_per_epoch)/BATCH_SIZE
opt = tf.train.MomentumOptimizer(lr_func, momentum=MOMENTUM, use_nesterov=True)
data_aug = lambda x, y: (tf.image.random_flip_left_right(tf.random_crop(x, [32, 32, 3])), y)

In [0]:
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.5729883239746094 train acc: 0.4318 val loss: 1.4432714904785156 val acc: 0.4921 time: 76.9434449672699


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.8328930831909179 train acc: 0.70444 val loss: 0.7488742691040039 val acc: 0.7399 time: 141.28321814537048


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.6374128131103516 train acc: 0.77728 val loss: 0.7638113677978515 val acc: 0.7594 time: 205.80347967147827


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.5316774468994141 train acc: 0.81504 val loss: 0.7109573867797851 val acc: 0.7683 time: 270.28719830513


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.4898508883666992 train acc: 0.83256 val loss: 0.5315122955322266 val acc: 0.8213 time: 334.8328628540039


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.3926790737915039 train acc: 0.86446 val loss: 0.5660333297729492 val acc: 0.8163 time: 399.37062788009644


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.31823074935913087 train acc: 0.89 val loss: 0.39684227600097655 val acc: 0.8678 time: 463.84410977363586


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.2651527754211426 train acc: 0.90756 val loss: 0.34867447052001954 val acc: 0.8774 time: 528.4634234905243


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.22826515243530274 train acc: 0.9217 val loss: 0.31846856842041016 val acc: 0.8926 time: 592.8896317481995


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.19554384384155274 train acc: 0.93216 val loss: 0.3610939796447754 val acc: 0.8836 time: 657.3656146526337


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.16678711334228516 train acc: 0.94172 val loss: 0.3449850326538086 val acc: 0.8891 time: 722.0675446987152


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.14260851379394532 train acc: 0.95094 val loss: 0.35070321807861327 val acc: 0.8901 time: 786.8998787403107


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.12410563514709473 train acc: 0.95724 val loss: 0.3295885929107666 val acc: 0.8963 time: 852.7079229354858


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.10357004867553711 train acc: 0.9651 val loss: 0.2685782764434814 val acc: 0.9169 time: 920.8589396476746


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.08681803081512451 train acc: 0.9705 val loss: 0.28154357528686524 val acc: 0.9121 time: 989.2394742965698


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.07077291715621949 train acc: 0.97694 val loss: 0.2554411136627197 val acc: 0.9201 time: 1053.8001987934113


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.059139689331054686 train acc: 0.982 val loss: 0.26072929000854494 val acc: 0.9224 time: 1121.3430335521698


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.047883650617599485 train acc: 0.98556 val loss: 0.2574003028869629 val acc: 0.9247 time: 1188.5337998867035


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.03953889844894409 train acc: 0.98844 val loss: 0.2468906021118164 val acc: 0.928 time: 1257.5414872169495


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03200000000000001 train loss: 0.03466308523178101 train acc: 0.99022 val loss: 0.24620797805786132 val acc: 0.9287 time: 1325.6473724842072


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.024000000000000004 train loss: 0.0329012788772583 train acc: 0.99124 val loss: 0.24076818771362304 val acc: 0.9288 time: 1394.795580148697


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.016 train loss: 0.030442327222824098 train acc: 0.99174 val loss: 0.24234764404296874 val acc: 0.9309 time: 1463.0227556228638


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.008 train loss: 0.028436120920181274 train acc: 0.99256 val loss: 0.23991000442504884 val acc: 0.9301 time: 1531.7265470027924


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.0 train loss: 0.027072189655303953 train acc: 0.99278 val loss: 0.24007553482055663 val acc: 0.9306 time: 1599.6102633476257


### LR with Gradual drop after One-CycleLR

In [0]:
model = DavidNet()
batches_per_epoch = len_train//BATCH_SIZE + 1

# lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, EPOCHS], [0, LEARNING_RATE, 0])[0]
lr_schedule = lambda t: np.interp([t], [0, (EPOCHS+1)//5, int(0.8*EPOCHS), EPOCHS], [0, LEARNING_RATE, 0.1*LEARNING_RATE, 0.009])[0]
global_step = tf.train.get_or_create_global_step()
lr_func = lambda: lr_schedule(global_step/batches_per_epoch)/BATCH_SIZE
opt = tf.train.MomentumOptimizer(lr_func, momentum=MOMENTUM, use_nesterov=True)
data_aug = lambda x, y: (tf.image.random_flip_left_right(tf.random_crop(x, [32, 32, 3])), y)

In [0]:
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.5756401922607421 train acc: 0.4288 val loss: 1.072593603515625 val acc: 0.6101 time: 78.34312701225281


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.8382053173828125 train acc: 0.70194 val loss: 0.813119320678711 val acc: 0.7272 time: 143.32983422279358


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.6188879544067383 train acc: 0.78582 val loss: 1.0661950439453125 val acc: 0.6733 time: 208.1281681060791


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.540421672668457 train acc: 0.81242 val loss: 0.6837811279296875 val acc: 0.7789 time: 273.1088569164276


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.4646663754272461 train acc: 0.83982 val loss: 0.6191536453247071 val acc: 0.7981 time: 337.9119203090668


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.38141759872436526 train acc: 0.87 val loss: 0.5959481643676758 val acc: 0.8057 time: 402.9627003669739


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.31145608673095704 train acc: 0.8913 val loss: 0.504211962890625 val acc: 0.8315 time: 467.8819634914398


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.26613399078369143 train acc: 0.90796 val loss: 0.341609587097168 val acc: 0.8855 time: 532.795037984848


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.22351856262207032 train acc: 0.92384 val loss: 0.3723988395690918 val acc: 0.8763 time: 597.7378373146057


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.1908758380126953 train acc: 0.93514 val loss: 0.33886055297851564 val acc: 0.8923 time: 662.615700006485


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.16317631790161133 train acc: 0.94406 val loss: 0.324035652923584 val acc: 0.8971 time: 727.7112045288086


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.13779161315917968 train acc: 0.95316 val loss: 0.40785554885864256 val acc: 0.8804 time: 792.553325176239


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.12154374671936036 train acc: 0.95764 val loss: 0.33456083297729494 val acc: 0.8985 time: 857.3681025505066


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.10104073097229004 train acc: 0.96656 val loss: 0.29609530792236327 val acc: 0.9086 time: 922.184987783432


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.08260379043579101 train acc: 0.97284 val loss: 0.2751507583618164 val acc: 0.9162 time: 987.0241296291351


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.07028454216003419 train acc: 0.97676 val loss: 0.2799189170837402 val acc: 0.9149 time: 1052.0392711162567


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.05643450017929077 train acc: 0.98222 val loss: 0.27646307907104495 val acc: 0.9195 time: 1116.9330220222473


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.04643728038787842 train acc: 0.98574 val loss: 0.2540622833251953 val acc: 0.9257 time: 1181.7840478420258


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.037135019130706784 train acc: 0.98966 val loss: 0.2561089427947998 val acc: 0.9232 time: 1246.557298898697


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.033800000000000004 train loss: 0.03359828550338745 train acc: 0.99082 val loss: 0.2505376678466797 val acc: 0.9274 time: 1311.498485326767


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.027600000000000006 train loss: 0.030669383087158202 train acc: 0.992 val loss: 0.2521941390991211 val acc: 0.9256 time: 1376.3941004276276


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.021400000000000002 train loss: 0.028361185398101806 train acc: 0.99278 val loss: 0.2522179626464844 val acc: 0.9264 time: 1441.2829911708832


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.015200000000000002 train loss: 0.026790784969329833 train acc: 0.99298 val loss: 0.25317908096313474 val acc: 0.9262 time: 1506.036226272583


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.009 train loss: 0.026615504913330076 train acc: 0.9934 val loss: 0.2504736694335937 val acc: 0.9287 time: 1570.8621892929077


### Added Augmentation

In [0]:
# GlobalAvgPool
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.5032733990478515 train acc: 0.45242 val loss: 1.0886321594238282 val acc: 0.6188 time: 111.50436234474182


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 0.8530325842285156 train acc: 0.69548 val loss: 0.9529444549560547 val acc: 0.6794 time: 199.9768569469452


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.6533779327392578 train acc: 0.77246 val loss: 0.7283566802978516 val acc: 0.7649 time: 288.34793758392334


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.5501688134765625 train acc: 0.80898 val loss: 0.9207446472167968 val acc: 0.7194 time: 376.78658080101013


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.5000022366333008 train acc: 0.82862 val loss: 0.774743960571289 val acc: 0.7654 time: 465.24572944641113


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.4241369676208496 train acc: 0.85384 val loss: 0.5882711288452148 val acc: 0.8151 time: 553.882570028305


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.3514873225402832 train acc: 0.87836 val loss: 0.480287336730957 val acc: 0.8378 time: 642.6105444431305


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.3025240234375 train acc: 0.89442 val loss: 0.5253153778076172 val acc: 0.8291 time: 731.2188715934753


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.25969407653808596 train acc: 0.91084 val loss: 0.40791345901489257 val acc: 0.8676 time: 819.8148202896118


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.22666076049804687 train acc: 0.92178 val loss: 0.3468249656677246 val acc: 0.8876 time: 908.3380837440491


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.19832207611083985 train acc: 0.93092 val loss: 0.4129452537536621 val acc: 0.8729 time: 996.9089481830597


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.1753183155822754 train acc: 0.93794 val loss: 0.3450341255187988 val acc: 0.8932 time: 1085.5193700790405


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.1516726131439209 train acc: 0.94594 val loss: 0.33399319000244143 val acc: 0.895 time: 1174.4486730098724


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.13029390907287597 train acc: 0.9563 val loss: 0.33918890533447266 val acc: 0.8972 time: 1263.0140419006348


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.10755028312683106 train acc: 0.96312 val loss: 0.3546448657989502 val acc: 0.8932 time: 1352.588057756424


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.09568681179046631 train acc: 0.96702 val loss: 0.29890195655822754 val acc: 0.9109 time: 1442.2459874153137


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.07679723133087159 train acc: 0.97332 val loss: 0.29734925079345703 val acc: 0.9145 time: 1531.7264902591705


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.06047877330780029 train acc: 0.97972 val loss: 0.2882490333557129 val acc: 0.9166 time: 1621.3573637008667


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.0487181545829773 train acc: 0.98378 val loss: 0.2813457649230957 val acc: 0.9191 time: 1710.9448218345642


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.039510198631286624 train acc: 0.9875 val loss: 0.2763801441192627 val acc: 0.9228 time: 1800.6386814117432


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.03616086312294006 train acc: 0.98838 val loss: 0.2785358768463135 val acc: 0.9226 time: 1890.3359158039093


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.03245591863632202 train acc: 0.98996 val loss: 0.2790522232055664 val acc: 0.9242 time: 1979.8410873413086


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.030312863445281984 train acc: 0.99022 val loss: 0.27781903839111327 val acc: 0.9227 time: 2069.5606100559235


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.027942822637557985 train acc: 0.99184 val loss: 0.27924166564941405 val acc: 0.9236 time: 2159.18039393425


### Added Augmentation

In [0]:
# GlobalAvgPool
t = time.time()
test_set = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

for epoch in range(EPOCHS):
  train_loss = test_loss = train_acc = test_acc = 0.0
  train_set = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(data_aug).shuffle(len_train).batch(BATCH_SIZE).prefetch(1)

  tf.keras.backend.set_learning_phase(1)
  for (x, y) in tqdm(train_set):
    with tf.GradientTape() as tape:
      loss, correct = model(x, y)

    var = model.trainable_variables
    grads = tape.gradient(loss, var)
    for g, v in zip(grads, var):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    opt.apply_gradients(zip(grads, var), global_step=global_step)

    train_loss += loss.numpy()
    train_acc += correct.numpy()

  tf.keras.backend.set_learning_phase(0)
  for (x, y) in test_set:
    loss, correct = model(x, y)
    test_loss += loss.numpy()
    test_acc += correct.numpy()
    
  print('epoch:', epoch+1, 'lr:', lr_schedule(epoch+1), 'train loss:', train_loss / len_train, 'train acc:', train_acc / len_train, 'val loss:', test_loss / len_test, 'val acc:', test_acc / len_test, 'time:', time.time() - t)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 1 lr: 0.08 train loss: 1.709021552734375 train acc: 0.3895 val loss: 1.326011083984375 val acc: 0.5089 time: 87.73148798942566


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 2 lr: 0.16 train loss: 1.1988032983398438 train acc: 0.57196 val loss: 1.2408642211914063 val acc: 0.5604 time: 175.50283885002136


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 3 lr: 0.24 train loss: 0.952559672241211 train acc: 0.66252 val loss: 0.8901464141845703 val acc: 0.685 time: 263.3326745033264


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 4 lr: 0.32 train loss: 0.8045873928833008 train acc: 0.7156 val loss: 0.8165519332885742 val acc: 0.7088 time: 351.06829285621643


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 5 lr: 0.4 train loss: 0.6980540258789063 train acc: 0.755 val loss: 0.8282282669067382 val acc: 0.707 time: 438.72689962387085


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 6 lr: 0.37428571428571433 train loss: 0.6276732815551758 train acc: 0.78092 val loss: 0.7505364364624023 val acc: 0.7312 time: 526.4738259315491


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 7 lr: 0.3485714285714286 train loss: 0.5655119232177734 train acc: 0.8038 val loss: 0.6257189300537109 val acc: 0.7809 time: 614.2550461292267


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 8 lr: 0.3228571428571429 train loss: 0.5163136071777343 train acc: 0.82106 val loss: 0.6185842498779297 val acc: 0.79 time: 702.1466720104218


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 9 lr: 0.29714285714285715 train loss: 0.4749776596069336 train acc: 0.83608 val loss: 0.5101982711791992 val acc: 0.8235 time: 789.7827477455139


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 10 lr: 0.27142857142857146 train loss: 0.44279710052490234 train acc: 0.84754 val loss: 0.5605245239257812 val acc: 0.8128 time: 877.6319212913513


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 11 lr: 0.24571428571428575 train loss: 0.40757338897705075 train acc: 0.85904 val loss: 0.5304155685424805 val acc: 0.8176 time: 965.4076209068298


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 12 lr: 0.22000000000000003 train loss: 0.3824380662536621 train acc: 0.86852 val loss: 0.5214053207397461 val acc: 0.8171 time: 1053.18577003479


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 13 lr: 0.1942857142857143 train loss: 0.3573675692749023 train acc: 0.87696 val loss: 0.5547255889892578 val acc: 0.8178 time: 1140.8264727592468


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 14 lr: 0.1685714285714286 train loss: 0.33890083984375 train acc: 0.88392 val loss: 0.4615213256835938 val acc: 0.8449 time: 1228.611144542694


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 15 lr: 0.1428571428571429 train loss: 0.3162707019042969 train acc: 0.89188 val loss: 0.45866693572998046 val acc: 0.8458 time: 1316.271376132965


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 16 lr: 0.11714285714285716 train loss: 0.29655705139160154 train acc: 0.89978 val loss: 0.47494787979125974 val acc: 0.8375 time: 1404.0585067272186


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 17 lr: 0.09142857142857147 train loss: 0.2806854797363281 train acc: 0.90312 val loss: 0.4819341247558594 val acc: 0.8372 time: 1492.2008607387543


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 18 lr: 0.06571428571428573 train loss: 0.26486497299194334 train acc: 0.90944 val loss: 0.4986314895629883 val acc: 0.8379 time: 1579.9963948726654


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 19 lr: 0.04000000000000001 train loss: 0.25010818161010745 train acc: 0.91428 val loss: 0.46876868438720704 val acc: 0.8426 time: 1667.8314554691315


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 20 lr: 0.03300000000000001 train loss: 0.2347546627807617 train acc: 0.91928 val loss: 0.49122921142578124 val acc: 0.8374 time: 1755.69038772583


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 21 lr: 0.026000000000000002 train loss: 0.22667024085998536 train acc: 0.92264 val loss: 0.42604240493774415 val acc: 0.8596 time: 1843.4422998428345


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 22 lr: 0.019000000000000003 train loss: 0.21167442672729492 train acc: 0.92786 val loss: 0.5472040420532227 val acc: 0.8263 time: 1931.0193729400635


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 23 lr: 0.012 train loss: 0.20096443725585939 train acc: 0.93186 val loss: 0.4120720420837402 val acc: 0.866 time: 2018.7962889671326


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


epoch: 24 lr: 0.005 train loss: 0.19094788940429688 train acc: 0.93478 val loss: 0.4139376693725586 val acc: 0.8664 time: 2106.876798391342


### Backup