# Notebook for U-Net model to segment the sterezoom images of mice wounds over healing days
## Workflow:
- 1. [Library imports](#Import-libraries-and-sub-libs)
- 2. [Data loading](#Load-images-and-preprocess-them)
- 3. [Model assembly and training](#model-building-and-training)
- 4. [Evaluations](#Checking-results)
- 5. [Plotting layer weights](#Check-the-layer-weights)

### Import libraries and sub-libs

In [None]:
#Import whole libraries
import os
import cv2
import skimage
import sklearn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf

#From import specific sub-libs or functions/classes
from skimage.transform import resize
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from tensorflow.keras.layers import Conv2D, MaxPooling2D, UpSampling2D, Activation, MaxPool2D, Concatenate, Input, Conv2DTranspose, BatchNormalization, Dropout, Lambda
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam

### Load images and preprocess them

In [9]:
im_height = 512
im_width = 512

In [10]:
imagepath = r"/home/swatantra/Documents/Data/Image_dataset-20220916T130653Z-001/Image_dataset/image/"
maskpath = r"/home/swatantra/Documents/Data/Image_dataset-20220916T130653Z-001/Image_dataset/BINARY_MASK/"
images = os.listdir(imagepath)
masks = os.listdir(maskpath)

In [5]:
IMAGES = np.zeros((len(masks), im_height, im_width, 3), dtype=np.float32) #create an empty array for all the examples (instance,hight,width,three channels)
MASKS = np.zeros((len(masks), im_height, im_width, 1), dtype=np.float32) #create an empty array for all the examples (instance,hight,width,one channels)

## Load images and masks, resize for training, normalize and convert to floats
for n,image in enumerate(masks): #loop over the mask names and number
    
    print(n,image)
    mask =  cv2.imread("{}/{}".format(maskpath,image),0)  #read the masks with exact values
    mask =  cv2.resize(mask,(512,512))   #resize to expand the dimention of the masks
    
    for i in range(mask.shape[0]):
        for j in range(mask.shape[1]):
            if mask[i,j] >= 0.05:
                mask[i,j] = 1
            else:
                mask[i,j] = 0
    mask = np.expand_dims(mask,axis=2)

    img =  cv2.imread("{}/{}".format(imagepath,image),1) #read the colored image
    img =  cv2.cvtColor(img,cv2.COLOR_BGR2RGB)   #convert to RGB
    img =  cv2.resize(img,(512,512))   #resize the image for consistancy (there are some images with diff. dimentions)
    
    IMAGES[n,:,:] = img/255.0   #normalize and convert to floats
    MASKS[n,:,:] = mask/1.0     #normalize and convert to floats

0 290.png
1 98.png
2 517.png
3 413.png
4 436.png
5 389.png
6 460.png
7 622.png
8 302.png
9 440.png
10 381.png
11 643.png
12 99.png
13 105.png
14 617.png
15 645.png
16 258.png
17 411.png
18 427.png
19 187.png
20 186.png
21 343.png
22 318.png
23 558.png
24 467.png
25 392.png
26 599.png
27 90.png
28 92.png
29 305.png
30 117.png
31 611.png
32 332.png
33 385.png
34 159.png
35 507.png
36 232.png
37 156.png
38 226.png
39 121.png
40 261.png
41 408.png
42 384.png
43 230.png
44 612.png
45 94.png
46 637.png
47 122.png
48 358.png
49 255.png
50 329.png
51 366.png
52 266.png
53 150.png
54 355.png
55 360.png
56 555.png
57 214.png
58 562.png
59 473.png
60 428.png
61 508.png
62 353.png
63 375.png
64 541.png
65 606.png
66 519.png
67 477.png
68 648.png
69 576.png
70 359.png
71 199.png
72 629.png
73 620.png
74 110.png
75 504.png
76 250.png
77 498.png
78 568.png
79 72.png
80 130.png
81 386.png
82 589.png
83 241.png
84 419.png
85 623.png
86 430.png
87 180.png
88 278.png
89 461.png
90 342.png
91 137.png
92 4

### Model building and training

In [11]:
# define the double convolution block
def conv_block(input, num_filters):
    x = Conv2D(num_filters, 3, padding="same")(input)
    x = BatchNormalization()(x)   #Not in the original network. 
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)  #Not in the original network
    x = Activation("relu")(x)

    return x

In [12]:
# define the encoder block
def encoder_block(input, num_filters):
    x = conv_block(input, num_filters)
    p = MaxPool2D((2, 2))(x)
    return x, p  
# define the decoder block
def decoder_block_for_unet(input, skip_features, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

In [13]:
# define the assembled model
def build_unet(input_shape):
    inputs = Input(input_shape)

    s1,p1 = encoder_block(inputs,64)
    s2,p2 = encoder_block(p1,128)
    s3,p3 = encoder_block(p2,256)
    s4,p4 = encoder_block(p3,512)

    b1 = conv_block(p4,1024)

    d1 = decoder_block_for_unet(b1,s4,512)
    d2 = decoder_block_for_unet(d1,s3,256)
    d3 = decoder_block_for_unet(d2,s2,128)
    d4 = decoder_block_for_unet(d3,s1,64)

    outputs = Conv2D(1,1,padding="same",activation ="sigmoid")(d4) #sigmoid activation for binary classification since we have "wound" or "background"

    model = Model(inputs,outputs,name = "U_net")
    print(model.summary())
    return model

In [1]:
# compile the model
model=build_unet((12,12,1))
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

In [None]:
# train test split
X_train, X_test, y_train, y_test = train_test_split(IMAGES, MASKS,test_size= 0.20, random_state=2022)

In [None]:
# fit the model and save the history
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
hist = model.fit(X_train, y_train,
        epochs=100,batch_size = 3,
        validation_data=(X_test, y_test),
        callbacks =[callback])

### Checking results

In [None]:
# plot the test results
id = np.random.randint(0,len(X_test))
img = np.expand_dims(X_test[id,:,:,:],axis=0)
pred = model.predict(img).squeeze()
plt.subplot(1,2,1)
plt.imshow(img.squeeze())
plt.title('Original')
plt.subplot(1,2,2)
plt.imshow(pred)
plt.title('Prediction')
plt.show()

In [None]:
#plot the training and validation accuracy and loss at each epoch
loss = model.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, 'y', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

acc = history.history['iou_score']
#acc = history.history['accuracy']
val_acc = history.history['val_iou_score']
#val_acc = history.history['val_accuracy']

In [None]:
# plot the Intersection over Union (IoU Scores)
val_acc = history.history['val_iou_score']

plt.plot(epochs, acc, 'y', label='Training IOU')
plt.plot(epochs, val_acc, 'r', label='Validation IOU')
plt.title('Training and validation IOU')
plt.xlabel('Epochs')
plt.ylabel('IOU')
plt.legend()
plt.show()

### Check the layer weights

In [None]:
# prep the model for weights visualization
my_model = model
outputs = [layer.output for layer in my_model.layers[1:]]
model_for_visualization = Model(inputs = my_model.input, outputs = outputs)

In [None]:
input_img = img
feature_maps = model_for_visualization.predict(input_img)

In [None]:
# visualize the weights for a specific layer
layer_num = 32 #Depth of layer...
square = 8
ix = 1
for _ in range(square):
	for _ in range(square):
		# specify subplot and turn of axis
		ax = plt.subplot(square, square, ix)
		ax.set_xticks([])
		ax.set_yticks([])
		# plot filter channel in grayscale
		plt.imshow(feature_maps[layer_num][0, :, :, ix-1])
		ix += 1
# show the figure
plt.show()