In [None]:
import pandas as pd
from sklearn.preprocessing import LabelBinarizer
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import numpy as np
import glob
import os
import cv2
import argparse
import math
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.python.keras as kr
from tensorflow.python.keras import backend as K
#from skimage import io, transform, img_as_float 
os.environ['CUDA_VISIBLE_DEVICES'] = '/cpu:0'
#from tensorflow.compat.v1 import ConfigProto
#from tensorflow.compat.v1 import InteractiveSession
#config = ConfigProto()
#config.gpu_options.allow_growth = True
#session = InteractiveSession(config=config)



image_height = 128
image_width = 128
T = 0.5

def train_id_to_path(x):
    return 'SAXSdata/IPP/' + str(x) + ".tif"
#def test_id_to_path(x):
    #return 'SAXSdata/' + x + ".tif"
    
def custom_rgb_to_grayscale(image, weights):  
    gray_image = tf.reduce_mean(tf.cast(image, tf.float32) * weights, axis=-1)  
    return gray_image

def path_to_eagertensor(image_path):
    image = cv2.imread(image_path)
    image = tf.image.rgb_to_grayscale(image)
    #f = np.fft.fft2(image)
    #fshift = np.fft.fftshift(f)
    #magnitude_spectrum = np.log(np.abs(fshift)) 
    image = tf.cast(image, tf.float32) / 255
    image = tf.image.resize(image, (image_height, image_width))
    return image

def path_to_eagertensor2(image_path):
    image = cv2.imread(image_path)
    image = tf.image.resize(image, (image_height, image_width))
    magnitude_spectrum = np.array(tf.image.rgb_to_grayscale(image))
    image = (magnitude_spectrum - magnitude_spectrum.min()) * (1 / (magnitude_spectrum.max() - magnitude_spectrum.min()))
    #image = cv2.normalize(image, resultimage, 0, 255, cv2.NORM_MINMAX)
    #image = custom_rgb_to_grayscale(image, weights)
    #f = np.fft.fft2(image)
    #fshift = np.fft.fftshift(f)
    #magnitude_spectrum = np.log(np.abs(fshift)) 
    #image = tf.cast(image, tf.float32) / 225
    #image1 = tf.image.central_crop(image, 0.4)
    #mask = tf.ones_like(image1)
    #image1 = tf.image.resize_with_crop_or_pad(image1, image_height, image_width)
    return image

def int_to_float(image_path):
    image = io.imread(image_path)  #读取图像为整型， [0-255]
    image = img_as_float(image)  #变为浮点型[0-1]。
    f = np.fft.fft2(image)
    fshift = np.fft.fftshift(f)
    magnitude_spectrum = 20*np.log(np.abs(fshift))
    images = (magnitude_spectrum - magnitude_spectrum.min()) * (1 / (magnitude_spectrum.max() - magnitude_spectrum.min()))  #比例缩放的归一化
    images = transform.resize(images, (image_height, image_width))  #图像缩放大小
    return images
    
def plot_predictions(y_true, y_end, y_pred1, y_pred2):    
    f, ax = plt.subplots(2, 2, figsize=(10, 10))
    ax[0][0].imshow(np.reshape(y_true, (128, 128)), aspect='auto')
    ax[1][0].imshow(np.reshape(y_pred1, (128, 128)), aspect='auto',cmap = 'gray')
    ax[0][1].imshow(np.reshape(y_end, (2, 2)), aspect='auto')
    ax[1][1].imshow(np.reshape(y_pred2, (128, 128)), aspect='auto',cmap = 'gray')
    plt.tight_layout()
    
def plot_predictions2(y_enc):  
    f, ax = plt.subplots(1, 1, figsize=(7, 6))
    plt.imshow(np.reshape(y_enc, (128, 128)), aspect='auto',cmap='gray')
    plt.xticks(fontsize = 20)
    plt.yticks(fontsize = 20)
    cb=plt.colorbar()
    cb.ax.tick_params(labelsize=20)
    cb.ax.set_xlabel('gray', size=20,fontproperties="Arial")
    #plt.clim(0,1)
    plt.tight_layout()
    plt.savefig('SAXSdata/IPP/1.jpg')
    
