In [2]:
from model import *
from data import *
from keras.callbacks import EarlyStopping
import matplotlib.pyplot as plt

--- Train ---

In [None]:
EPOCHs = 60
steps = 200
train_batch_size = 16
val_batch_size = 16

root_path = "synthetic_image/"

aug_args = dict( #define ImageDataGenerator parameters
    rotation_range = 0.2,
    width_shift_range = 0.05,
    height_shift_range = 0.05,
    shear_range = 0.05,
    zoom_range = 0.05,
    horizontal_flip = True,
    vertical_flip = True,
    fill_mode = 'nearest'
)

val_args =dict( #define ImageDataGenerator parameters
    rotation_range = 0.2,
    width_shift_range = 0.05,
    height_shift_range = 0.05,
    shear_range = 0.05,
    zoom_range = 0.05,
    horizontal_flip = True,
    vertical_flip = True,
    fill_mode = 'nearest'
)

image_folder= root_path

train_gene = trainGenerator(batch_size=train_batch_size,aug_dict=aug_args,train_path=image_folder,
                        image_folder='trainCT',label_folder='trainLabel',
                        image_color_mode='grayscale',label_color_mode='grayscale',
                        image_save_prefix=None,label_save_prefix=None,
                        flag_multi_class=True,save_to_dir=None
                        )
val_gene = valGenerator(batch_size=val_batch_size,aug_dict=val_args,train_path=image_folder,
                        image_folder='valCT',label_folder='valLabel',
                        image_color_mode='grayscale',label_color_mode='grayscale',
                        image_save_prefix=None,label_save_prefix=None,
                        flag_multi_class=True,save_to_dir=None
                        )

model = unet(num_class=11)
lrate = LearningRateScheduler(step_decay)
# reduce_lr = ReduceLROnPlateau(monitor='val_loss', patience=10, mode='auto')

model_checkpoint = ModelCheckpoint('bone.hdf5',monitor='val_loss',verbose=1,save_best_only=True)

early_stopping = EarlyStopping(monitor='val_loss',patience=10,verbose=0) 

H = model.fit_generator(train_gene,
                              steps_per_epoch=steps,
                              epochs=EPOCHs,
                              verbose=1,
                              validation_data = val_gene,
                              validation_steps=15,
                              shuffle = True,
                              callbacks=[model_checkpoint]
                              )

In [None]:
N = EPOCHs
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, N), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, N), H.history['val_loss'], label="val_loss")

plt.title("Training/Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(loc="lower left")

In [None]:
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, N), H.history["dice_coef"], label="train_dice")
plt.plot(np.arange(0, N), H.history['val_dice_coef'], label="val_dice")

plt.title("Training/Validation Dice coef")
plt.xlabel("Epoch")
plt.ylabel("Dice Coef")
plt.legend(loc="lower left")

In [None]:
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, N), H.history["lr"], label="learning rate")

plt.title("Learning rate")
plt.xlabel("Epoch")
plt.ylabel("lr")
plt.legend(loc="lower left")

--- Test ---

In [2]:
numtestImage = 1
num_class =11

testCTPath = "synthetic_image/testCT"
testLabelPath = "synthetic_image/testLabel"
savePath = "synthetic_image/real_result/origin"

testGene = testGenerator(testCTPath,numtestImage)

generatedModel = load_model('model/origin.hdf5',custom_objects={'dice_coef_loss':dice_coef_loss,'dice_coef':dice_coef})

results = generatedModel.predict(testGene,verbose=1)
saveResult(savePath,results)


prediction = createData(savePath,'predict',numtestImage,'.tif')

groundTruth = createData(testLabelPath,None,numtestImage,'.png')


total = []
boneDice = []
for i in range(numtestImage):
    DiceScore = dice_coef(prediction[i],groundTruth[i])   
    for j in range(num_class-1):
        # use narray.copy() to copy the data itself rather than the address
        boneScore=boneDiceCalculator(prediction[i].copy(),groundTruth[i].copy(),j)
        boneDice.append(boneScore)
    boneIndex = boneDice.index(max(boneDice))
    print('The best prediction is the %dth bone'%(boneIndex+1))
    boneDice = []
    Diceinfo = DiceScore.numpy()
    Iouinfo = iou(float(Diceinfo))
    total.append(Diceinfo)
    print('Dice score for %dth image: %s\n'%(i+1,Diceinfo))

print('\nThe average dice score is %s'%(sum(total)/numtestImage))



---Dice score for 1th bone: 0.9176863181988678
---Dice score for 2th bone: 0.8801123398770916
---Dice score for 3th bone: 0.718259990153055
---Dice score for 4th bone: 0.7794639975281044
---Dice score for 5th bone: 0.2728226682112905
---Dice score for 6th bone: 0.5836653410950635
---Dice score for 7th bone: 0.6214028789850442
---Dice score for 8th bone: 0.8797653960487174
---Dice score for 9th bone: 0.8854415285230541
---Dice score for 10th bone: 0.8416969699656562
The best prediction is the 1th bone
Dice score for 1th image: 0.9197979380903744

---Dice score for 1th bone: 0.9066070542165557
---Dice score for 2th bone: 0.9018502203046801
---Dice score for 3th bone: 0.8677007300411009
---Dice score for 4th bone: 0.4014598574273265
---Dice score for 5th bone: 0.8776185231225276
---Dice score for 6th bone: 0.9507278837016062
---Dice score for 7th bone: 0.8646575348073334
---Dice score for 8th bone: 0.9232623545210724
---Dice score for 9th bone: 0.8764007368859633
---Dice score for 10th bo

---Dice score for 1th bone: 0.8883442266693344
---Dice score for 2th bone: 0.8758017493274735
---Dice score for 3th bone: 0.7267541360619126
---Dice score for 4th bone: 0.7852298418747291
---Dice score for 5th bone: 0.004987535958397206
---Dice score for 6th bone: 0.08745248173147914
---Dice score for 7th bone: 0.6578073100446298
---Dice score for 8th bone: 0.8796450362262851
---Dice score for 9th bone: 0.8514461014505929
---Dice score for 10th bone: 0.8927093738231938
The best prediction is the 10th bone
Dice score for 17th image: 0.924544813705136

---Dice score for 1th bone: 0.9103407053989316
---Dice score for 2th bone: 0.8898385565662612
---Dice score for 3th bone: 0.8918073797401644
---Dice score for 4th bone: 1.407459713823187e-08
---Dice score for 5th bone: 0.8489208640082069
---Dice score for 6th bone: 0.9415254240215518
---Dice score for 7th bone: 0.9018645734749975
---Dice score for 8th bone: 0.9108527134806926
---Dice score for 9th bone: 0.7842665081668362
---Dice score for

Load model

In [4]:
root_path = "synthetic_image/"
realGene = testGeneratorRealXRay(root_path+"realX-ray",8)

generatedModel = load_model('model/origin.hdf5',custom_objects={'dice_coef_loss':dice_coef_loss,'dice_coef':dice_coef})
results = generatedModel.predict(realGene,verbose=1)

saveResult(root_path+"real_result/origin/",results)

