In [16]:
import tensorflow as tf
from tensorflow.keras.layers import Input,Permute, Reshape, Conv2D, BatchNormalization, Activation,Add,AveragePooling2D,Flatten,Dense
import time
from tensorflow.keras.regularizers import l2,l1_l2
from tensorflow.keras.losses import Loss
from tensorflow.keras.metrics import Mean,SparseCategoricalAccuracy
from tensorflow.keras.utils import to_categorical
import numpy as np

In [41]:
class CusomDataGen(tf.keras.utils.Sequence):

    def __init__(self,X,y,batch_size,batch_rep,inp_rep_prob,ensemble_size,input_size,shuffle=True):
        self.X = X
        self.y = y
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.ensemble_size = ensemble_size
        self.batch_rep = batch_rep
        self.inp_rep_prob = inp_rep_prob
        self.n = X.shape[0]
        self.input_size = input_size

    def on_epoch_end(self):
        if self.shuffle:
            idxs = np.arange(self.n)
            np.random.shuffle(idxs)
            self.X = self.X[idxs]
            self.y = self.y[idxs]

    def __get_data(self,imgs,labels):

        batch_rep = np.tile(np.arange(imgs.shape[0]),[self.batch_rep])
        np.random.shuffle(batch_rep)
        input_shuffle=int(batch_rep.shape[0] * (1. - self.inp_rep_prob))
        #Kan detta göras bättre?
        shuffle_idxs = [np.concatenate([np.random.permutation( batch_rep[:input_shuffle]), batch_rep[input_shuffle:]]) for _ in range(self.ensemble_size)]

        imgs = np.stack([np.take(imgs, indxs, axis=0) for indxs in shuffle_idxs], axis=1)
        labels = np.stack([np.take(labels, indxs, axis=0) for indxs in shuffle_idxs], axis=1)
        
        return imgs, labels

    def __getitem__(self, index):
        imgs = self.X[index * self.batch_size : (index + 1) * self.batch_size]
        labels = self.y[index * self.batch_size: (index + 1) * self.batch_size]
        imgs, labels = self.__get_data(imgs,labels)
        return  imgs, labels
        
    def __len__(self):
        return self.n // self.batch_size
        

In [44]:
class CustomModel(tf.keras.Model):
    def __init__(self,ensemble_size,batch_rep, inp_rep_prob,n_classes,*args, **kwargs):
        super().__init__(*args, **kwargs)
        self.batch_rep = batch_rep
        self.inp_rep_prob = inp_rep_prob
        self.ensemble_size = ensemble_size
        self.n_classes = n_classes

    def train_step(self, data):
        imgs, labels = data
        with tf.GradientTape() as tape:
            logits = self(imgs, training=True)
            loss = self.compiled_loss(
                labels,
                logits,
                regularization_losses=self.losses,
            )
        grads = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))

        #probs = tf.nn.softmax(tf.reshape(logits, [-1, self.n_classes]))
        #flat_labels = tf.reshape(labels, [-1])
        #self.compiled_metrics.update_state(flat_labels, probs)

        #self.metrics["nll"].update_state() #TODO
        #for m in self.metrics:
         #   if m.name == "accuracy":
          #      m.update_state(flat_labels,probs)

        return {m.name: m.result() for m in self.metrics}
    '''
    def test_step(self, data):
        imgs, labels = data
        imgs = tf.tile(tf.expand_dims(imgs, 1), [1, self.ensemble_size, 1, 1, 1]) #Expand to add ensemble dimension
        logits = self(imgs, training=False)
        probs = tf.nn.softmax(logits)

        for i in range(self.ensemble_size):
            member_probs = probs[:,i]
            member_loss = tf.keras.losses.sparse_categorical_crossentropy(labels, member_probs)
            self.metrics[F"nll_member_{i}"].update_state(member_loss)
            self.metrics[f"accuracy_member_{i}"].update_state(labels,member_probs)

        probs = tf.math.reduce_mean(probs, axis=1) 

        self.metrics["accuracy"].update_state(labels,probs)


        return {m.name: m.result() for m in self.metrics}
    '''

