Imports

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tensorflow.keras.datasets import cifar10 as cifar
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.regularizers import l2
from tensorflow.keras import mixed_precision
from tensorflow.keras import backend as K
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
from numpy import random as npr
from tqdm import tqdm
import tensorflow as tf
import numpy as np
import time, sys, gc

Mixed Precision Setup

In [None]:
useMixed = True
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
    tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    useMixed = False
    gpus = tf.config.experimental.list_physical_devices('GPU')
    tf.config.experimental.set_memory_growth(gpus[0], True)

In [None]:
if useMixed:
    print('---MIXED PRECISION TRAINING---')
    mixed_precision.set_global_policy('mixed_bfloat16')

Hyperparameters

In [None]:
m = 50000
if useMixed:
    batchSize = 1024
else:
    batchSize = 256
    
epochs = 300
steps_per_epoch = 1000
p = 0.0
pStep = 1e-2

Load data and augment function

In [None]:
(train_x, train_y), (test_x, test_y) = cifar.load_data()

def normalize(batch):
    mean = np.mean(batch, axis=(0, 1, 2), keepdims=True)
    std = np.std(batch, axis=(0, 1, 2), keepdims=True)
    return (batch - mean) / std

train_x = normalize(train_x[:m])
test_x = normalize(test_x)
train_y = tf.keras.utils.to_categorical(train_y)[:m]
test_y = tf.keras.utils.to_categorical(test_y)

