In [1]:
import numpy as np
from numpy import cos,sin,pi
from numpy.random import random,randint,choice,sample
from matplotlib import pyplot as plt
import pandas as pd
import glob,cv2,time
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers,Model

In [2]:
# Dense image warp from Tensorflow addons
# There is an error using tfa.image.dense_image_warp
# define same tfa.image.dense_image_warp

from tensorflow_addons.utils import types
from typing import Optional

def _get_dim(x, idx):
    if x.shape.ndims is None:
        return tf.shape(x)[idx]
    return x.shape[idx] or tf.shape(x)[idx]

def dense_image_warp(image: types.TensorLike, displacement: types.TensorLike, name: Optional[str] = None) -> tf.Tensor:
    with tf.name_scope(name or "dense_image_warp"):
        image = tf.convert_to_tensor(image)
        displacement = tf.convert_to_tensor(displacement)
        batch_size, height, width, channels = (
            _get_dim(image, 0),
            _get_dim(image, 1),
            _get_dim(image, 2),
            _get_dim(image, 3),
        )

        # The flow is defined on the image grid. Turn the flow into a list of query
        # points in the grid space.
        grid_x, grid_y = tf.meshgrid(tf.range(width), tf.range(height))
        stacked_grid = tf.cast(tf.stack([grid_y, grid_x], axis=2), displacement.dtype)
        batched_grid = tf.expand_dims(stacked_grid, axis=0)
        query_points_on_grid = batched_grid + displacement
        query_points_flattened = tf.reshape(query_points_on_grid, [batch_size, height * width, 2])
        # Compute values at the query points, then reshape the result back to the
        # image grid.
        interpolated = tfa.image.interpolate_bilinear(image, query_points_flattened)
        interpolated = tf.reshape(interpolated, [batch_size, height, width, channels])
        return interpolated

In [3]:
def img_grad(imgs):
    #central different
    img_x,img_y = tf.image.image_gradients(imgs)
    img_x,img_y = (img_x+tf.roll(img_x,1,1))/2, (img_y+tf.roll(img_y,1,2))/2
    return img_x,img_y
def count_fold_fn(u):
    u = tf.convert_to_tensor(u)
    u_x,u_y = img_grad(u)
    det_u = (u_x[:,:,:,0]+1)*(u_y[:,:,:,1]+1)-u_x[:,:,:,1]*u_y[:,:,:,0]
    cf = tf.math.count_nonzero(det_u<=0,axis=[1,2])
    return tf.reduce_mean(cf).numpy()