def ff_propagation(image):
    #with tf.compat.v1.Session() as sess:
    #    sess.run(tf.compat.v1.global_variables_initializer())
    #    images = sess.run(self.outputs)
    #print(self.outputs, type(self.outputs))
    image = tf.cast(image, tf.complex64)
    f = tf.signal.fft3d(image)
    fshift = tf.signal.fftshift(f)
    #magnitude_spectrum = tf.math.log(tf.math.abs(fshift))
    magnitude_spectrum = tf.math.log(tf.math.abs(fshift))
    intensity = tf.cast(magnitude_spectrum, tf.float32)
    intensity = (intensity - tf.reduce_min(intensity)) * (1 / (tf.reduce_max(intensity) - tf.reduce_min(intensity)))
    #intensity = tf.image.resize(intensity, (64, 64))
    return intensity
    
def combine_complex(amp, phi):
    output = tf.cast(amp, tf.complex64) * tf.exp(
        1j * tf.cast(phi, tf.complex64))
    return output

def get_mask1(input):
    mask = tf.where(input >= T, tf.ones_like(input),
                    tf.zeros_like(input))
    return mask
    
def get_mask2(input):
    mask = tf.where(input < T, tf.ones_like(input),
                    tf.zeros_like(input))
    return mask

def getdata():
    cols = ["id"]
    df = pd.read_csv("SAXSdata/IPP/data.dat", sep=" ", header=None, names=cols)
    df["img_path"] = df["id"].apply(train_id_to_path)
    X = []
    Y = df[["id"]]
   
    for img in df['img_path']:
        new_img_tensor = path_to_eagertensor2(img)
        X.append(new_img_tensor)
    
    X = np.array(X)
    print(type(X),X.shape)
    print(type(Y),Y.shape)
    x_train = X
    y_train = Y
    #x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.05, random_state=42)
    print(type(x_train),x_train.shape)
    print(type(y_train),y_train.shape)
    return x_train, y_train

In [None]:
class TimingCallback(kr.callbacks.Callback):
    def _init_(self, logs={}):
        self.logs=[]
    def on_train_begin(self, logs={}):
        self.logs=[]
    def on_epoch_begin(self, epoch, logs={}):
        self.starttime = timer()
    def on_epoch_end(self, epoch, logs={}):
        self.logs.append(timer()-self.starttime)
        
class checkpointCallback(kr.callbacks.Callback):
    def _init_(self, logs={}):
        self.logs=[]
    def on_train_begin(self, logs={}):
        self.logs=[]
        self.save_freq = 50
        self._epoches_seen_since_last_saving = 0
        self._last_epoch_seen = 0
    def on_epoch_begin(self, epoch, logs={}):
        self.starttime = timer()
        
    def _should_save_on_batch(self, epoch):
        """Handles batch-level saving logic, supports steps_per_execution."""
        if epoch <= self._last_epoch_seen:  # New epoch.
            add_batches = epoch + 1  # batches are zero-indexed.
        else:
            add_batches = epoch - self._last_epoch_seen
        self._epoches_seen_since_last_saving += add_batches
        self._last_epoch_seen = epoch

        if self._epoches_seen_since_last_saving >= self.save_freq:
            self._epoches_seen_since_last_saving = 0
            return True
        return False
    
    def on_epoch_end(self, epoch, logs={}):
        self.logs.append(timer()-self.starttime)
        #self.model.predict(y_true)
        if self._should_save_on_batch(epoch):
            y_pred2 = end.decoder().predict(y_true)
            plt.imsave('SAXSdata/IPP/1/1-' + str(sum(self.logs)) + '.jpg', np.reshape(y_pred2, (128, 128)), cmap='gray')
    
cb = checkpointCallback()