In [None]:
def aug(imgs, p):
    imgSize = 32
    augImgs = tf.cast(imgs, tf.float32)
    def augCond(x):
        randInds = tf.random.uniform((batchSize, 1, 1, 1))
        trueCond = tf.cast(randInds < p, tf.float32) # using tf.cast to turn booleans into ones and zeros
        falseCond = 1 - trueCond
        auged = x * trueCond + augImgs * falseCond
        return auged
    
    height = tf.random.uniform((), minval=0.5, maxval=1)
    width = tf.random.uniform((), minval=0.5, maxval=1)
    boxLite = tf.random.uniform((batchSize, 2), maxval=(1-height, 1-width))
    boxes = tf.concat([boxLite, tf.transpose(boxLite[:, 0][np.newaxis]) + height, tf.transpose(boxLite[:, 1][np.newaxis]) + width], axis=1)
    boxLiteIso = tf.random.uniform((batchSize, 1), maxval=1-height)
    boxIso = tf.concat([boxLite, tf.transpose(boxLiteIso[:, 0][np.newaxis]) + height, tf.transpose(boxLiteIso[:, 0][np.newaxis]) + height], axis=1)
    rot90s = np.pi * 90 * tf.cast(tf.random.uniform((batchSize,), minval=0, maxval=4, dtype=tf.int32), tf.float32) / 180
    augImgs = augCond(tf.image.random_brightness(augImgs, max_delta=0.25))
    augImgs = augCond(tf.image.crop_and_resize(augImgs, boxIso, tf.range(batchSize), (imgSize, imgSize), extrapolation_value=1))
    augImgs = augCond(tf.image.crop_and_resize(augImgs, boxes, tf.range(batchSize), (imgSize, imgSize), extrapolation_value=1))
    augImgs = augCond(tfa.image.rotate(augImgs, rot90s, fill_mode='reflect'))
    augImgs = augCond(tfa.image.rotate(augImgs, tf.random.uniform((batchSize,), minval=-np.pi/6, maxval=np.pi/6), fill_mode='reflect'))
    augImgs = augCond(tfa.image.translate(augImgs, tf.random.normal((batchSize, 2), 0, imgSize // 10), fill_mode='reflect'))
    return augImgs

Custom Layers

In [None]:
ndist = tf.random_normal_initializer(0, 1)
lecun = tf.keras.initializers.LecunNormal()
zeros = tf.zeros_initializer()
ones = tf.ones_initializer()
l2 = tf.keras.regularizers.L2(l2=2e-5)

castZero = tf.convert_to_tensor(0.0)
magic = tf.convert_to_tensor((2 / (1 - (1 / np.pi))) ** 0.5)
if useMixed:
    castZero = tf.cast(castZero, tf.bfloat16)
    magic = tf.cast(magic, tf.bfloat16)

activation = Lambda(lambda x: tf.maximum(castZero, x) * magic)
 
class WSConv(Conv2D):
    def __init__(self, units, kernel_size=3, kernel_initializer=ndist, bias_initializer=zeros, kernel_regularizer=l2, padding='same', relu=True, *args, **kwargs):
        super().__init__(units, kernel_size, *args, **kwargs)
        self.kernel_initializer = kernel_initializer
        self.bias_initializer = bias_initializer
        self.kernel_regularizer = kernel_regularizer
        self.padding = padding
        if kernel_size != 1:
            self.groups = units // (scale // 2)
        self.relu = relu
        self.scale = 1

    def build(self, input_shape):
        super().build(input_shape)
        self.fan_in = np.prod(self.kernel.shape[:-1]) # self.kernel.shape = (kernel_x, kernel_y, features_in, features_out)
        self.gain = self.add_weight(name='gain', shape=(self.kernel.shape[-1],), initializer='ones', trainable=True)

    def call(self, inputs):
        mean = tf.math.reduce_mean(self.kernel, axis=(0, 1, 2), keepdims=True)
        var = tf.math.reduce_variance(self.kernel, axis=(0, 1, 2), keepdims=True)
        k = self.gain * (self.kernel - mean) * tf.math.rsqrt(tf.maximum(var * self.fan_in, 1e-8))
        output = K.conv2d(inputs, k, padding=self.padding, strides=self.strides)
        output = K.bias_add(output, self.bias)
        
        if self.relu:
            output = activation(output)
        return output
    
class SE(Layer):
    def __init__(self, nf, se_ratio=0.5, activation=activation, name=None):
        super(SE, self).__init__(name=name)
        self.activation = activation
        self.fc0 = Dense(int(se_ratio * nf), kernel_initializer=lecun, kernel_regularizer=l2, use_bias=True)
        self.fc1 = Dense(nf, kernel_initializer=lecun, kernel_regularizer=l2, use_bias=True)

    def call(self, x):
        h = tf.math.reduce_mean(x, axis = [1, 2])
        h = self.fc1(activation(self.fc0(h)))
        h = tf.keras.activations.sigmoid(h)[:, None, None]
        return 2 * x * h

class StochDepth(Layer):
    def __init__(self, dropRate=0.25):
        super(StochDepth, self).__init__()
        self.dropRate = dropRate
    
    def call(self, x, training):
        batchSize = tf.shape(x)[0]
        
        if not training:
            return x, tf.ones([batchSize, 1, 1, 1], dtype=x.dtype)
        
        randNums = tf.random.uniform(shape=[batchSize, 1, 1, 1], dtype=x.dtype)
        keepRate = 1 - self.dropRate
        keepInds = tf.floor(keepRate + randNums)
        return x * keepInds, keepInds

class NFBlock(Layer):
    def __init__(self, units, stochDepth, strides=1, skipconv=False, alpha=0.2):
        super(NFBlock, self).__init__()
        self.alpha = alpha
        self.strides = strides
        self.skipconv = skipconv
        if self.skipconv:
            self.skipConv = WSConv(units, 1, relu=False)
            
        self.conv1 = WSConv(units // 2, 1)
        self.conv2 = WSConv(units // 2, strides=strides)
        self.conv3 = WSConv(units // 2)
        self.conv4 = WSConv(units, 1, relu=False)
        self.se = SE(units)
        self.stochDepth = StochDepth(stochDepth)
        
    def build(self, input_shape):
        super().build(input_shape)
        self.skipInitGain = self.add_weight(name='skip_init_gain', shape=(), trainable=True, initializer='zeros')
    
    def call(self, inputs):
        out, var = inputs
        beta = var ** 0.5
        
        if self.skipconv:
            out = out * (1 / beta) # var = 1
            out = activation(out)

        skip = out
        if self.strides == 2:
            skip = AveragePooling2D()(skip)
        if self.skipconv:
            skip = self.skipConv(skip)
        
        if self.strides == 1:
            out = out * (1 / beta)
            out = activation(out)
            
        out = self.conv1(out)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        
        out = self.se(out) * self.alpha * self.skipInitGain
        out, stochSkips = self.stochDepth(out)
        out = Add()([skip, out])
        
        if self.skipconv:
            var = 1 + stochSkips * (alpha * self.skipInitGain) ** 2
        else:
            var = var + stochSkips * (alpha * self.skipInitGain) ** 2
        return out, var

Building model

In [None]:
def build(units=(1, 2, 6, 6), repeats=(1, 2, 6, 3), scale=64, repeatScale=1):
    inp = Input((32, 32, 3))
    out = inp

    calcStochDepth = lambda ind: 0.25 * ind / sum(units)
    blockIdx = 0
    var = 1.0
    alpha = 0.2
    for idx, unit in enumerate(units):
        if idx == 0: # put var in call arg to prevent sicko mode
            out, var = NFBlock(unit * scale, calcStochDepth(blockIdx), skipconv=True, strides=1, alpha=alpha)([out, var])
        else:
            out, var = NFBlock(unit * scale, calcStochDepth(blockIdx), skipconv=True, strides=2, alpha=alpha)([out, var])
        blockIdx += 1

        for _ in range(repeatScale * repeats[idx] - 1):
            out, var = NFBlock(unit * scale, calcStochDepth(blockIdx), alpha=alpha)([out, var])
            blockIdx += 1
        
    out = WSConv(2 * units[-1] * scale, 1)(out)
    out = tf.reduce_mean(out, axis=[1, 2])
    out = Flatten()(out)
    out = Dropout(0.2)(out)
    out = Dense(100, kernel_initializer=tf.random_normal_initializer(0, 0.01), kernel_regularizer=l2)(out) # 0.01 thing sus
    out = Activation('softmax', dtype=tf.float32)(out)
    return Model(inp, out)

Step function

In [None]:
loss_fn = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1, reduction=tf.keras.losses.Reduction.SUM)
recipBatch = tf.convert_to_tensor(1 / batchSize)

@tf.function(experimental_compile=useMixed)
def trainStep(x, y):
    def norm(tensor, axes):
        return tf.sqrt(tf.reduce_sum(tf.square(tensor), axis=axes))
    
    def clip(grad, w):
        lb = 1e-2
        eps = 1e-3
        fanInRange = tf.range(tf.maximum(0, tf.rank(grad)-1))
        gw = norm(grad, axes=fanInRange) / tf.maximum(norm(w, axes=fanInRange), eps)
        smolG = tf.cast(gw < lb, grad.dtype)
        bigG  = 1 - smolG
        clipGrad = smolG * grad + lb / tf.maximum(gw, 1e-6) * bigG * grad
        return clipGrad

    preds = model(x, training=True)

    loss = loss_fn(y, preds) * recipBatch

    grad = tf.gradients(loss, model.trainable_variables)
    clipGrad = [clip(grad[wi], w) for wi, w in enumerate(model.trainable_variables[:-2])] + grad[-2:] # :-2 excludes clipping on last linear layer with bias
    opt.apply_gradients(zip(clipGrad, model.trainable_variables))
    return loss

@tf.function(experimental_compile=useMixed)
def evalBatch(x, y):
    return loss_fn(y, model(x, training=False)) * recipBatch

Train function

In [None]:
def toNp(mean):
    ret = []
    
    if useMixed:
        mean = tf.reduce_mean(mean.values)
    
    mean = mean.numpy()
    return mean

def train(batchSize, epochs=1, steps=5, evaluate=True):
    global p, pStep
    
    for i in range(epochs):
        stepNum = 0
        cost = 0
        
        batchRandInds = npr.randint(0, m, (steps, batchSize))
        validRandInds = npr.randint(0, test_x.shape[0], (steps // 16 + 1, batchSize))
        for step in tqdm(range(steps)):
            randInds = batchRandInds[step]
            batchX = aug(train_x[randInds], p)
            batchY = train_y[randInds]
            
            if useMixed:
                loss = toNp(tpu_strategy.run(trainStep, args=(batchX, batchY)))
            else:
                loss = toNp(trainStep(batchX, batchY))
            
            if step % 16 == 0:
                randInds = validRandInds[step // 16]
                validX = test_x[randInds]
                validY = test_y[randInds]
                if useMixed:
                    valid_loss = tf.reduce_mean(tpu_strategy.run(
                        evalBatch, args=(validX, validY)).values).numpy()
                else:
                    valid_loss = evalBatch(validX, validY).numpy()
                    
                p += pStep * np.sign(valid_loss / loss - 1.1) # if val loss high, train loss low, up augment
                p = max(p, 0) % 1.0
            
            cost += loss
        print('Epoch {} | Train Loss: {}'.format(i, cost / steps))
        if evaluate:
            model.evaluate(test_x, test_y, batch_size=batchSize)

Custom learning rate scheduler

In [None]:
class HybridLR(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, lr=(1e-1 * batchSize / 256), decay_steps=((epochs - 5) * steps_per_epoch), warmup=(5 * steps_per_epoch)):
        super().__init__()
        self.maxlr = lr
        self.decSteps = decay_steps
        self.warmSteps = warmup
    
    def __call__(self, step):
        def cosDecay(step):
            step = tf.minimum(step, tf.cast(self.decSteps, step.dtype))
            cosine_decay = 0.5 * (1 + tf.math.cos(np.pi * step / self.decSteps))
            return self.maxlr * cosine_decay
        
        lr = tf.cond(step < tf.cast(self.warmSteps, step.dtype),
                     lambda: step / self.warmSteps * self.maxlr,
                     lambda: cosDecay(step - tf.cast(self.warmSteps, step.dtype)))
        return lr

Creating the model and compiling

In [None]:
if useMixed:
    with tpu_strategy.scope():
        scale = 256
        repeatScale = 1
        model = build(scale=scale, repeatScale=repeatScale)
        print('Scale: {} | Num params: {}'.format(scale, model.count_params()))
        opt = tf.keras.optimizers.SGD(learning_rate=HybridLR(), momentum=0.9, nesterov=True)
        model.compile(optimizer=opt, loss=loss_fn, metrics=['accuracy'])
else:
    scale = 256
    repeatScale = 1
    model = build(scale=scale, repeatScale=repeatScale)
    print('Scale: {} | Num params: {}'.format(scale, model.count_params()))
    opt = tf.keras.optimizers.SGD(learning_rate=HybridLR(), momentum=0.9, nesterov=True)
    model.compile(optimizer=opt, loss=loss_fn, metrics=['accuracy'])

Train model

In [None]:
train(batchSize, epochs, steps=steps_per_epoch)

In [None]:
batchX = train_x[batchSize:2*batchSize]
for i in range(3):
    plt.imshow(batchX[i])
    plt.show()
print(model.predict(batchX))
print(loss_fn(train_y[batchSize:2*batchSize], model.predict(batchX)) / batchSize)
print(p, pStep)

Evaluate model

In [None]:
model.evaluate(train_x, train_y, batch_size=batchSize)