In [5]:
import numpy as np
import h5py 
import skimage.transform as sc
from PIL import Image


from keras.layers import Input
from keras.layers.core import Activation, Reshape, Dropout
from keras.layers.convolutional import MaxPooling2D, UpSampling2D, Conv2D
from keras.layers.normalization import BatchNormalization
from keras.models import Model
from keras.utils import np_utils


import matplotlib.pyplot as plt
import os
import numpy as np

from keras.callbacks import ModelCheckpoint, EarlyStopping

Using TensorFlow backend.


In [2]:
# load data

def resize(arr):
    im = []
    for i in range(arr.shape[0]):    
        img = sc.resize(arr[i], (88, 88), preserve_range=True)
        im.append(img)
    
    im = np.array(im)
    return im

def load_data():
    
    data = h5py.File('LowRes_13434_overlapping_pairs.h5', 'r')   #loading data
    data = data.get('dataset_1')
    images = data[:, :, :, 0]
    labels = data[:, :, :, 1]
    print(images.shape)
    print(labels.shape)
    
    images = resize(images)
    labels = resize(labels)
        
    images = np.expand_dims(images, -1)
    labels = np.expand_dims(labels, -1)
    
    print("Images shape:{}".format(images.shape))
    print("Labels shape:{}".format(labels.shape))    
    
    images_train = images[:images.shape[0]-20]     #splitting into train and test sets
    images_test = images[images.shape[0]-20:]
    
    labels_train = labels[:labels.shape[0]-20]
    labels_test = labels[labels.shape[0]-20:]
    
    return images_train, labels_train, images_test, labels_test

In [3]:
# segNet

def to_categorical(y, nb_classes):
    num_samples = len(y)
    Y = np_utils.to_categorical(y.flatten(), nb_classes)
    return Y.reshape((num_samples, y.size // num_samples, nb_classes))

def SegNet(input_shape=(88, 88, 1), classes=4):

    img_input = Input(shape=input_shape)
    x = img_input
    
    ######### Encoder #########
    
    # Block 1
    x = Conv2D(64, (3, 3), activation='relu', padding='same', 
                                              kernel_initializer = 'he_normal')(x)
    x = BatchNormalization()(x)
    x = Conv2D(64, (3, 3), activation='relu', padding='same',
                                              kernel_initializer = 'he_normal')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2), strides=(2, 2))(x)
    x = Dropout(0.25)(x)    
    
    # Block 2
    x = Conv2D(128, (3, 3), activation='relu', padding='same', 
                                               kernel_initializer = 'he_normal')(x)
    x = BatchNormalization()(x)
    x = Conv2D(128, (3, 3), activation='relu', padding='same', 
                                               kernel_initializer = 'he_normal')(x)
    x = BatchNormalization()(x)
    x = Conv2D(128, (3, 3), activation='relu', padding='same', 
                                               kernel_initializer = 'he_normal')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2), strides=(2, 2))(x)
    x = Dropout(0.5)(x)

    # Block 3
    x = Conv2D(256, (3, 3), activation='relu', padding='same', 
                                               kernel_initializer = 'he_normal')(x)
    x = BatchNormalization()(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', 
                                               kernel_initializer = 'he_normal')(x)
    x = BatchNormalization()(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', 
                                               kernel_initializer = 'he_normal')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.5)(x)
    
    
    
    ############ Decoder ############

    # Deconv Block 1
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(128, kernel_size=3, activation='relu', padding='same', 
                                                      kernel_initializer = 'he_normal')(x)
    x = BatchNormalization()(x)
    x = Conv2D(128, kernel_size=3, activation='relu', padding='same', 
                                                      kernel_initializer = 'he_normal')(x)
    x = BatchNormalization()(x)
    x = Conv2D(128, kernel_size=3, activation='relu', padding='same', 
                                                      kernel_initializer = 'he_normal')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.5)(x)  
    
    # Deconv Block 2
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(64, kernel_size=3, activation='relu', padding='same', 
                                                     kernel_initializer = 'he_normal')(x)
    x = BatchNormalization()(x)
    x = Conv2D(64, kernel_size=3, activation='relu', padding='same', 
                                                     kernel_initializer = 'he_normal')(x)
    x = BatchNormalization()(x)
    x = Conv2D(4, kernel_size=3, activation='relu', padding='same',
                                                    kernel_initializer = 'he_normal')(x)
    x = Dropout(0.25)(x)
       
    x = Reshape((input_shape[0]*input_shape[1], classes))(x)
    x = Activation("softmax")(x)
    
    model = Model(img_input, x)
    
    return model