def grid_warp(u,n=48):
    img_shape = u.shape[:-1]
    img = np.zeros(img_shape)
    for i in range(1,n):
        img[img_shape[0]*i//n,:]=1
        img[:,img_shape[0]*i//n]=1
    ex_img = np.expand_dims(np.expand_dims(img,-1),0)
    ex_u = np.expand_dims(u,0)
    warp = tf.squeeze(dense_image_warp(ex_img,ex_u))
    plt.imshow(warp,cmap='gray')
def vector_field(u,n=6):
    J,I = np.meshgrid(np.arange(256),np.arange(256))
    plt.gca().invert_yaxis()
    plt.quiver(J[::n,::n],I[::n,::n],u[::n,::n,1],-u[::n,::n,0],units='xy')
def show_result(moved_img,fixed_img,Ex_name):
    stack = np.expand_dims(np.stack([moved_img,fixed_img],-1),0)
    start = time.time()
    u = model(stack)
    print(time.time()-start)
    warped_img = tf.squeeze(dense_image_warp(np.expand_dims(np.expand_dims(moved_img,-1),0),u))
    nof = count_fold_fn(u)
    nof_inv = count_fold_fn(-u)
    rel_ssd = np.sum((warped_img-fixed_img)**2)/np.sum((moved_img-fixed_img)**2)
    dcs = D_DCS(warped_img,fixed_img,axis=[0,1]).numpy()
    print('folding  u=',nof)
    print('folding -u=',nof_inv)
    print('relative ssd =',rel_ssd)
    print('Dice score =',dcs)
    plt.figure(figsize=(16,8))
    plt.subplot(241)
    plt.title(Ex_name+': Moved image')
    plt.imshow(moved_img,cmap='gray')
    plt.subplot(242)
    plt.title(Ex_name+': Fixed image')
    plt.imshow(fixed_img,cmap='gray')
    plt.subplot(243)
    plt.title(Ex_name+': Warped image')
    plt.imshow(warped_img,cmap='gray')
    plt.subplot(244)
    plt.title(Ex_name+': Different image')
    plt.imshow(abs(warped_img-fixed_img),cmap='gray')
    plt.subplot(245)
    plt.title(Ex_name+': Grid warped u')
    grid_warp(u[0])
    plt.subplot(246)
    plt.title(Ex_name+': Vector field u')
    vector_field(-u[0])
    plt.subplot(247)
    plt.title(Ex_name+': Grid warped -u')
    grid_warp(-u[0])
    plt.subplot(248)
    plt.title(Ex_name+': Vector field -u')
    vector_field(u[0])
    plt.show()
    return rel_ssd,nof,dcs

# Train/Test spliting

In [4]:
#load dataset
data_set = np.load('550m2f_Diffeomorphic_Deform.npy')
train_set = data_set[:500,:,:,:]
test_set = data_set[500:550,:,:,:]
train_size = train_set.shape[0]
test_size = test_set.shape[0]
img_shape = train_set.shape[1:3]

# Loss definition

In [5]:
def D_MSE(img1,img2):
    mse = tf.reduce_mean(tf.square(img1-img2),axis=[1,2,3])
    return tf.cast(tf.reduce_mean(mse),tf.float32)

def D_SSD(img1,img2):
    ssd = tf.reduce_sum(tf.square(img1-img2),axis=[1,2,3])
    return tf.cast(tf.reduce_mean(ssd),tf.float32)

def D_CC(img1,img2,n=9):
    gs_img1 = tfa.image.gaussian_filter2d(img1,filter_shape=(n,n))
    gs_img2 = tfa.image.gaussian_filter2d(img2,filter_shape=(n,n))
    dif_img1 = img1-gs_img1
    dif_img2 = img2-gs_img2
    cc = tf.reduce_sum(dif_img1*dif_img2,axis=[1,2,3])**2
    cc = cc/(tf.reduce_sum(dif_img1**2,axis=[1,2,3])*tf.reduce_sum(dif_img2**2,axis=[1,2,3]))
    return tf.cast(tf.reduce_mean(cc),tf.float32)

def D_NGF(img1,img2,e=1):
    img1_x,img1_y = img_grad(img1)
    img2_x,img2_y = img_grad(img2)
    inner = lambda a,b: a*b+e**2
    norm_s = lambda a,b: inner(a,a)+inner(b,b)
    ngf = 1 - ((inner(img1_x,img2_x)+inner(img1_y,img2_y))**2)/(norm_s(img1_x,img1_y)*norm_s(img2_x,img2_y))
    return tf.reduce_mean(tf.reduce_sum(ngf,axis=[1,2,3]))

def D_CLM(img1,img2):
    img1 = tf.convert_to_tensor(img1)
    img2 = tf.convert_to_tensor(img2)
    b,h,w,c = img1.get_shape()
    tile_fn = lambda b_c: tf.transpose(tf.reshape(tf.tile(b_c,(h,w)),[h,b,w,c]),[1,0,2,3])
    img1_mean = tile_fn(tf.math.reduce_mean(img1,axis=[1,2]))
    img2_mean = tile_fn(tf.math.reduce_mean(img2,axis=[1,2]))
    img1_std = tile_fn(tf.math.reduce_std(img1,axis=[1,2]))
    img2_std = tile_fn(tf.math.reduce_std(img2,axis=[1,2]))
    img12_std = tile_fn(tf.math.reduce_std(img1+img2,axis=[1,2]))
    clm = tf.reduce_sum(((img1-img1_mean)/img1_std-(img2-img2_mean)/img2_std)**2,axis=(1,2,3))+(img1_std+img2_std-img12_std)**2
    return tf.cast(tf.reduce_mean(clm),tf.float32)

def D_DCS(img1,img2,axis=[1,2,3]):
    dcs = 2*tf.reduce_sum(img1*img2,axis=axis)/(tf.reduce_sum(img1,axis=axis)+tf.reduce_sum(img2,axis=axis))
    return tf.reduce_mean(dcs)
    
def R_difsn(u):
    u_x,u_y = img_grad(u)
    difsn = tf.reduce_sum(u_x**2+u_y**2,axis=[1,2,3])
    return tf.reduce_mean(difsn)

def R_tv(u,e=5e-3):
    u_x,u_y = img_grad(u)
    tv = tf.reduce_sum((u_x**2+u_y**2+e)**0.5,axis=[1,2,3])
    return tf.reduce_mean(tv)

def R_curv(u):
    u_x,u_y = tf.image.image_gradients(u) #forward different
    u_xx = u_x-tf.roll(u_x,1,1)
    u_yy = u_y-tf.roll(u_y,1,2)
    curv = tf.reduce_sum(u_xx**2+u_yy**2,axis=[1,2,3])
    return tf.reduce_mean(curv)

def R_HH(u):
    u_x,u_y = img_grad(u)
    r = (u_x[:,:,:,0]-u_y[:,:,:,1])**2+(u_x[:,:,:,1]+u_y[:,:,:,0])**2
    R1 = tf.reduce_sum(r,axis=[1,2])
    return tf.reduce_mean(R1)

def R_SA(u):
    u = tf.convert_to_tensor(u)
    u_x,u_y = img_grad(u)
    u1_x1 = u_x[:,:,:,0]
    u1_x2 = u_x[:,:,:,1]
    u2_x1 = u_y[:,:,:,0]
    u2_x2 = u_y[:,:,:,1]
    s = u1_x2*u2_x1-u1_x1*u2_x2
    a = abs(u1_x1+u2_x2)
    Rsa = tf.reduce_sum(tf.minimum(0,s+a)**2,axis=[1,2])    
    return tf.reduce_mean(Rsa)

def R_fold(u):
    u = tf.convert_to_tensor(u)
    u_x,u_y = img_grad(u)
    det_phi = (u_x[:,:,:,0]+1)*(u_y[:,:,:,1]+1)-u_x[:,:,:,1]*u_y[:,:,:,0]
    det_phi_p = tf.maximum(det_phi,1)
    fold = tf.reduce_sum((det_phi-1)**2/det_phi_p**2,axis=(1,2))
    return tf.reduce_mean(fold)

In [6]:
tf.keras.backend.clear_session()
inputs = layers.Input(shape=train_set[0].shape)
Conv2D = lambda l: layers.Conv2D(l,3,activation='relu',padding='same')
Maxpool2D = layers.MaxPool2D()
UpSampling2D = layers.UpSampling2D()
Concatenate = layers.Concatenate()
n = 3
lys = 32

x = inputs
skip = []
for i in range(n): 
    x = Conv2D(lys*2**i)(x)
    skip.append(x)
    x = Maxpool2D(x)
for j in range(n):
    x = Conv2D(lys*2**(n-j-1))(x)
    x = UpSampling2D(x)
    x = Concatenate([x,skip.pop()])
    
x = Conv2D(2*lys)(x)
x = tf.keras.layers.Conv2D(2,1,padding='same')(x)
model = Model(inputs=inputs, outputs=x,name='Unet')
model.summary()
weight_save = model.get_weights()

Model: "Unet"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 256, 256, 2) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 256, 256, 32) 608         input_1[0][0]                    
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    multiple             0           conv2d[0][0]                     
                                                                 conv2d_1[0][0]                   
                                                                 conv2d_2[0][0]                   
