In [5]:
import tensorflow as tf
from tensorflow.keras.layers import Input,Permute, Reshape, Conv2D, BatchNormalization, Activation,Add,AveragePooling2D,Flatten,Dense
import time

In [None]:
import argparse
parser = argparse.ArgumentParser(description='Parser')
parser.add_argument('epochs',type=int,default=2)
#TODO
args = parser.parse_args()

In [None]:
def basic_block(input,filters,strides):
    y = input
    x = BatchNormalization(momentum=0.9,epsilon=1e-5)(input)
    x = Activation('relu')(x)
    x = Conv2D(filters,3,strides=strides,padding ='same',use_bias=False,kernel_initializer="he_normal")(x)
    x = BatchNormalization(momentum=0.9,epsilon=1e-5)(x)
    x = Activation('relu')(x)
    x = Conv2D(filters,3,strides=1,padding ='same',use_bias=False,kernel_initializer="he_normal")(x)
    
    if not x.shape.is_compatible_with(y.shape):
        y = Conv2D(filters,1,strides=strides,padding ='same',use_bias=False,kernel_initializer="he_normal")(input)

    return Add()([x,y])


In [None]:
def res_group(input,filters,strides,n_blocks):
    x = basic_block(input,filters,strides)
    for _ in range(n_blocks):
        x = basic_block(x,filters,1)
    return x

In [None]:
def wide_resnet(input_shape,d,w_mult,n_classes):
    n_blocks = (d - 4) // 6
    input_shape = list(input_shape)
    ensemble_size = input_shape[0]

    input = Input(shape=input_shape)
    x = Permute([2,3,4,1])(input)

    # Reshape so that each subnetwork has 3 channels
    x = Reshape(input_shape[1:-1] + [input_shape[-1] * ensemble_size])

    x = Conv2D(16,3,padding ='same',use_bias=False,kernel_initializer="he_normal")(x)

    for strides, filters in zip([1, 2, 2], [16, 32, 64]):
        x = res_group(x,filters*w_mult,strides,n_blocks)

    x = BatchNormalization(momentum=0.9,epsilon=1e-5)(x)
    x = Activation('relu')(x)
    x = AveragePooling2D(pool_size=8)(x)
    x = Flatten()(x)

    #TODO Dense multi head
    batch_size = tf.shape(x)[0]
    x = Dense(n_classes*ensemble_size,kernel_initializer='he_normal',activation=None)(x)
    x = Reshape([batch_size,ensemble_size,n_classes])(x)
    
    return tf.keras.Model(input=input,output=x)

In [None]:
def cifar_10():

    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

    return


In [None]:
@tf.function
def train_step(model,data,args):
    #TODO
    imgs, labels = data
    batch_size = tf.shape(imgs)[0]
    batch_rep = tf.tile(tf.range(batch_size),[args.batch_rep])
    shuffled_batch_rep = tf.shuffle(batch_rep)
    input_shuffle=tf.cast(tf.cast(batch_size,tf.float32) * (1. - args.inp_rep_prob),tf.int32)
    #Kan detta göras bättre?
    shuffle_idxs = [tf.concat([tf.random.shuffle(shuffled_batch_rep[:input_shuffle]), input_shuffle[input_shuffle:]], axis=0) for _ in range(args.ensemble_size)]
    imgs = tf.stack([tf.gather(imgs, indxs, axis=0) for indxs in shuffle_idxs], axis=1)
    labels = tf.stack([tf.gather(labels, indxs, axis=0) for indxs in shuffle_idxs], axis=1)

    with tf.GradientTape() as tape:
        logits = model(imgs, training=True)
        nll = tf.reduce_mean(tf.reduce_sum(tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True), axis=1))

        
    

In [None]:
def train(model,dataset,args):
    steps_per_epoch = dataset.size // args.batch_size
    train_iter=iter(dataset)

    start_time = time.time()
    for epoch in range(args.epochs):
        print(f"Starting epoch {epoch}")
        start_time_epoch = time.time()

        for step in range(steps_per_epoch):
            train_step(model,next(train_iter))
            
        print(f"Epoch {epoch} took {time.time()-start_time_epoch} total elapsed time {time.time()-start_time}")