In [2]:
import tensorflow as tf
from tensorflow.keras.layers import Input,Permute, Reshape, Conv2D, BatchNormalization, Activation,Add,AveragePooling2D,Flatten,Dense
import time
from tensorflow.keras.regularizers import l2,l1_l2
from tensorflow.keras.losses import Loss
from tensorflow.keras.metrics import Mean,SparseCategoricalAccuracy
import numpy as np

In [3]:
class DataGenerator(tf.keras.utils.Sequence):

    def __init__(self,X,y,batch_size,batch_rep,inp_rep_prob,ensemble_size,training,shuffle=True):
        self.X = X
        self.y = y
        self.batch_size = batch_size //batch_rep
        self.shuffle = shuffle
        self.ensemble_size = ensemble_size
        self.batch_rep = batch_rep
        self.inp_rep_prob = inp_rep_prob
        self.n = X.shape[0]
        self.training = training
    def on_epoch_end(self):
        if self.shuffle:
            idxs = np.arange(self.n)
            np.random.shuffle(idxs)
            self.X = self.X[idxs]
            self.y = self.y[idxs]

    def __get_train_data(self,imgs,labels):

        batch_rep = np.tile(np.arange(imgs.shape[0]),[self.batch_rep])
        np.random.shuffle(batch_rep)
        input_shuffle=int(batch_rep.shape[0] * (1. - self.inp_rep_prob))
        #Kan detta göras bättre?
        shuffle_idxs = [np.concatenate([np.random.permutation( batch_rep[:input_shuffle]), batch_rep[input_shuffle:]]) for _ in range(self.ensemble_size)]

        imgs = np.stack([np.take(imgs, indxs, axis=0) for indxs in shuffle_idxs], axis=1)
        labels = np.stack([np.take(labels, indxs, axis=0) for indxs in shuffle_idxs], axis=1)
        
        return imgs, labels

    def __get_test_data(self,imgs,labels):
        imgs = np.tile(np.expand_dims(imgs, 1), [1, self.ensemble_size, 1, 1, 1])
        
        return imgs, labels

    def __getitem__(self, index):
        imgs = self.X[index * self.batch_size : (index + 1) * self.batch_size]
        labels = self.y[index * self.batch_size: (index + 1) * self.batch_size]
        if self.training:
            imgs, labels = self.__get_train_data(imgs,labels)
        else:
            imgs,labels = self.__get_test_data(imgs,labels)
        return  imgs, labels
        
    def __len__(self):
        return self.n // self.batch_size
        

In [4]:
def basic_block(input,filters,strides,l_2):
    y = input
    x = BatchNormalization(momentum=0.9,epsilon=1e-5,beta_regularizer=l2(l_2),gamma_regularizer=l2(l_2))(input)
    x = Activation('relu')(x)
    x = Conv2D(filters,3,strides=strides,padding ='same',use_bias=False,kernel_initializer="he_normal",kernel_regularizer=l2(l_2))(x)
    x = BatchNormalization(momentum=0.9,epsilon=1e-5,beta_regularizer=l2(l_2),gamma_regularizer=l2(l_2))(x)
    x = Activation('relu')(x)
    x = Conv2D(filters,3,strides=1,padding ='same',use_bias=False,kernel_initializer="he_normal",kernel_regularizer=l2(l_2))(x)
    
    if not x.shape.is_compatible_with(y.shape):
        y = Conv2D(filters,1,strides=strides,padding ='same',use_bias=False,kernel_initializer="he_normal",kernel_regularizer=l2(l_2))(input)

    return Add()([x,y])


In [5]:
def res_group(input,filters,strides,n_blocks,l_2):
    x = basic_block(input,filters,strides,l_2)
    for _ in range(n_blocks-1):
        x = basic_block(x,filters,1,l_2)
    return x

