In [1]:
from models import UNet
from utils import increase_batch, mini_batch
from utils import read_EM

import tensorflow as tf
import numpy as np
import time
import pdb
import matplotlib.pyplot as plt
    
x_train, t_train, x_test = read_EM("./Dataset/")

# Training CNN

In [2]:
unet = UNet(LR=1e-4, input_shape=[None, x_train.shape[1], x_train.shape[2], 1], 
            output_shape=[None, t_train.shape[1], t_train.shape[2], 1], )

# Loss
def loss(y, t):
    loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=t, logits=y)
    loss = tf.reduce_mean(loss)
    return loss
unet.optimize(loss)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [3]:
epoch = 30
batch_size = increase_batch(start=1, bound=3, rate=1e-3)
for ep in range(epoch):
    count = 0
    total_loss = 0
    start = time.time()
    for x, t in mini_batch(x_train, t_train, batch_generator=batch_size):
        count += 1
        feed_dict = {
            unet.x: np.expand_dims(x, axis=3),
            unet.t: np.expand_dims(t, axis=3),
        }
        loss, _ = sess.run([unet.loss, unet.training], feed_dict=feed_dict)
        total_loss += loss
    end = time.time()
    print("Epoch: {:<4} Loss: {:<10.9f} Time: {:<10.2f} ".format(ep, total_loss/count, end-start))

Epoch: 0    Loss: 0.486629517 Time: 28.78      
Epoch: 1    Loss: 0.351954122 Time: 19.67      
Epoch: 2    Loss: 0.311473751 Time: 19.77      
Epoch: 3    Loss: 0.300134595 Time: 19.86      
Epoch: 4    Loss: 0.294581190 Time: 19.85      
Epoch: 5    Loss: 0.290307743 Time: 19.88      
Epoch: 6    Loss: 0.304993431 Time: 19.96      
Epoch: 7    Loss: 0.281674750 Time: 19.98      
Epoch: 8    Loss: 0.281825716 Time: 19.98      
Epoch: 9    Loss: 0.271816172 Time: 19.97      
Epoch: 10   Loss: 0.260757534 Time: 19.99      
Epoch: 11   Loss: 0.253999563 Time: 19.96      
Epoch: 12   Loss: 0.251973872 Time: 19.98      
Epoch: 13   Loss: 0.239099472 Time: 19.99      
Epoch: 14   Loss: 0.240805492 Time: 19.99      
Epoch: 15   Loss: 0.235163855 Time: 19.99      
Epoch: 16   Loss: 0.230567812 Time: 19.99      
Epoch: 17   Loss: 0.232469178 Time: 20.01      
Epoch: 18   Loss: 0.226891549 Time: 19.98      
Epoch: 19   Loss: 0.224884062 Time: 20.01      
Epoch: 20   Loss: 0.220938953 Time: 20.0

KeyboardInterrupt: 

In [None]:
saver = tf.train.Saver()
save_path = saver.save(sess, "./Models/unet_small_without_1by1.ckpt")
print("Model saved in path: " + save_path)

# Test

In [None]:
#saver = tf.train.Saver()
#saver.restore(sess, './Models/unet_small.ckpt')

In [None]:
feed_dict = {
    unet.x: np.expand_dims(x_test[0:1], axis=3),
    #ae.t: np.expand_dims(t_test[0:1], axis=3),
}

y_test = sess.run(unet.y, feed_dict=feed_dict)
y_test = np.squeeze(y_test)
y_test = sess.run(tf.sigmoid(y_test))

%matplotlib inline
f, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].set_title("original", size=25)
ax[0].imshow(x_test[0], cmap="gray")
ax[1].set_title("segmented image", size=25)
ax[1].imshow(y_test, cmap="gray")
plt.show()

In [None]:
feed_dict = {
    unet.x: np.expand_dims(x_train[0:1], axis=3),
    #ae.t: np.expand_dims(t_test[0:1], axis=3),
}

y_test = sess.run(unet.y, feed_dict=feed_dict)
y_test = np.squeeze(y_test)
y_test = sess.run(tf.sigmoid(y_test))

%matplotlib inline
f, ax = plt.subplots(1, 3, figsize=(20, 10))
ax[0].set_title("original", size=25)
ax[0].imshow(x_train[0], cmap="gray")
ax[1].set_title("segmented image", size=25)
ax[1].imshow(y_test, cmap="gray")
ax[2].set_title('ground truth', size=25)
ax[2].imshow(t_train[0], cmap="gray")
plt.show()