In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input,Permute, Reshape, Conv2D, BatchNormalization, Activation,Add,AveragePooling2D,Flatten,Dense
import time
from tensorflow.keras.regularizers import l2,l1_l2
from tensorflow.keras.losses import Loss

In [6]:
class CustomModel(tf.keras.Model):
    def __init__(self,ensemble_size,batch_rep, inp_rep_prob,*args, **kwargs):
        super().__init__(*args, **kwargs)
        self.batch_rep = batch_rep
        self.inp_rep_prob = inp_rep_prob
        self.ensemble_size = ensemble_size

    def train_step(self, data):
        imgs, labels = data
        batch_size = tf.shape(imgs)[0]
        batch_rep = tf.tile(tf.range(batch_size),[self.batch_rep])
        shuffled_batch_rep = tf.shuffle(batch_rep)
        input_shuffle=tf.cast(tf.cast(batch_size,tf.float32) * (1. - self.inp_rep_prob),tf.int32)
        #Kan detta göras bättre?
        shuffle_idxs = [tf.concat([tf.random.shuffle(shuffled_batch_rep[:input_shuffle]), input_shuffle[input_shuffle:]], axis=0) for _ in range(self.ensemble_size)]
        imgs = tf.stack([tf.gather(imgs, indxs, axis=0) for indxs in shuffle_idxs], axis=1)
        labels = tf.stack([tf.gather(labels, indxs, axis=0) for indxs in shuffle_idxs], axis=1)

        with tf.GradientTape() as tape:
            logits = self(imgs, training=True)
            loss = self.compiled_loss(
                labels,
                imgs,
                regularization_losses=self.losses,
            )


        grads = tape.gradient(loss, self.trainable_variables)
        optimizer.apply_gradients(zip(grads, self.trainable_variables))

        return #TODO Metrics

    def test_step(self, data):
        imgs, labels = data
        imgs = tf.tile(tf.expand_dims(imgs, 1), [1, self.ensemble_size, 1, 1, 1]) #Expand to add ensemble dimension
        logits = self(imgs, training=False)
        probs = tf.nn.softmax(logits)

        probs = tf.math.reduce_mean(probs, axis=1) 


        return #TODO Metrics

In [5]:
def basic_block(input,filters,strides,l_2):
    y = input
    x = BatchNormalization(momentum=0.9,epsilon=1e-5,beta_regularizer=l2(l_2),gamma_regularizer=l2(l_2))(input)
    x = Activation('relu')(x)
    x = Conv2D(filters,3,strides=strides,padding ='same',use_bias=False,kernel_initializer="he_normal",kernel_regularizer=l2(l_2))(x)
    x = BatchNormalization(momentum=0.9,epsilon=1e-5,beta_regularizer=l2(l_2),gamma_regularizer=l2(l_2))(x)
    x = Activation('relu')(x)
    x = Conv2D(filters,3,strides=1,padding ='same',use_bias=False,kernel_initializer="he_normal",kernel_regularizer=l2(l_2))(x)
    
    if not x.shape.is_compatible_with(y.shape):
        y = Conv2D(filters,1,strides=strides,padding ='same',use_bias=False,kernel_initializer="he_normal",kernel_regularizer=l2(l_2))(input)

    return Add()([x,y])


In [None]:
def res_group(input,filters,strides,n_blocks,l_2):
    x = basic_block(input,filters,strides,l_1,l_2)
    for _ in range(n_blocks-1):
        x = basic_block(x,filters,1,l_1,l_2)
    return x

In [None]:
def wide_resnet(input_shape,d,w_mult,n_classes,batch_rep,inp_rep_prob, l_2=0):
    n_blocks = (d - 4) // 6
    input_shape = list(input_shape)
    ensemble_size = input_shape[0]

    input = Input(shape=input_shape)
    x = Permute([2,3,4,1])(input)

    # Reshape so that each subnetwork has 3 channels
    x = Reshape(input_shape[1:-1] + [input_shape[-1] * ensemble_size])(x)


    x = Conv2D(16,3,padding ='same',use_bias=False,kernel_initializer="he_normal",kernel_regularizer=l2(l_2))(x)

    for strides, filters in zip([1, 2, 2], [16, 32, 64]):
        x = res_group(x,filters*w_mult,strides,n_blocks,l_2)

    x = BatchNormalization(momentum=0.9,epsilon=1e-5,beta_regularizer=l2(l_2),gamma_regularizer=l2(l_2))(x)
    x = Activation('relu')(x)
    x = AveragePooling2D(pool_size=8)(x)
    x = Flatten()(x)

    batch_size = tf.shape(x)[0]
    x = Dense(n_classes*ensemble_size,kernel_initializer='he_normal',activation=None,kernel_regularizer=l2(l_2),bias_regularizer=l2(l_2))(x)
    x = Reshape([batch_size,ensemble_size,n_classes])(x)
    
    return tf.keras.CusomModel(ensemble_size, batch_rep,inp_rep_prob ,input=input,output=x)

In [2]:
class NLL(Loss):

    def call(self, y_true, y_pred):
        y_pred = tf.convert_to_tensor_v2(y_pred)
        y_true = tf.cast(y_true, y_pred.dtype)
        nll = tf.reduce_mean(tf.reduce_sum(tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True), axis=1))
        return nll

In [None]:
model = wide_resnet(input_shape,d,w_mult,n_classes,batch_rep,inp_rep_prob, l_2=0,l_1=0)
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    0.1,
    decay_steps=steps_per_epoch,
    decay_rate=0.1)

optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=0.9, nesterov=True)
model.compile(optimizer,loss = NLL())
model.fit(x_train,y_train,batch_size,validation_split=val_split)