In [None]:
!pip uninstall tensorflow
!pip install tensorflow==2.16.1

In [1]:
import tensorflow as tf
import os
import numpy as np
from osgeo import gdal, osr
import cv2
import matplotlib.pyplot as plt

In [2]:
def input_pipeline(filename, batch_size, is_shuffle=True, is_train=True, is_repeat=True):
    feature_description = {
        'image_raw': tf.io.FixedLenFeature([400*400*4], dtype=tf.int64),
        'label': tf.io.FixedLenFeature([400*400*1], dtype=tf.int64),
    }

    def _parse_function(example_proto):
        feature_dict = tf.io.parse_single_example(example_proto, feature_description)
        image = tf.reshape(feature_dict['image_raw'], [400, 400, 4])
        image = tf.cast(image, tf.float32)

        image = image/10000

        label = tf.reshape(feature_dict['label'], [400, 400, 1])
        label = tf.cast(label, tf.float32)

        image_label = tf.concat([image, label], axis=-1)
        if is_train:
            image_label = tf.image.rot90(image_label, tf.random.uniform([], 0, 5, dtype=tf.int32))
            image_label = tf.image.random_flip_left_right(image_label)
            image_label = tf.image.random_flip_up_down(image_label)

        image = image_label[:, :, :4]
        label = image_label[:, :, -1]

        label = tf.cast(label, tf.int32)

        return image, label

    dataset = tf.data.TFRecordDataset(filename)
    if is_repeat:
        dataset = dataset.repeat()
    dataset = dataset.map(_parse_function)
    if is_shuffle:
        dataset = dataset.shuffle(buffer_size=1000)
    batch = dataset.batch(batch_size=batch_size)
    return batch

In [3]:
def P(X, F1, k_size, s, stage):

    # Name definition
    P_Name = 'P-layer' + str(stage)
    P_BN_Name = 'P-layer-BN' + str(stage)

    X = tf.keras.layers.Conv2D(filters = F1, kernel_size = (k_size, k_size), strides = (s, s), padding = 'valid',
              name = P_Name, kernel_initializer = tf.keras.initializers.glorot_uniform(seed = 0))(X)

    X = tf.keras.layers.BatchNormalization(axis = 3, name = P_BN_Name)(X)

    X = tf.keras.layers.Activation('relu')(X)

    return X

def Q(X, F1, k_size, s, stage):

    # Name definition
    Q_Name = 'Q-layer' + str(stage)
    Q_BN_Name = 'Q-layer-BN' + str(stage)


    X = tf.keras.layers.Conv2D(filters = F1, kernel_size = (k_size, k_size), strides = (s, s), padding = 'same',
              name = Q_Name, kernel_initializer = tf.keras.initializers.glorot_uniform(seed = 0))(X)

    X = tf.keras.layers.BatchNormalization(axis = 3, name = Q_BN_Name)(X)

    X = tf.keras.layers.Activation('relu')(X)

    return X

# w = h = 128

def complex_initializer(base_initializer):
    f = base_initializer(seed = 0)

    def initializer(*args, dtype=tf.complex64, **kwargs):
        real = f(*args, **kwargs)
        imag = f(*args, **kwargs)
        return tf.complex(real, imag)

    return initializer


