# import library

In [18]:
import tifffile as tiff
import numpy as np
import matplotlib.pyplot as plt
import cv2
import tensorflow as tf
from tensorflow.keras.layers import Input, Concatenate, Conv2D, LeakyReLU, BatchNormalization, Conv2DTranspose, Activation, Dropout, Add, UpSampling2D
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from skimage import exposure
import os
import time
from tensorflow.keras.callbacks import TensorBoard

In [19]:
current_time = time.strftime("%m%d_%H%M")

result_dir = os.path.join(os.getcwd(), 'result', f'logs_{current_time}')
os.makedirs(result_dir, exist_ok=True)

log_dir = os.path.join(os.getcwd(), 'logs', f'logs_{current_time}')
os.makedirs(log_dir, exist_ok=True)

print(f'Result directory created: {result_dir}')
print(f'log directory created: {log_dir}')

Epoch = 200
batch = 32

Result directory created: /jsm0707/GAN/result/logs_0808_0416
log directory created: /jsm0707/GAN/logs/logs_0808_0416


# Data load

In [20]:
def load_tif_files(directory, target_size=(256, 256)):
    files = sorted([os.path.join(directory, file) for file in os.listdir(directory) if file.endswith('.tif')])
    data = []
    for file in files:
        img = tiff.imread(file)
        img_resized = cv2.resize(img, target_size)
        data.append(img_resized)
    return np.array(data)

def load_dataset(x_directory, y_directory):
    x_data = load_tif_files(x_directory)
    y_data = load_tif_files(y_directory)
    return x_data, y_data

def contrast(image): 
    valid_data = image[np.logical_and(image != np.inf, image != -np.inf)] 
    in_low, in_high = np.percentile(valid_data, (2, 98))
    image = exposure.rescale_intensity(image, in_range=(in_low, in_high))
    return image

def load_data(batch_size, X, y):
    i = np.random.randint(0, int(len(X) / batch_size))
    img_A = (np.array(X[i:i + batch_size]) / 5000.0) - 1
    img_B = (np.array(y[i:i + batch_size]) / 5000.0) - 1
    return img_A, img_B

def load_batch(batch_size, X, y):
    for i in range(int(len(X) / batch_size)):
        img_A = (X[i:i + batch_size] / 5000.0) - 1
        img_B = (y[i:i + batch_size] / 5000.0) - 1
        yield img_A, img_B

# Define Discriminator