In [19]:
def basic_block(input,filters,strides,l_2):
    y = input
    x = BatchNormalization(momentum=0.9,epsilon=1e-5,beta_regularizer=l2(l_2),gamma_regularizer=l2(l_2))(input)
    x = Activation('relu')(x)
    x = Conv2D(filters,3,strides=strides,padding ='same',use_bias=False,kernel_initializer="he_normal",kernel_regularizer=l2(l_2))(x)
    x = BatchNormalization(momentum=0.9,epsilon=1e-5,beta_regularizer=l2(l_2),gamma_regularizer=l2(l_2))(x)
    x = Activation('relu')(x)
    x = Conv2D(filters,3,strides=1,padding ='same',use_bias=False,kernel_initializer="he_normal",kernel_regularizer=l2(l_2))(x)
    
    if not x.shape.is_compatible_with(y.shape):
        y = Conv2D(filters,1,strides=strides,padding ='same',use_bias=False,kernel_initializer="he_normal",kernel_regularizer=l2(l_2))(input)

    return Add()([x,y])


In [20]:
def res_group(input,filters,strides,n_blocks,l_2):
    x = basic_block(input,filters,strides,l_2)
    for _ in range(n_blocks-1):
        x = basic_block(x,filters,1,l_2)
    return x

In [21]:
def wide_resnet(input_shape,d,w_mult,n_classes,batch_rep,inp_rep_prob, l_2=0):
    n_blocks = (d - 4) // 6
    input_shape = list(input_shape)
    ensemble_size = input_shape[0]

    input = Input(shape=input_shape)

    

    x = Permute([2,3,4,1])(input)

    # Reshape so that each subnetwork has 3 channels
    x = Reshape(input_shape[1:-1] + [input_shape[-1] * ensemble_size])(x)


    x = Conv2D(16,3,padding ='same',use_bias=False,kernel_initializer="he_normal",kernel_regularizer=l2(l_2))(x)

    for strides, filters in zip([1, 2, 2], [16, 32, 64]):
        x = res_group(x,filters*w_mult,strides,n_blocks,l_2)

    x = BatchNormalization(momentum=0.9,epsilon=1e-5,beta_regularizer=l2(l_2),gamma_regularizer=l2(l_2))(x)
    x = Activation('relu')(x)
    x = AveragePooling2D(pool_size=8)(x)
    x = Flatten()(x)

    x = Dense(n_classes*ensemble_size,kernel_initializer='he_normal',activation=None,kernel_regularizer=l2(l_2),bias_regularizer=l2(l_2))(x)
    x = Reshape([ensemble_size,n_classes])(x)
    
    return CustomModel(ensemble_size, batch_rep,inp_rep_prob, n_classes,input,x)

In [42]:
class NLL(Loss):

    def call(self, y_true, y_pred):
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)
        nll = tf.reduce_mean(tf.reduce_sum(loss, axis=1))
        return nll

In [23]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()


In [24]:
x_train = (x_train.astype('float32') / 256 ) - 0.5
x_test =(x_test.astype('float32') / 256 ) - 0.5
y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)

In [25]:
[3]+ list(x_train[0].shape)

[3, 32, 32, 3]

In [26]:
ensemble_size = 3
d = 28
w_mult = 10
n_classes = 10
epochs = 250
batch_size = 64
batch_rep = 4
inp_rep_prob = 0.5
input_shape = tuple([3]+ list(x_train[0].shape))
val_split = 0.1
l_2 = 3e-4
steps_per_epoch = x_train.shape[0] * 1 - val_split


In [27]:
metrics = [Mean(name="nll"), SparseCategoricalAccuracy("accuracy")]

for i in range(ensemble_size):
    metrics.append(Mean(f"nll_member_{i}"))
    metrics.append(SparseCategoricalAccuracy(f"accuracy_member_{i}"))

In [45]:
traing_data=CusomDataGen(x_train,y_train,batch_size,batch_rep,inp_rep_prob,ensemble_size,input_shape)
model = wide_resnet(input_shape,d,w_mult,n_classes,batch_rep,inp_rep_prob, l_2)

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    0.1,
    decay_steps=steps_per_epoch,
    decay_rate=0.1)

optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=0.9, nesterov=True)

model.compile(optimizer,loss = NLL())
model.fit(traing_data,epochs=epochs)

Epoch 1/250


InvalidArgumentError:  logits and labels must have the same first dimension, got logits shape [768,10] and labels shape [7680]
	 [[node NLL/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits (defined at tmp/ipykernel_4833/3309396650.py:4) ]] [Op:__inference_train_function_64006]

Function call stack:
train_function


In [None]:
model.summary()

Model: "custom_model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 3, 32, 32, 3 0                                            
__________________________________________________________________________________________________
permute (Permute)               (None, 32, 32, 3, 3) 0           input_1[0][0]                    
__________________________________________________________________________________________________
reshape (Reshape)               (None, 32, 32, 9)    0           permute[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 32, 32, 16)   1296        reshape[0][0]                    
_______________________________________________________________________________________