In [14]:
import os
import numpy as np
import cv2
from glob import glob
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger, TensorBoard
from tensorflow.keras.utils import CustomObjectScope
from tqdm import tqdm

print(tf.test.is_built_with_cuda())

##data.py
##read the images

imsize = 512       #image size
bridge_layer = 512 # bridge layer fileters in U-NET 
num_filters = [32, 64, 128, 256] # Filters in U-net
training = 1 # program should only carry out testting

def load_data(path, split=0.1):
    images = sorted(glob(os.path.join(path,"./BCSS-master/images/*")))
    masks = sorted(glob(os.path.join(path,"./BCSS-master/masks_tumor/*")))

    total_size = len(images)
    valid_size = int(split * total_size)
    test_size = int(split * total_size)

    train_x, valid_x = train_test_split(images, test_size=valid_size, random_state=42)
    train_y, valid_y = train_test_split(masks, test_size=valid_size, random_state=42)

    train_x, test_x = train_test_split(train_x, test_size=test_size, random_state=42)
    train_y, test_y = train_test_split(train_y, test_size=test_size, random_state=42)

    return (train_x, train_y), (valid_x, valid_y), (test_x, test_y)

def read_image(path):
    path = path.decode()
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    x = cv2.resize(x,(imsize, imsize))
    x = x/255.0
    return x

def read_image_test(path):
    #path = path.decode()
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    x = cv2.resize(x,(imsize, imsize))
    x = x/255.0
    return x

def read_mask(path):
    path = path.decode()
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    x = cv2.resize(x, (imsize, imsize))
    x = x/255.0
#     x = np.expand_dims(x, axis=-1)
    return x

def read_mask_test(path):
    #path = path.decode()
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    x = cv2.resize(x, (imsize, imsize))
    x = x/255.0
#     x = np.expand_dims(x, axis=-1)
    return x

def myfun(img):
    img = np.squeeze(img)

    return img

def mask_parse(mask):
    mask = np.squeeze(mask)
    mask = [mask, mask, mask]
    mask = np.transpose(mask, (1, 2, 0))
    return mask

def tf_parse(x, y):
    def _parse(x, y):
        x = read_image(x)
        y = read_mask(y)
        return x, y

    x, y = tf.numpy_function(_parse, [x, y], [tf.float64, tf.float64])
    x.set_shape([imsize, imsize, 3])
    y.set_shape([imsize, imsize, 3])
    return x, y

def tf_dataset(x, y, batch=16):
    dataset = tf.data.Dataset.from_tensor_slices((x, y))
    dataset = dataset.map(tf_parse)
    dataset = dataset.batch(batch)
    dataset = dataset.repeat()
    return dataset

#model.py
#building the u-net arch



