In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input,Permute, Reshape, Conv2D, BatchNormalization, Activation,Add,AveragePooling2D,Flatten,Dense
import time
from tensorflow.keras.regularizers import l2,l1_l2
from tensorflow.keras.losses import Loss

In [None]:
import argparse
parser = argparse.ArgumentParser(description='Parser')
parser.add_argument('epochs',type=int,default=2)
#TODO
args = parser.parse_args()

In [None]:
class CustomModel(tf.keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        #TODO

    def train_step(self, data):
        imgs, labels = data
        batch_size = tf.shape(imgs)[0]
        batch_rep = tf.tile(tf.range(batch_size),[args.batch_rep])
        shuffled_batch_rep = tf.shuffle(batch_rep)
        input_shuffle=tf.cast(tf.cast(batch_size,tf.float32) * (1. - args.inp_rep_prob),tf.int32)
        #Kan detta göras bättre?
        shuffle_idxs = [tf.concat([tf.random.shuffle(shuffled_batch_rep[:input_shuffle]), input_shuffle[input_shuffle:]], axis=0) for _ in range(args.ensemble_size)]
        imgs = tf.stack([tf.gather(imgs, indxs, axis=0) for indxs in shuffle_idxs], axis=1)
        labels = tf.stack([tf.gather(labels, indxs, axis=0) for indxs in shuffle_idxs], axis=1)

        with tf.GradientTape() as tape:
            logits = self(imgs, training=True)
            loss = self.compiled_loss(
                labels,
                imgs,
                regularization_losses=self.losses,
            )


        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        return super().train_step(data)

In [None]:
def basic_block(input,filters,strides,l_1,l_2):
    y = input
    x = BatchNormalization(momentum=0.9,epsilon=1e-5,beta_regularizer=l2(l_2),gamma_regularizer=l2(l_2))(input)
    x = Activation('relu')(x)
    x = Conv2D(filters,3,strides=strides,padding ='same',use_bias=False,kernel_initializer="he_normal",kernel_regularizer=l1_l2(l1=l_1,l2=l_2))(x)
    x = BatchNormalization(momentum=0.9,epsilon=1e-5)(x)
    x = Activation('relu')(x)
    x = Conv2D(filters,3,strides=1,padding ='same',use_bias=False,kernel_initializer="he_normal",kernel_regularizer=l1_l2(l1=l_1,l2=l_2))(x)
    
    if not x.shape.is_compatible_with(y.shape):
        y = Conv2D(filters,1,strides=strides,padding ='same',use_bias=False,kernel_initializer="he_normal",kernel_regularizer=l1_l2(l1=l_1,l2=l_2))(input)

    return Add()([x,y])


In [None]:
def res_group(input,filters,strides,n_blocks,l_1,l_2):
    x = basic_block(input,filters,strides,l_1,l_2)
    for _ in range(n_blocks-1):
        x = basic_block(x,filters,1,l_1,l_2)
    return x

In [None]:
def wide_resnet(input_shape,d,w_mult,n_classes,l_2=0,l_1=0):
    n_blocks = (d - 4) // 6
    input_shape = list(input_shape)
    ensemble_size = input_shape[0]

    input = Input(shape=input_shape)
    x = Permute([2,3,4,1])(input)

    # Reshape so that each subnetwork has 3 channels
    x = Reshape(input_shape[1:-1] + [input_shape[-1] * ensemble_size])(x)


    x = Conv2D(16,3,padding ='same',use_bias=False,kernel_initializer="he_normal",kernel_regularizer=l1_l2(l1=l_1*ensemble_size,l2=l_2*ensemble_size))(x)

    for strides, filters in zip([1, 2, 2], [16, 32, 64]):
        x = res_group(x,filters*w_mult,strides,n_blocks,l_1,l_2)

    x = BatchNormalization(momentum=0.9,epsilon=1e-5,beta_regularizer=l2(l_2),gamma_regularizer=l2(l_2))(x)
    x = Activation('relu')(x)
    x = AveragePooling2D(pool_size=8)(x)
    x = Flatten()(x)

    batch_size = tf.shape(x)[0]
    x = Dense(n_classes*ensemble_size,kernel_initializer='he_normal',activation=None,kernel_regularizer=l1_l2(l1=l_1*ensemble_size,l2=l_2*ensemble_size),bias_regularizer=l1_l2(l1=l_1*ensemble_size,l2=l_2*ensemble_size))(x)
    x = Reshape([batch_size,ensemble_size,n_classes])(x)
    
    return tf.keras.CusomModel(input=input,output=x)

In [2]:
class NLL(Loss):

    def call(self, y_true, y_pred):
        y_pred = tf.convert_to_tensor_v2(y_pred)
        y_true = tf.cast(y_true, y_pred.dtype)
        nll = tf.reduce_mean(tf.reduce_sum(tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True), axis=1))
        return nll

In [None]:
optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=0.9, nesterov=True)
model.compile(optimizer,loss = NLL())
model.fit(x_train,y_train)