In [None]:
class VAED2():
    def __init__(self):

        self.input_dim = (128,128,1)
        self.input_latent = 128*128
        self.latent_dim = 64
        
        self.inputs = kr.Input(shape=self.input_dim)        
        # generate latent vector Q(z|X)
        x = kr.layers.Conv2D(64, (3, 3), strides=2, padding='same')(self.inputs)
        x = kr.activations.relu(x,alpha=0.05)
        x = kr.layers.BatchNormalization()(x)
        x = kr.layers.MaxPooling2D((2, 2))(x)
        x = kr.layers.Conv2D(32, (3, 3), strides=2, padding='same')(x)
        x = kr.activations.relu(x,alpha=0.05)
        x = kr.layers.BatchNormalization()(x)
        x = kr.layers.MaxPooling2D((2, 2))(x)
        x = kr.layers.Conv2D(16, (3, 3), strides=2, padding='same')(x)
        x = kr.activations.relu(x,alpha=0.05)
        x = kr.layers.BatchNormalization()(x)
        x = kr.layers.MaxPooling2D((2, 2))(x)
        x = kr.layers.Conv2D(1, (3, 3), strides=1, padding='same')(x)
        x = kr.activations.relu(x,alpha=0.05)
        x = kr.layers.BatchNormalization()(x)

        #x = kr.layers.Dense(16, activation='relu')(x)
        self.z_mean = x
        #self.z_log_var = kr.layers.Dense(self.latent_dim)(x)
        #self.z = kr.layers.Lambda(self.sampling, output_shape=(self.latent_dim,))([self.z_mean, self.z_log_var])
        
        # build decoder model
        #latent_inputs = kr.Input(shape=(latent_dim,), name='z_sampling')
        #self.dec1 = kr.layers.Dense(64, activation='relu')
        #x = kr.layers.Dense(64, activation='relu')(self.z_mean)
        x = kr.layers.Conv2DTranspose(16, (3, 3), strides=2, padding='same')(self.z_mean)
        x = kr.activations.relu(x,alpha=0.05)
        x = kr.layers.BatchNormalization()(x)
        x = kr.layers.UpSampling2D((2,2))(x)
        x = kr.layers.Conv2DTranspose(16, (3, 3), strides=2, padding='same')(x)
        x = kr.activations.relu(x,alpha=0.05)
        x = kr.layers.BatchNormalization()(x)
        x = kr.layers.UpSampling2D((2,2))(x)
        x = kr.layers.Conv2DTranspose(32, (3, 3), strides=2, padding='same')(x)
        x = kr.activations.relu(x,alpha=0.05)
        x = kr.layers.BatchNormalization()(x)
        x = kr.layers.UpSampling2D((2,2))(x)
        x = kr.layers.Conv2DTranspose(64, (3, 3), strides=1, padding='same')(x)
        x = kr.activations.relu(x,alpha=0.05)
        x = kr.layers.BatchNormalization()(x)
        #x = kr.layers.UpSampling2D((2,2))(x)
        #x = kr.layers.Conv2DTranspose(64, (3, 3), strides=1, padding='same')(x)
        #x = kr.activations.relu(x,alpha=0.05)
        #x = kr.layers.BatchNormalization()(x)
        #x = kr.layers.Conv2DTranspose(128, (3, 3), strides=2, padding='same')(x)
        #x = kr.activations.relu(x,alpha=0.05)
        #x = kr.layers.BatchNormalization()(x)
        self.dec_out = kr.layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same', name='decoder_output')(x)
        
        # forward propagation
        # far-field propagation to get the diff
        self.Psi = kr.layers.Lambda(lambda x: ff_propagation(x), name='farfield_diff')(self.dec_out)
        
        #x = kr.layers.Resizing(64,64)(self.Psi)
        
        self.outputs = self.Psi 
        
    # sampling function
    def sampling(self, args):
        z_mean, z_log_var = args
        nd = K.shape(z_mean)[0]
        nc = self.latent_dim
        eps = K.random_normal(shape=(nd, nc), mean=0., stddev=1.0)
        return z_mean + K.exp(z_log_var / 2) * eps

    def vae(self):
        return kr.Model(self.inputs, self.outputs)
    
    def encoder(self):
        return kr.Model(self.inputs, self.z_mean)
    
    def decoder(self): 
        return kr.Model(self.inputs, self.dec_out)

    def loss(self):
        mse = kr.metrics.binary_crossentropy(K.flatten(self.inputs), K.flatten(self.outputs))
        xent_loss = self.input_latent * mse
        kl = 1 + self.z_log_var - K.square(self.z_mean) - K.exp(self.z_log_var)
        kl_loss = - 0.5 * K.sum(kl, axis=-1)
        vae_loss = K.mean(xent_loss + kl_loss)
        return vae_loss
    
    def loss2(self):
        #with tf.compat.v1.Session() as sess:
        #    sess.run(tf.compat.v1.global_variables_initializer())
        #    images = sess.run(self.outputs)
        #print(self.outputs, type(self.outputs))
        image1 = tf.cast(self.outputs, tf.complex64)
        f = tf.signal.fft2d(image1)
        fshift = tf.signal.fftshift(f)
        magnitude_spectrum = tf.math.log(tf.math.abs(fshift))
        spectrum = tf.cast(magnitude_spectrum, tf.float32)
        spectrum = tf.image.resize(spectrum, (64, 64))
        mse = kr.metrics.mean_absolute_error(K.flatten(self.inputs), K.flatten(spectrum))
        xent_loss = self.input_latent * mse
        return xent_loss
    
    def loss3(self):
        #spectrum = tf.image.resize(spectrum, (64, 64))
        mae = kr.metrics.mean_absolute_error(K.flatten(self.inputs), K.flatten(tf.cast(tf.math.log(tf.math.abs(tf.signal.fftshift(tf.signal.fft2d(tf.cast(self.outputs, tf.complex64))))), tf.float32)))
        #xent_loss = self.input_latent * mse
        return mae
    
    def loss4(self):
        #spectrum = tf.image.resize(spectrum, (64, 64))
        mse = kr.metrics.mean_squared_error(K.flatten(self.inputs), K.flatten(tf.cast(tf.math.log(tf.math.abs(tf.signal.fftshift(tf.signal.fft2d(tf.cast(self.outputs, tf.complex64))))), tf.float32)))
        #xent_loss = self.input_latent * mse
        return mse
    
    def loss5(self):
        #spectrum = tf.image.resize(spectrum, (64, 64))
        mae = kr.metrics.mean_absolute_error(K.flatten(self.inputs), K.flatten(self.outputs))
        #xent_loss = self.input_latent * mse
        return mae
    
    def loss_pcc(self):
        pred = self.outputs - tf.reduce_mean(self.outputs,axis=(1,2),keepdims=True)
        true = self.inputs - tf.reduce_mean(self.inputs,axis=(1,2),keepdims=True)
        top = tf.reduce_sum(pred * true,axis=(1,2),keepdims=True)
    
        pred_sum = tf.reduce_sum(pred*pred,axis=(1,2),keepdims=True)
        true_sum = tf.reduce_sum(true*true,axis=(1,2),keepdims=True)
        bottom = tf.math.sqrt(pred_sum * true_sum)
    
        loss_value = tf.reduce_sum(1 - top / bottom)
        return loss_value
    
    def loss_comb(self):  
        pred = self.outputs - tf.reduce_mean(self.outputs,axis=(1,2),keepdims=True)
        true = self.inputs - tf.reduce_mean(self.inputs,axis=(1,2),keepdims=True)
        top = tf.reduce_sum(pred * true,axis=(1,2),keepdims=True)
    
        pred_sum = tf.reduce_sum(pred*pred,axis=(1,2),keepdims=True)
        true_sum = tf.reduce_sum(true*true,axis=(1,2),keepdims=True)
        bottom = tf.math.sqrt(pred_sum * true_sum)
        loss_1 = tf.reduce_sum(1 - top / bottom)
        loss_2 = kr.metrics.mean_absolute_error(K.flatten(self.inputs), K.flatten(self.outputs))
        a1 = 1
        a2 = 1
        loss_value = (a1*loss_1+a2*loss_2)/(a1+a2)
        return loss_value