In [21]:
def define_discriminator(image_shape):
    init = RandomNormal(stddev=0.02)
    in_src_image = Input(shape=image_shape)
    in_target_image = Input(shape=image_shape)
    merged = Concatenate()([in_src_image, in_target_image])
    d = Conv2D(64, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(merged)
    d = LeakyReLU(alpha=0.2)(d)
    d = Conv2D(128, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    d = Conv2D(256, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    d = Conv2D(512, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    d = Conv2D(512, (4, 4), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    patch_out = Conv2D(1, (4, 4), padding='same', kernel_initializer=init)(d)
    patch_out = Activation('sigmoid')(patch_out)
    model = Model([in_src_image, in_target_image], patch_out)
    opt = Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])
    return model

# Define Generator

In [22]:
def define_encoder_block(layer_in, n_filters, batchnorm=True):
    init = RandomNormal(stddev=0.02)
    g = Conv2D(n_filters, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(layer_in)
    if batchnorm:
        g = BatchNormalization()(g, training=True)
    g = LeakyReLU(alpha=0.2)(g)
    return g

def decoder_block(layer_in, skip_in, n_filters, dropout=True):
    init = RandomNormal(stddev=0.02)
    g = Conv2DTranspose(n_filters, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(layer_in)
    g = BatchNormalization()(g, training=True)
    if dropout:
        g = Dropout(0.5)(g, training=True)
    if g.shape[1] != skip_in.shape[1] or g.shape[2] != skip_in.shape[2]:
        skip_in = UpSampling2D(size=(g.shape[1] // skip_in.shape[1], g.shape[2] // skip_in.shape[2]))(skip_in)
    g = Concatenate()([g, skip_in])
    g = Activation('relu')(g)
    return g

def residual_block(layer_in, n_filters, kernel_size=(3,3)):
    init = RandomNormal(stddev=0.02)
    g = Conv2D(n_filters, kernel_size, padding='same', kernel_initializer=init)(layer_in)
    g = BatchNormalization()(g, training=True)
    g = Activation('relu')(g)
    g = Conv2D(n_filters, kernel_size, padding='same', kernel_initializer=init)(g)
    g = BatchNormalization()(g, training=True)
    g = Add()([g, layer_in])
    return g

def define_generator(image_shape):
    init = RandomNormal(stddev=0.02)
    in_image = Input(shape=image_shape)
    e1 = define_encoder_block(in_image, 64, batchnorm=False)
    e2 = define_encoder_block(e1, 128)
    e3 = define_encoder_block(e2, 256)
    e4 = define_encoder_block(e3, 512)
    e5 = define_encoder_block(e4, 512)
    e6 = define_encoder_block(e5, 512)
    e7 = define_encoder_block(e6, 512)
    b = Conv2D(512, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(e7)
    b = Activation('relu')(b)
    r1 = residual_block(b, 512)
    r2 = residual_block(r1, 512)
    r3 = residual_block(r2, 512)
    d1 = decoder_block(r3, e7, 512)
    d2 = decoder_block(d1, e6, 512)
    d3 = decoder_block(d2, e5, 512)
    d4 = decoder_block(d3, e4, 512, dropout=False)
    d5 = decoder_block(d4, e3, 256, dropout=False)
    d6 = decoder_block(d5, e2, 128, dropout=False)
    d7 = decoder_block(d6, e1, 64, dropout=False)
    g = Conv2DTranspose(image_shape[-1], (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(d7)
    out_image = Activation('tanh')(g)
    model = Model(in_image, out_image)
    return model

# Define Gan

In [23]:
def define_gan(g_model, d_model, image_shape):
    for layer in d_model.layers:
        if not isinstance(layer, BatchNormalization):
            layer.trainable = False
    in_src = Input(shape=image_shape)
    gen_out = g_model(in_src)
    dis_out = d_model([in_src, gen_out])
    model = Model(in_src, [dis_out, gen_out])
    opt = Adam(learning_rate=0.0001, beta_1=0.5)
    model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1, 100])
    return model


In [24]:
def generate_real_samples(n_samples, patch_shape, X, y):
    trainA, trainB = load_data(n_samples, X, y)
    ix = np.random.randint(0, trainA.shape[0], n_samples)
    X1, X2 = trainA[ix], trainB[ix]
    y = np.ones((n_samples, patch_shape, patch_shape, 1))
    return [X1, X2], y

def generate_fake_samples(g_model, samples, patch_shape):
    X = g_model.predict(samples)
    y = np.zeros((len(X), patch_shape, patch_shape, 1))
    return X, y

def summarize_performance(step, g_model, result_dir, dataset, n_samples, X, y):
    [X_realA, X_realB], _ = generate_real_samples(n_samples, 1, X, y)
    X_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)
    gen_imgs = [X_realA, X_fakeB, X_realB]
    titles = ['Condition', 'Generated', 'Original']
    fig, axs = plt.subplots(3, n_samples, figsize=(20, 20))
    for i in range(3):
        for j in range(n_samples):
            image = gen_imgs[i][j]
            rescaled_image = (image + 1) * 5000
            axs[i, j].imshow(contrast(rescaled_image / 10000), cmap='gray')
            axs[i, j].set_title(titles[i])
            axs[i, j].axis('off')
    filename = os.path.join(result_dir, f'{dataset}_plot_{step:06d}.png')
    fig.savefig(filename)
    plt.close(fig)

# Training

In [25]:
tensorboard_callback = TensorBoard(log_dir=log_dir)
file_writer = tf.summary.create_file_writer(log_dir)

In [26]:
def train_on_dataset(dataset, d_model, g_model, gan_model, n_epochs, n_batch, result_dir, d_filename, g_filename, file_writer, update_ratio=5):
    x_directory = os.path.join(os.getcwd(), 'test_image', dataset)
    y_directory = os.path.join(os.getcwd(), 'test_image', 'subset_512')
    X, y = load_dataset(x_directory, y_directory)
    print(f'Training on {dataset}: X shape: {X.shape}, y shape: {y.shape}')

    n_patch = d_model.output_shape[1]
    for epoch in range(1, n_epochs + 1):
        for batch_i, (trainA, trainB) in enumerate(load_batch(n_batch, X, y)):
            [X_realA, X_realB], y_real = generate_real_samples(n_batch, n_patch, X, y)
            X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)

            if batch_i % update_ratio == 0:
                d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
                d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
                with file_writer.as_default():
                    tf.summary.scalar('D1_loss', d_loss1, step=epoch)
                    tf.summary.scalar('D2_loss', d_loss2, step=epoch)
            
            g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])
            with file_writer.as_default():
                tf.summary.scalar('G_loss', g_loss, step=epoch)

        summarize_performance(epoch, g_model, result_dir, dataset, n_samples=3, X=X, y=y)
        print(f'[Epoch {epoch}] D1_loss: {d_loss1:.3f}, D2_loss: {d_loss2:.3f}, G_loss: {g_loss:.3f}')
        g_model.save(g_filename)
        d_model.save(d_filename)

In [27]:
# Directories for datasets
datasets = ['subset_512_8x8', 'subset_512_4x4', 'subset_512_2x2']

# Sequentially train on each dataset
image_shape = (256, 256, 1)
d_model = define_discriminator(image_shape)
g_model = define_generator(image_shape)
gan_model = define_gan(g_model, d_model, image_shape)

for dataset in datasets:
    g_filename = os.path.join(result_dir, f'g_model_{dataset}.h5')
    d_filename = os.path.join(result_dir, f'd_model_{dataset}.h5')

    if os.path.exists(g_filename):
        g_model.load_weights(g_filename)
    if os.path.exists(d_filename):
        d_model.load_weights(d_filename)

    train_on_dataset(dataset, d_model, g_model, gan_model, n_epochs=Epoch, n_batch=batch, result_dir=result_dir, d_filename=d_filename, g_filename=g_filename, file_writer=file_writer)

Training on subset_512_8x8: X shape: (864, 256, 256), y shape: (864, 256, 256)
[Epoch 1] D1_loss: 0.168, D2_loss: 0.310, G_loss: 51.273
[Epoch 2] D1_loss: 0.210, D2_loss: 0.299, G_loss: 31.150
[Epoch 3] D1_loss: 0.027, D2_loss: 0.036, G_loss: 20.780
[Epoch 4] D1_loss: 0.205, D2_loss: 1.717, G_loss: 11.000
[Epoch 5] D1_loss: 0.420, D2_loss: 0.426, G_loss: 7.503
[Epoch 6] D1_loss: 0.405, D2_loss: 0.410, G_loss: 5.246
[Epoch 7] D1_loss: 0.396, D2_loss: 0.403, G_loss: 3.901
[Epoch 8] D1_loss: 0.391, D2_loss: 0.398, G_loss: 3.083
[Epoch 9] D1_loss: 0.387, D2_loss: 0.394, G_loss: 2.569
[Epoch 10] D1_loss: 0.384, D2_loss: 0.392, G_loss: 2.210
[Epoch 11] D1_loss: 0.382, D2_loss: 0.390, G_loss: 1.971
[Epoch 12] D1_loss: 0.380, D2_loss: 0.388, G_loss: 1.786
[Epoch 13] D1_loss: 0.378, D2_loss: 0.387, G_loss: 1.634
[Epoch 14] D1_loss: 0.377, D2_loss: 0.386, G_loss: 1.548
[Epoch 15] D1_loss: 0.376, D2_loss: 0.384, G_loss: 1.451
[Epoch 16] D1_loss: 0.374, D2_loss: 0.384, G_loss: 1.387
[Epoch 17] D1_

# inference코드

In [28]:
import tensorflow as tf
from tensorflow.keras.models import load_model
import numpy as np
import os
import matplotlib.pyplot as plt
from skimage import exposure
import cv2
import tifffile as tiff

def load_tif_file(file_path, target_size=(256, 256)):
    img = tiff.imread(file_path)
    img_resized = cv2.resize(img, target_size)
    return np.array(img_resized)

def contrast(image):
    valid_data = image[np.logical_and(image != np.inf, image != -np.inf)]
    in_low, in_high = np.percentile(valid_data, (2, 98))
    image = exposure.rescale_intensity(image, in_range=(in_low, in_high))
    return image

def visualize_results(X_realA, X_fakeB, save_path=None):
    gen_imgs = [X_realA, X_fakeB]
    titles = ['Condition', 'Generated']
    fig, axs = plt.subplots(1, 2, figsize=(20, 10))
    for i in range(2):
        image = gen_imgs[i]
        rescaled_image = (image + 1) * 5000
        axs[i].imshow(contrast(rescaled_image / 10000), cmap='gray')
        axs[i].set_title(titles[i])
        axs[i].axis('off')
    if save_path:
        fig.savefig(save_path)
    plt.show()

def perform_inference(g_model, file_path, save_path=None):
    X_realA = load_tif_file(file_path)

    X_realA = (X_realA / 5000.0) - 1

    # Ensure the input shape matches the expected shape of the model
    if len(X_realA.shape) == 2:  # Add channel dimension if missing
        X_realA = np.expand_dims(X_realA, axis=-1)
    
    # Ensure that input is 4-dimensional (batch, height, width, channels)
    X_realA = np.expand_dims(X_realA, axis=0)

    X_fakeB = g_model.predict(X_realA)

    # Ensure the output is squeezed to remove the batch dimension
    X_fakeB = np.squeeze(X_fakeB, axis=0)
    X_realA = np.squeeze(X_realA, axis=0)

    visualize_results(X_realA, X_fakeB, save_path)

In [29]:
# File path
file_path = os.path.join(os.getcwd(), 'test_image', 'subset_512_4x4', 'S1A_IW_GRDH_1SDV_20210105T092315_20210105T092344_036001_0437B7_42D51_subset_0_0.tif')

# Load the saved generator model
result_dir = os.path.join(os.getcwd(), 'result', 'logs_8x8_0709_1551_code1')
g_model_path = os.path.join(result_dir, 'd_model.h5')

if not os.path.exists(g_model_path):
    raise FileNotFoundError(f"Generator model not found at {g_model_path}")

# Load the generator model
g_model = load_model(g_model_path, compile=False)
print("Generator model loaded successfully.")

# Perform inference
output_image_path = os.path.join(result_dir, 'inference_result.png')
perform_inference(g_model, file_path, save_path=output_image_path)

FileNotFoundError: Generator model not found at /jsm0707/GAN/result/logs_8x8_0709_1551_code1/d_model.h5