In [None]:
from __future__ import absolute_import
from __future__ import print_function
import os
import numpy as np
from keras.utils import np_utils
from keras.applications import imagenet_utils


########################
from keras.models import Model
from keras.layers import Input
from keras.layers.core import Activation, Reshape
from keras.layers import BatchNormalization
import tensorflow as tf
from torch.nn import MaxUnpool3d
from keras.layers import Conv3D, MaxPooling3D, concatenate, UpSampling3D


def SegNet(input_shape, classes):
    kernel=(3, 3, 3)
    pool_size=(2, 2, 2)
    output_mode="softmax"
    
    img_input = Input(shape=input_shape)
    x = img_input
    # Encoder
    x = Conv3D(64, kernel, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    pool_1 = MaxPooling3D(pool_size=pool_size)(x)
    
    x = Conv3D(128, kernel, padding="same")(pool_1)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    pool_2 = MaxPooling3D(pool_size=pool_size)(x)
    
    x = Conv3D(256, kernel, padding="same")(pool_2)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    pool_3 = MaxPooling3D(pool_size=pool_size)(x)
    
    x = Conv3D(512,kernel, padding="same")(pool_3)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    # Decoder
    x = Conv3D(512, kernel, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    x = UpSampling3D(size=pool_size)(x)
    x = Conv3D(256, kernel, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    x = UpSampling3D(size=pool_size)(x)
    x = Conv3D(128, kernel, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    x = UpSampling3D(size=pool_size)(x)
    x = Conv3D(64, kernel, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    x = Conv3D(classes, 1, 1, padding="valid")(x)
    #x = Reshape((input_shape[0]*input_shape[1]*input_shape[2], classes))(x)
    x = Activation("softmax")(x)
    model = Model(img_input, x)


    return model



model = SegNet(input_shape=(128,128,128,3), classes=4)

model.summary()
print(model.input_shape)
print(model.output_shape)