In [6]:
def wide_resnet(input_shape,d,w_mult,n_classes,l_2=0):
    n_blocks = (d - 4) // 6
    input_shape = list(input_shape)
    ensemble_size = input_shape[0]

    input = Input(shape=input_shape)

    x = Permute([2,3,4,1])(input)

    # Reshape so that each subnetwork has 3 channels
    x = Reshape(input_shape[1:-1] + [input_shape[-1] * ensemble_size])(x)

    x = Conv2D(16,3,padding ='same',use_bias=False,kernel_initializer="he_normal",kernel_regularizer=l2(l_2))(x)

    for strides, filters in zip([1, 2, 2], [16, 32, 64]):
        x = res_group(x,filters*w_mult,strides,n_blocks,l_2)

    x = BatchNormalization(momentum=0.9,epsilon=1e-5,beta_regularizer=l2(l_2),gamma_regularizer=l2(l_2))(x)
    x = Activation('relu')(x)
    x = AveragePooling2D(pool_size=8)(x)
    x = Flatten()(x)

    x = Dense(n_classes*ensemble_size,kernel_initializer='he_normal',activation=None,kernel_regularizer=l2(l_2),bias_regularizer=l2(l_2))(x)
    x = Reshape([ensemble_size,n_classes])(x)
    
    return tf.keras.Model(input,x)

In [7]:
class NLL(Loss):
    def call(self, y_true, y_pred):
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)
        nll = tf.reduce_mean(tf.reduce_sum(loss, axis=1))
        return nll

In [8]:
class Accuracy(tf.keras.metrics.Metric):

    def __init__(self, n_classes,name="Accuracy", dtype=None, **kwargs):
        super().__init__(name=name, dtype=dtype, **kwargs)
        self.accuracy = self.add_weight(name='acc', initializer='zeros')
        self.n_classes = n_classes
        self.training = tf.Variable(True)

    def update_state(self, labels,logits):
        probs = tf.nn.softmax(tf.reshape(logits, [-1, self.n_classes]))
        if self.training:
            labels = tf.reshape(labels, [-1])
        else:
            probs = tf.math.reduce_mean(probs, axis=1) 

        accuracy = tf.keras.metrics.sparse_categorical_accuracy(labels,probs)
        self.accuracy.assign_add(accuracy)

    def result(self):
        return sefl.accuracy

In [9]:
class ToggleMetrics(tf.keras.callbacks.Callback):
    '''On test begin (i.e. when evaluate() is called or 
     validation data is run during fit()) toggle metric flag '''
    def on_test_begin(self, logs):
        for metric in self.model.metrics:
            if 'Accuracy' in metric.name:
                metric.on.assign(False)
    def on_test_end(self,  logs):
        for metric in self.model.metrics:
            if 'Accuracy' in metric.name:
                metric.on.assign(True)

In [10]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = (x_train.astype('float32') / 256 ) - 0.5
x_test =(x_test.astype('float32') / 256 ) - 0.5

In [11]:
ensemble_size = 3
d = 28
w_mult = 10
n_classes = 10
epochs = 250
batch_size = 128
batch_rep = 4
inp_rep_prob = 0.5
input_shape = tuple([3]+ list(x_train[0].shape))
val_split = 0.1
l_2 = 3e-4
steps_per_epoch = x_train.shape[0] * 1 - val_split


In [12]:
traing_data=DataGenerator(x_train,y_train,batch_size,batch_rep,inp_rep_prob,ensemble_size,True)

model = wide_resnet(input_shape,d,w_mult,n_classes, l_2)

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    0.1,
    decay_steps=steps_per_epoch,
    decay_rate=0.1)

optimizer = tf.keras.optimizers.SGD(lr_schedule)

model.compile(optimizer,loss = NLL())
#model.fit(traing_data,epochs=epochs)

2021-11-26 16:37:41.778279: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-11-26 16:37:41.778547: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-11-26 16:37:41.780556: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.


In [13]:
test_data=DataGenerator(x_test,y_test,batch_size,batch_rep,inp_rep_prob,ensemble_size,False)

In [14]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 3, 32, 32, 3 0                                            
__________________________________________________________________________________________________
permute (Permute)               (None, 32, 32, 3, 3) 0           input_1[0][0]                    
__________________________________________________________________________________________________
reshape (Reshape)               (None, 32, 32, 9)    0           permute[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 32, 32, 16)   1296        reshape[0][0]                    
______________________________________________________________________________________________

In [20]:
test=np.random.uniform(size=[2,3,10])

In [18]:
from scipy.special import softmax


In [30]:
probs = softmax(test,axis=2)
np.sum(probs,axis=2)
probs,np.mean(probs,axis=1)
accuracy = tf.keras.metrics.sparse_categorical_accuracy(labels,probs) #Funkar detta?


(2, 10)

In [None]:
probs = softmax(pred)
    probs = np.mean(probs,axis=1)
    accuracy = tf.keras.metrics.sparse_categorical_accuracy(labels,probs) #Funkar detta?