In [7]:
# main

X_train, y_train, X_test, y_test = load_data() 
print('Already loaded data')
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)

num_classes = 4
Y_train = to_categorical(y_train, num_classes)
model = SegNet()
model.compile(loss="categorical_crossentropy", optimizer='adam', metrics=['accuracy'])
model.summary()



model_checkpoint = ModelCheckpoint('Weights.h5', monitor='val_loss', save_best_only=True)
early_stopping = EarlyStopping(patience=2, verbose=2)


print('Fitting model...')
history = model.fit(X_train, Y_train, batch_size=16, epochs=10, validation_split=0.05, 
                    shuffle=True, callbacks=[model_checkpoint, early_stopping])




(13434, 94, 93)
(13434, 94, 93)


  warn("The default mode, 'constant', will be changed to 'reflect' in "


Images shape:(13434, 88, 88, 1)
Labels shape:(13434, 88, 88, 1)
Already loaded data
(13414, 88, 88, 1)
(13414, 88, 88, 1)
(20, 88, 88, 1)
(20, 88, 88, 1)
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 88, 88, 1)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 88, 88, 64)        640       
_________________________________________________________________
batch_normalization_1 (Batch (None, 88, 88, 64)        256       
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 88, 88, 64)        36928     
_________________________________________________________________
batch_normalization_2 (Batch (None, 88, 88, 64)        256       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 44, 44, 64)       

### 显微镜图片测试

In [1]:
def process_pics(image):
    image1 = np.array(image)
    print('image1',image1.shape)
    new_image = image1[np.newaxis,:]
    image2 = resize(new_image)
    image3 = np.expand_dims(image2, -1)
    
    return image3

In [2]:
def outputt(image3):
    Y_test = model.predict(image3, verbose=1)
    print('Y_test',Y_test.shape)
    Y_test = Y_test.reshape(1, 88, 88, 4)
    Y = np.argmax(Y_test, axis=-1)
    
    return Y

In [13]:
!pip install --upgrade --force-reinstall matplotlib==2.0.0

Collecting matplotlib==2.0.0
  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/44/28/86488fe20a146241ece4570af94b49689ab80b66a1609256a8469a0bd1b0/matplotlib-2.0.0-1-cp36-cp36m-manylinux1_x86_64.whl (14.7MB)
[K    100% |████████████████████████████████| 14.7MB 69.9MB/s ta 0:00:011�         | 10.4MB 92.3MB/s eta 0:00:01
[?25hCollecting pyparsing!=2.0.0,!=2.0.4,!=2.1.2,!=2.1.6,>=1.5.6 (from matplotlib==2.0.0)
  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/dd/d9/3ec19e966301a6e25769976999bd7bbe552016f0d32b577dc9d63d2e0c49/pyparsing-2.4.0-py2.py3-none-any.whl (62kB)
[K    100% |████████████████████████████████| 71kB 61.3MB/s ta 0:00:01
[?25hCollecting pytz (from matplotlib==2.0.0)
  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/3d/73/fe30c2daaaa0713420d0382b16fbb761409f532c56bdcc514bf7b6262bb6/pytz-2019.1-py2.py3-none-any.whl (510kB)
[K    100% |████████████████████████████████| 512kB 102.1MB/s ta 0:00:01
[?25hCollecting nu

In [17]:
from matplotlib import pyplot as plt

In [19]:
# load certain test_data
%matplotlib inline

#model.load_weights('Weights.h5')

pred_dir_1 = 'pics_black_1'
if not os.path.exists(pred_dir_1):
    os.mkdir(pred_dir_1)
    
pred_dir_test = 'pics_black'
for i in range(135):
    image = Image.open(os.path.join(pred_dir_test, str(i + 1) + '.png'))
    image = image.convert('L')
    plt.imshow(image,cmap='gray')
    plt.show()
    break
    image3 = process_pics(image)
    print('image3:',image3.shape)
    Y = outputt(image3)

    image_predict = Y[0, :, : ]
    plt.imsave(os.path.join(pred_dir_1, str(i + 2) + '_pred.png'), image_predict, cmap = 'viridis')

AttributeError: 'numpy.ndarray' object has no attribute 'mask'

<matplotlib.figure.Figure at 0x7f294b904240>