In [None]:
#tf.compat.v1.enable_eager_execution()
x_train, y_train = getdata()

In [None]:
y_true = x_train[0:1]
#y_true = end.encoder0().predict(y_true)
def show_train_image(y_true):  
    f, ax = plt.subplots(1, 1, figsize=(7, 6))
    plt.imshow(np.reshape(y_true, (128, 128)), aspect='auto',cmap = 'gray')
    plt.xticks(fontsize = 20)
    plt.yticks(fontsize = 20)
    cb=plt.colorbar()
    cb.ax.tick_params(labelsize=20)
    cb.ax.set_xlabel('gray', size=20,fontproperties="Arial")
    plt.clim(0,1)
    plt.tight_layout()
    #plt.savefig('saxsGH/test4/intens.jpg')
show_train_image(y_true)
#plt.imsave('saxsGH/4/intens.jpg', np.reshape(y_true, (128, 128)), cmap='gray')

In [None]:
tf.compat.v1.disable_eager_execution()
seed = 42
batch_size = 1
epochs = 8000

np.random.seed(seed)
end = VAED2()
model = end.vae()
model.summary()
model.add_loss(end.loss_pcc())

np.random.seed(seed)
#adam = kr.optimizers.Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
model.compile(optimizer=kr.optimizers.Adam(lr = 0.00005, decay=0.0004), loss=None)
#history = model.fit(x_train, epochs=epochs, batch_size=batch_size, validation_data = (y_train,None))
history = model.fit(x_train, epochs=epochs, batch_size=batch_size)