def conv_block(x, num_filters):
    x = Conv2D(num_filters, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

def build_model():
    size = imsize
    
    inputs = Input((size, size, 3))

    skip_x = []
    x = inputs
    ## Encoder
    for f in num_filters:
        x = conv_block(x, f)
        skip_x.append(x)
        x = MaxPool2D((2, 2))(x)

    ## Bridge
    x = conv_block(x, bridge_layer)

    num_filters.reverse()
    skip_x.reverse()
    ## Decoder
    for i, f in enumerate(num_filters):
        x = UpSampling2D((2, 2))(x)
        xs = skip_x[i]
        x = Concatenate()([x, xs])
        x = conv_block(x, f)

    ## Output
    x = Conv2D(3, (3, 3), padding="same")(x)
    x = Activation("sigmoid")(x)
    return Model(inputs, x)

##Train.py



def iou(y_true, y_pred):
    def f(y_true, y_pred):
        intersection = (y_true * y_pred).sum()
        union = y_true.sum() + y_pred.sum() - intersection
        x = (intersection + 1e-15) / (union + 1e-15)
        x = x.astype(np.float32)
        return x
    
    return tf.numpy_function(f, [y_true, y_pred], tf.float32)



True


In [15]:
if __name__ == "__main__":
    ## Dataset
   # np.random.seed(42);
  #np.random.set_seed(42);
    #os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    #config = tf.compat.v1.ConfigProto()
    #config.gpu_options.allow_growth = True
    #config = tf.compat.v1.ConfigProto(device_count={'GPU': 0})
    #sess = tf.compat.v1.Session(config=config)
    path = "."
    (train_x, train_y), (valid_x, valid_y), (test_x, test_y) = load_data(path)

    ## Hyperparameters
    batch = 8
    lr = 1e-3
    epochs = 100

    train_dataset = tf_dataset(train_x, train_y, batch=batch)
    valid_dataset = tf_dataset(valid_x, valid_y, batch=batch)
    test_dataset = tf_dataset(test_x, test_y, batch=batch)

    if os.path.isfile("./BCSS-master/files/model.h5"):
        with CustomObjectScope({'iou': iou}):
            model = tf.keras.models.load_model("./BCSS-master/files/model.h5")
    else:
        model = build_model()		
    print("Model Buid....")
    model.summary()
    
    opt = tf.keras.optimizers.Adam(lr)
    metrics = ["acc", tf.keras.metrics.Recall(), tf.keras.metrics.Precision(), iou]
    model.compile(loss="mse", optimizer=opt, metrics=metrics)

    callbacks = [
        ModelCheckpoint("./BCSS-master/files/model.h5"),
        ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=4),
        CSVLogger("./BCSS-master/files/data.csv"),
        TensorBoard(),
        #EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=False)
    ]

    train_steps = len(train_x)//batch
    valid_steps = len(valid_x)//batch
    test_steps = len(test_x)//batch
    if len(train_x) % batch != 0:
        train_steps += 1
    if len(valid_x) % batch != 0:
        valid_steps += 1
    if len(test_x) % batch != 0:
        test_steps += 1  
    print("Training ....")
#     print(train_dataset)
#     print(valid_dataset)
#     print(train_steps)
#     print(valid_steps)
#     print(callbacks)
    if training == 1:
        model.fit(train_dataset,
            validation_data=valid_dataset,
            epochs=epochs,
            steps_per_epoch=train_steps,
            validation_steps=valid_steps,
            callbacks=callbacks)
    evalcallbacks = [
        CSVLogger("./BCSS-master/files/eval.csv")
    ]

    print("Evaluating model\n")
    model.evaluate(test_dataset, steps=test_steps, callbacks=evalcallbacks)
    print("Evaluating model ... complete\n")

    for i, (x, y) in tqdm(enumerate(zip(test_x, test_y)), total=len(test_x)):
        x = read_image_test(x)
        y = read_mask_test(y)
        y_pred = model.predict(np.expand_dims(x, axis=0))
        #cv2.imwrite("./results/"+str(i)+".png", y_pred)
        h, w, _ = x.shape
        white_line = np.ones((h, 5, 3)) * 255.0

        #print("Printing sizes\n")
        #print(mask_parse(y_pred).shape)
        #print(mask_parse(y).shape)
        # print(x.shape)
        # print(myfun(y_pred).shape)
        all_images = [
            x * 255.0, white_line,
            y*255.0, white_line,
            myfun(y_pred) * 255.0
        ]
        image = np.concatenate(all_images, axis=1)
        cv2.imwrite(f"results/{i}.png", image)
    print(type(image))

Model Buid....
Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            [(None, 512, 512, 3) 0                                            
__________________________________________________________________________________________________
conv2d_57 (Conv2D)              (None, 512, 512, 32) 896         input_4[0][0]                    
__________________________________________________________________________________________________
batch_normalization_54 (BatchNo (None, 512, 512, 32) 128         conv2d_57[0][0]                  
__________________________________________________________________________________________________
activation_57 (Activation)      (None, 512, 512, 32) 0           batch_normalization_54[0][0]     
_____________________________________________________________________________

100%|██████████| 48/48 [00:11<00:00,  4.34it/s]

<class 'numpy.ndarray'>