class FourierLayer_large(tf.keras.layers.Layer):
    def __init__(self, num_modes, stage, L=1.0):
        super(FourierLayer_large, self).__init__()
        self.num_modes = num_modes
        self.modes1=24
        self.modes2=24
        self.L = L
        self.stage = stage

    def build(self, input_shape):
        self.weight1 = self.add_weight(
            shape=(input_shape[-1], self.num_modes, self.modes1, self.modes2), dtype=tf.complex64,
            initializer = complex_initializer(tf.keras.initializers.glorot_uniform), name='w1'
        )

        self.weight2 = self.add_weight(
            shape=(input_shape[-1], self.num_modes, self.modes1, self.modes2), dtype=tf.complex64,
            initializer = complex_initializer(tf.keras.initializers.glorot_uniform), name='w2'
        )

    def compl_mul2d(self, input, weights):
        # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
        return tf.einsum("bixy,ioxy->boxy", input, weights)

    def call(self, X):

        # Name definition
        FFT_name = 'fft-layer' + str(self.stage)
        RFFT_name = 'ifft-layer' + str(self.stage)

        Res_X = tf.keras.layers.Conv2D(filters = self.num_modes, kernel_size = (1, 1), padding = 'same')(X)

        X = tf.transpose(X, perm=[0, 3, 1, 2])
        X_fft = tf.signal.rfft2d(X, name = FFT_name)

        updated_slice = self.compl_mul2d(X_fft[:, :, :self.modes1, :self.modes2], self.weight1)
        updated_slice2 = self.compl_mul2d(X_fft[:, :, -self.modes1:, :self.modes2], self.weight2)

        zero_0 = tf.zeros([tf.shape(X)[0], tf.shape(X)[1], tf.shape(X)[2]-2*self.modes1, self.modes2], dtype=tf.complex64)
        zero_1 = tf.zeros([tf.shape(X)[0], tf.shape(X)[1], tf.shape(X)[2], tf.shape(X)[3]//2+1-self.modes2], dtype=tf.complex64)
        out_ft = tf.concat([tf.concat([updated_slice, zero_0, updated_slice2], axis=2), zero_1], axis=3)

        X = tf.signal.irfft2d(out_ft, fft_length=(X.shape[2], X.shape[3]), name = RFFT_name)
        X = tf.transpose(X, perm=[0, 2, 3, 1])

        X = tf.keras.layers.Add()([X, Res_X])

        X = tf.keras.layers.Activation('relu')(X)

        return X

def FNO_large(input_shape):

    X_input = tf.keras.layers.Input(shape = input_shape)


    X = P(X_input, F1 = 64, k_size = 1, s = 1, stage = 1)

    X = FourierLayer_large(64, stage = 1)(X)
    X = FourierLayer_large(64, stage = 2)(X)
    X = FourierLayer_large(64, stage = 3)(X)
    X = FourierLayer_large(64, stage = 4)(X)

    X = Q(X, F1 = 64, k_size = 1, s = 1, stage = 1)

    X = tf.keras.layers.Dense(2, activation=tf.nn.softmax, kernel_initializer = tf.keras.initializers.glorot_uniform(seed=0), name = 'Dense')(X)

    # Create model
    model = tf.keras.Model(inputs = X_input, outputs = X, name = 'Fourier-Neural-Operator')

    return model

In [4]:
BATCH_SIZE = 3
u_optimizer = tf.keras.optimizers.Adam(1e-4)

u_model = FNO_large(input_shape = (400, 400, 4))

In [7]:
import time

def train_step(t_images, t_labels):
    ce_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
    with tf.GradientTape() as u_tape:
        prob = u_model(t_images, training=True)
        loss = ce_loss(t_labels, prob)
    gradients_of_u_model = u_tape.gradient(loss, u_model.trainable_variables)
    u_optimizer.apply_gradients(zip(gradients_of_u_model, u_model.trainable_variables))
    return prob


def train(train_ds, epoch):
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
    for (t_images, t_labels) in train_ds:

        logits = train_step(t_images, t_labels)
        train_accuracy.update_state(t_labels, logits)
        # print(train_accuracy.result())

    train_acc = train_accuracy.result()
    print('train accuracy over epoch %d: %.2f'% (epoch+1, train_acc*100))

    return

In [6]:
folder = ['/content/drive/MyDrive/ice_cloud/l8ps_ds_2', '/content/drive/MyDrive/ice_cloud/l8ps_ds_3',
          '/content/drive/MyDrive/ice_cloud/l8ps_ds_4', '/content/drive/MyDrive/ice_cloud/l8ps_ds_5',
          '/content/drive/MyDrive/ice_cloud/l8ps_ds_6']
filenames = [os.path.join(folder[0], f) for f in os.listdir(folder[0]) if f.endswith('.tfrecords') and not f.endswith('full.tfrecords')]+\
            [os.path.join(folder[1], f) for f in os.listdir(folder[1]) if f.endswith('.tfrecords') and not f.endswith('full.tfrecords')]+\
            [os.path.join(folder[2], f) for f in os.listdir(folder[2]) if f.endswith('.tfrecords') and not f.endswith('full.tfrecords')]+\
            [os.path.join(folder[3], f) for f in os.listdir(folder[3]) if f.endswith('.tfrecords') and not f.endswith('full.tfrecords')]+\
            [os.path.join(folder[4], f) for f in os.listdir(folder[4]) if f.endswith('.tfrecords') and not f.endswith('full.tfrecords')]

filenames

['/content/drive/MyDrive/ice_cloud/l8ps_ds_2/20230308_212634_27_2413.tfrecords',
 '/content/drive/MyDrive/ice_cloud/l8ps_ds_2/20230325_103503_19_240c.tfrecords',
 '/content/drive/MyDrive/ice_cloud/l8ps_ds_2/20230329_033637_63_2414.tfrecords',
 '/content/drive/MyDrive/ice_cloud/l8ps_ds_2/20230402_144322_97_2402.tfrecords',
 '/content/drive/MyDrive/ice_cloud/l8ps_ds_2/20230417_072528_39_2414.tfrecords',
 '/content/drive/MyDrive/ice_cloud/l8ps_ds_2/20230418_094613_93_2426.tfrecords',
 '/content/drive/MyDrive/ice_cloud/l8ps_ds_2/20230418_094636_32_2426.tfrecords',
 '/content/drive/MyDrive/ice_cloud/l8ps_ds_2/20230420_143108_00_2413.tfrecords',
 '/content/drive/MyDrive/ice_cloud/l8ps_ds_2/20230429_192205_75_2414.tfrecords',
 '/content/drive/MyDrive/ice_cloud/l8ps_ds_2/20230429_192219_98_2414.tfrecords',
 '/content/drive/MyDrive/ice_cloud/l8ps_ds_2/20230503_122053_44_2414.tfrecords',
 '/content/drive/MyDrive/ice_cloud/l8ps_ds_2/20230515_223754_57_240c.tfrecords',
 '/content/drive/MyDrive/ice

In [None]:
def train_process():
    train_ds = input_pipeline(filenames, BATCH_SIZE, is_repeat=False)

    for i in range(100):
        train(train_ds, i)
        if (i + 1) % 1 == 0:
            u_model.save_weights('FourierNet.ckpt')

train_process()


In [None]:
u_model = FNO_large(input_shape = (400, 400, 4))
u_model.load_weights('/content/drive/MyDrive/ice_cloud/github_code/FourierNet.ckpt')

In [12]:
def tif2array(input_file):
    dataset = gdal.Open(input_file, gdal.GA_ReadOnly)
    image = np.zeros((dataset.RasterYSize, dataset.RasterXSize, dataset.RasterCount),
                     dtype=int)

    for b in range(dataset.RasterCount):
        band = dataset.GetRasterBand(b + 1)
        image[:, :, b] = band.ReadAsArray()

    return image

def clip_center(img, clip_height, clip_width):
    """
    Clips the center part of the image.

    Parameters:
    - img: 2D or 3D NumPy array representing the image.
    - clip_height: Height of the central clip.
    - clip_width: Width of the central clip.

    Returns:
    - Clipped center part of the image as a NumPy array.
    """
    # Get image dimensions
    img_height, img_width = img.shape[:2]

    # Calculate the center coordinates
    center_y, center_x = img_height // 2, img_width // 2

    # Define the starting and ending points for the clipping
    start_y = max(center_y - clip_height // 2, 0)
    end_y = min(center_y + clip_height // 2, img_height)
    start_x = max(center_x - clip_width // 2, 0)
    end_x = min(center_x + clip_width // 2, img_width)

    # Return the clipped center part
    return img[start_y:end_y, start_x:end_x]

In [13]:
folder_path = '/content/drive/MyDrive/ice_cloud/test_ds/dataset'
imgs = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('ref.tif')]
imgs.sort()

folder_path = '/content/drive/MyDrive/ice_cloud/test_ds2/dataset'
imgs2 = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('ref.tif')]
imgs2.sort()
imgs = imgs+imgs2
imgs

['/content/drive/MyDrive/ice_cloud/test_ds/dataset/2022-04-15_strip_5573346_ref.tif',
 '/content/drive/MyDrive/ice_cloud/test_ds/dataset/2023-08-31_strip_6747657_ref.tif',
 '/content/drive/MyDrive/ice_cloud/test_ds/dataset/2023-09-09_strip_6763305_ref.tif',
 '/content/drive/MyDrive/ice_cloud/test_ds/dataset/2023-09-14_strip_6774056_ref.tif',
 '/content/drive/MyDrive/ice_cloud/test_ds/dataset/2023-09-20_strip_6786319_ref.tif',
 '/content/drive/MyDrive/ice_cloud/test_ds/dataset/2023-09-21_strip_6788017_ref.tif',
 '/content/drive/MyDrive/ice_cloud/test_ds/dataset/2023-09-24_strip_6793653_ref.tif',
 '/content/drive/MyDrive/ice_cloud/test_ds/dataset/2023-09-27_strip_6800261_ref.tif',
 '/content/drive/MyDrive/ice_cloud/test_ds/dataset/2023-10-06_strip_6818915_ref.tif',
 '/content/drive/MyDrive/ice_cloud/test_ds/dataset/2023-10-08_strip_6823217_ref.tif',
 '/content/drive/MyDrive/ice_cloud/test_ds/dataset/2023-10-13_strip_6833728_ref.tif',
 '/content/drive/MyDrive/ice_cloud/test_ds/dataset/202

In [None]:
for img_path in imgs[2:3]:
    img = tif2array(img_path)
    img = clip_center(img, 3200, 3200)

    img = img/10000

    img = tf.reshape(img, [1, 3200, 3200, 4])

    patch_size = 400
    stride = 400
    output_shape = (img.shape[1], img.shape[2])

    # Initialize an empty array for storing the full prediction
    full_prediction = np.zeros(output_shape)

    # Iterate over the image with a sliding window
    for i in range(0, img.shape[1], stride):
        for j in range(0, img.shape[2], stride):
            patch = img[:, i:i+patch_size, j:j+patch_size, :]

            if patch.shape[1] == patch_size and patch.shape[2] == patch_size:
                pred = u_model(patch, training=False)

                # Place the predicted patch back into the full prediction array
                full_prediction[i:i+patch_size, j:j+patch_size] = pred[0, :, :, 1]

    plt.figure(figsize=(10, 10))

    plt.subplot(1, 2, 1)
    plt.axis('off')
    plt.imshow((img[0, ::10, ::10, 2::-1]), vmin=0, vmax=1)

    plt.subplot(1, 2, 2)
    plt.axis('off')
    plt.imshow((full_prediction), vmin=0, vmax=1)

    # np.save(img_path[:-7]+'fouriernet.npy', full_prediction)