_______________________________________________________________________________________________

In [7]:
name = 'm1curv3'

In [8]:
def loss_fn(model,imgs):
    u = model(imgs)
    ref_img = imgs[:,:,:,1:]
    warped_img = dense_image_warp(imgs[:,:,:,:1],u)
    D = D_SSD(ref_img,warped_img)
    R = 3*R_curv(u)#+0*R_HH(u)+0*R_SA(u)
    #print('No.folding :',count_fold_fn(u))
    loss = D+R
    return loss

In [None]:
print(tf.config.experimental.get_memory_info('GPU:0')['current']/1e6)
model.set_weights(weight_save)
epochs = 100
batch_size = 10
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

loss_log = []
metric_log = []
start_train = time.time()
for epoch in range(epochs):
    if epoch<3:
        mem = tf.config.experimental.get_memory_info('GPU:0')
        cr,pk = mem['current'],mem['peak']
        print('memory current : {}, peak : {}'.format(cr/1e6,pk/1e6)) 
    loss_epoch = []
    test_epoch = []
    #train
    for i in range(0,train_size,batch_size):
        train_batch = train_set[i:i+batch_size]
        with tf.GradientTape() as tape:
            loss_batch = loss_fn(model,train_batch)
            grads = tape.gradient(loss_batch, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        loss_epoch.append(loss_batch.numpy())
    #test
    for j in range(0,test_size,batch_size):
        loss_test = loss_fn(model,test_set[j:j+batch_size])
        test_epoch.append(loss_batch.numpy())
    loss_log = np.append(loss_log,[np.mean(loss_epoch),np.mean(test_epoch)])
    
    dgt = 5 #display 'dgt' digits
    print('epoch {}, loss_train {}, loss_test {}'.format(epoch+1,np.round(loss_log[-2],dgt),np.round(loss_log[-1],dgt)))
print(time.time()-start_train)

1.854464
memory current : 2.217728, peak : 2.807552
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
epoch 1, loss_train 999.18518, loss_test 1004.42139
memory current : 13.013504, peak : 1819.036672
epoch 2, loss_train 875.30127, loss_test 832.25
memory current : 13.014016, peak : 1819.036672
epoch 3, loss_train 599.47064, loss_test 553.52557
epoch 4, loss_train 407.35437, loss_test 374.28613
epoch 5, loss_train 318.9646, loss_test 314.6228
epoch 6, loss_train 281.32462, loss_test 295.92435
epoch 7, loss_train 262.57382, loss_test 279.03217


In [None]:
loss_log = loss_log.reshape(loss_log.size//2,2)
plt.plot(loss_log[:,0],label='train')
plt.plot(loss_log[:,1],label='test')
plt.title(name)
plt.ylim(0,max(loss_log[0,:]))
plt.legend() 
plt.show()

In [None]:
ind_list = choice(test_set.shape[0],3,False)
for ind in ind_list:
    test = np.expand_dims(test_set[ind],0)
    moved_img = test[0,:,:,0]
    fixed_img = test[0,:,:,1]
    show_result(moved_img,fixed_img,name)

In [None]:
circle = cv2.resize(cv2.imread('./Images/circle.png',0),img_shape)/255
ellipse = cv2.resize(cv2.imread('./Images/ellipse.png',0),img_shape)/255
square = cv2.resize(cv2.imread('./Images/square.png',0),img_shape)/255
c_shape = cv2.resize(cv2.imread('./Images/c_shape.png',0),img_shape)/255
o_ = cv2.resize(cv2.imread('./Images/o_.png',0),img_shape)/255
R1 = cv2.resize(cv2.imread('./Images/R1.png',0),img_shape)/255
T1 = cv2.resize(cv2.imread('./Images/T1.png',0),img_shape)/255
t2_R = cv2.resize(cv2.imread('./Images/t2_R.png',0),img_shape)/255
t2_T = cv2.resize(cv2.imread('./Images/t2_T.png',0),img_shape)/255
R2 = cv2.resize(cv2.imread('./Images/R2.png',0),img_shape)/255
T2 = cv2.resize(cv2.imread('./Images/T2.png',0),img_shape)/255
A_slope = cv2.resize(cv2.imread('./Images/A_slope.png',0),img_shape)/255
R = cv2.resize(cv2.imread('./Images/R.png',0),img_shape)/255

In [None]:
display_list = [
    [ellipse,circle],
    [circle,ellipse],
    [circle,square],
    [square,circle],
    [t2_R,t2_T],
    [t2_T,t2_R],
    [o_,c_shape],
    [c_shape,o_],
    [A_slope,R],
    [R,A_slope],
    [R1,T1],
    [T1,R1],
    [R2,T2],
    [T2,R2]
    ]

In [None]:
result_list = []
for i in display_list:
    result_list.append(show_result(i[0],i[1],name))

In [None]:
np.set_printoptions(suppress=True)
result = np.array(result_list).transpose()
print(result)

In [None]:
sum_fold_dir = 0
sum_fold_inv = 0
rel_ssd_list = []
dcs_list = []
for j in range(0,test_size,batch_size):
    test_batch = test_set[j:j+batch_size,:,:,:]
    u = model(test_batch)
    moved_imgs = test_batch[:,:,:,:1]
    warped_imgs = dense_image_warp(moved_imgs,u)
    fixed_imgs = test_batch[:,:,:,1:]
    rel_ssd_list.append(tf.reduce_sum((warped_imgs-fixed_imgs)**2,axis=[1,2])/tf.reduce_sum((moved_imgs-fixed_imgs)**2,axis=[1,2]))
    
    dcs_list.append(D_DCS(warped_imgs,fixed_imgs,axis=[1,2]))
    sum_fold_dir+=count_fold_fn(u)
    sum_fold_inv+=count_fold_fn(-u)
print('rel_ssd :',np.mean(rel_ssd_list))
print('No.folding  u per img:',sum_fold_dir/test_size)
print('No.folding -u per img:',sum_fold_inv/test_size)
print('DCS :',np.mean(dcs_list))

In [None]:
np.save('./loss log/'+name+'.npy',loss_log)
model.save_weights('./weight models/'+name)