In [None]:

from typing import List
def ConvBlock(out_channels:int,
             kernel_size:int=3,
             stride:tuple=(1,1),
             padding:str="same",
             activation:str=None,
             initializer:str="he_normal",
             bn:bool=False,
             dropout_rate:float=0,
             pooling=None,
            x:"Input"=None): # ConvBlock for backbone feature extraction.
    assert x is not None, "Did not provide input to the layer. Won't be able to build computational graph."
    x = Conv2D(out_channels,
                    kernel_size,
                    activation=activation,
                    padding=padding,
                    kernel_initializer=initializer,
                    use_bias=not bn,
                    )(x)
    if bn:
        x = BatchNormalization()(x)
    if dropout_rate>0:
        x = Dropout(dropout_rate)(x)
    if pooling:
        assert pooling in POOLS.keys(), "Not valid pooling method."
        x = POOLS[pooling](pool_size=(2, 2))(x)
    return x
    
def OutputBlock(x,
                n_classes:int,
                hidden_nodes:List[int]=None,
                hidden_activation:str="relu"):
    x = Flatten()(x)
    if hidden_nodes:
        for i in range(len(hidden_nodes)):
            x = Dense(hidden_nodes[i],activation=hidden_activation)(x)
    y = Dense(n_classes,activation="softmax")(x)
    return y
    
def ResBlock(out_channels:int,
             kernel_size:int=3,
             stride:tuple=(1,1),
             dropout_rate:float=0, # Dropout between main line segments.
             depth = 2, # Depth of main line.
            x:"Input"=None):
    act = Activation("relu")
    skip = Conv2D(out_channels,
                    1,
                    activation=None,
                    padding="same",
                    kernel_initializer="he_normal",
                    use_bias=False,
                    )(x)
    for i in range(depth): 
        x = Conv2D(out_channels,
                    kernel_size,
                    activation=None,
                    padding="same",
                    kernel_initializer="he_normal",
                    use_bias=False,
                    )(x)
        x = BatchNormalization()(x)
        x = act(x)
    y = act(skip+x)
    return y
def build_resnet(height, width, channels):
    inp = Input(shape=(height, width, channels), name='input_1')
    act = Activation("relu")
    drop = 0.8
    initial_filters = 16
    x = Conv2D(filters=initial_filters,kernel_size=7,strides=(2,2),padding="same",use_bias=False)(inp)
    x = BatchNormalization()(x)
    x = MaxPooling2D(pool_size=(2,2))(x)
    x = Dropout(drop)(x)
    # Layer 1.
    x = ResBlock(initial_filters*2,x=x,depth=1)
    x = MaxPooling2D(pool_size=(2,2))(x)
    # Layer 2
    x = ResBlock(initial_filters*4,x=x,depth=1)
    x = MaxPooling2D(pool_size=(2,2))(x)
    x = Dropout(drop)(x)
    # Layer 3
    x = ResBlock(initial_filters*8,x=x,depth=1)
    x = MaxPooling2D(pool_size=(2,2))(x)
    x = Dropout(drop)(x)
    # Layer 4
    x = ResBlock(initial_filters*8,x=x,depth=1)
    x = MaxPooling2D(pool_size=(2,2))(x)
    x = Dropout(drop)(x)
    # Layer 5
    x = ResBlock(initial_filters*16,x=x,depth=1)
    x = MaxPooling2D(pool_size=(2,2))(x)
    # Layer 6
    x = ResBlock(initial_filters*16,x=x,depth=1)
    x = MaxPooling2D(pool_size=(2,2))(x)
    # FC layer
    flat = Flatten()(x)
    flat = Dense(128, activation='relu')(flat)
    flat = Dense(32, activation='relu')(flat)
    y = Dense(4, activation='softmax')(flat)

    return Model(inputs=[inp], outputs=[y])

def confusion(model,gen_data, break_point:int=None):
    # Comparing Confusion matrix from Scikit with 
    preds = []
    ls = []
    labels = np.concatenate((t1_label,t1ce_label,t2_label, flair_label), axis=0)
    # Test Validation
    for idx, (t1, t1ce,t2,flair) in enumerate(gen_data):
        images = np.concatenate((t1, t1ce, t2, flair), axis=0)
        ls.append(labels.argmax(1))
        pred = model.predict_on_batch(images)
        preds.append(pred.argmax(1))
        if break_point and idx>=break_point-1:
            break
    result = confusion_matrix(np.array(ls).flatten(), np.array(preds).flatten(), normalize='pred')
    cm_display = ConfusionMatrixDisplay(confusion_matrix = result, display_labels = array_labels)
    return np.array(preds),np.array(ls),result,cm_display
def test_model(model):
    train_preds,train_labels,cm_train,cm_disp_train = confusion(model,gen_train,len(gen_val))
    print(f"Training set - Accuracy: {mean_average_accuracy(cm_train)}")
    cm_disp_train.plot()
    plt.plot()
    val_preds,val_labels,cm_val,cm_disp_val = confusion(model,gen_val)
    print(f"Validation set - Accuracy: {mean_average_accuracy(cm_val)}")
    cm_disp_val.plot()
    plt.plot()
    test_preds,test_labels,cm_test,cm_disp_test = confusion(model,gen_test)
    print(f"Test set - Accuracy: {mean_average_accuracy(cm_test)}")
    cm_disp_test.plot()
    plt.plot()