In [None]:
plt.figure(figsize=(12, 9))
ax=plt.gca()
ax.spines['top'].set_linewidth(1)
ax.spines['bottom'].set_linewidth(1)
ax.spines['left'].set_linewidth(1)
ax.spines['right'].set_linewidth(1)
#plt.grid(linestyle="--")
#plt.plot(loss00001, label="Training Loss", linewidth=2)
plt.plot(history.history["loss"], label="η = 5×10$^{-5}$", linewidth=2.5)
#plt.plot(history.history["val_loss"], label="Validation Loss")
plt.xticks(fontsize = 20, fontweight='bold')
plt.yticks(fontsize = 20, fontweight='bold')
plt.xlabel("Epochs",fontsize = 25,fontproperties="Arial",fontweight='bold')
plt.ylabel("1+NPCC",fontsize = 25,fontproperties="Arial",fontweight='bold')
#plt.yscale('log')
#plt.ylim(-0.01,0.75)
plt.legend()
leg = plt.gca().get_legend()
ltext = leg.get_texts()
plt.setp(ltext, fontsize = 20,fontweight='bold')
#plt.savefig('saxsGH/test4/1loss0.00005-0.0004.jpg')

In [None]:
y_true = x_train[0:1]
y_enc = end.encoder().predict(y_true)
y_pred1 = model.predict(y_true)
y_pred2 = end.decoder().predict(y_true)
#y_true = end.encoder0().predict(y_true)
plot_predictions(y_true, y_enc, y_pred1, y_pred2)
#plt.imsave('saxsGH/test4/sem2-1.jpg', np.reshape(y_pred2, (128, 128)), cmap='gray')
#plt.imsave('saxsGH/test4/saxs1gray.jpg', np.reshape(y_pred1, (128, 128)), cmap='gray')
#plt.imsave('saxsGH/test4/saxs1.jpg', np.reshape(y_pred1, (128, 128)))

In [None]:
y_true = x_train[0:1]
y_enc = end.encoder().predict(y_true)
plot_predictions2(y_pred1)

In [None]:
#Output the predicted results during the calculation process while iterating
epochs = 2000
np.random.seed(seed)
end = VAED2()
model = end.vae()
model.summary()
model.add_loss(end.loss_pcc())
#adam = kr.optimizers.Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
model.compile(optimizer=kr.optimizers.Adam(lr = 0.00005), loss=None)
#history = model.fit(x_train, epochs=epochs, batch_size=batch_size, validation_data = (y_train,None))
history = model.fit(x_train, epochs=epochs, batch_size=batch_size, callbacks=[cb])
plt.cla()
plt.plot(history.history["loss"], label="Training Loss")
plt.xlabel("epochs")
plt.ylabel("Loss")
plt.legend()