# Downgrade to tensorflow 2.15.0 (Colab default version is 2.17.0)

# Run this cell ONCE!

In [None]:
!pip install tensorflow==2.15.0

# Import libs, Mount Drive, and Functions

In [None]:
import tensorflow as tf
from tensorflow.keras.utils import plot_model
import os
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
print(tf.__version__)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
repo_path = r"/content/drive/MyDrive/DAS_N2N_demo_V2"

In [None]:
def load_test_data(file_path:str):
    '''load the .npy test data'''
    filename = file_path.split('\\')[-1]
    arr = np.load(file_path)
    H, W = arr.shape[-2], arr.shape[-1]
    arr = np.reshape(arr, (-1, H, W))
    return filename, arr

def normalize(arr):
    arr = np.squeeze(arr)
    datamean = np.mean(arr)
    datastd = np.std(arr)
    # print(f"MEAN:{datamean}, STD: {datastd}")
    return (arr - datamean) / datastd

def arr_recover(arr):
    arr = np.reshape(arr, (235,-1,128,96))
    nr, nc, h, w = arr.shape
    arr = arr.swapaxes(1,2)
    arr = np.reshape(arr, (nr*h, nc*w))
    return arr

In [None]:
def plot_zoomed_fig(arr, region):
    x1, x2, y1, y2 = region
    fig, ax = plt.subplots(figsize=(6,6))

    # Define the extent to fit your array
    extent = (0, arr.shape[1], 0, arr.shape[0])

    # Display the image
    ax.imshow(arr, extent=extent, origin='lower', cmap='seismic', vmin=-1, vmax=1, aspect='auto')

    # Create inset axis
    axins = ax.inset_axes(
        # [0.5, 0.5, 0.47, 0.47],
        [0.5, 0.5, 0.48, 0.48],
        xlim=(x1, x2), ylim=(y1, y2), xticklabels=[], yticklabels=[])
    axins.imshow(arr, extent=extent, origin='lower', cmap='seismic', vmin=-1, vmax=1, aspect='auto')
    ax.indicate_inset_zoom(axins, edgecolor="black", linewidth=2.)
    plt.show()

## Load Test Data

In [None]:
dataname = "1115_1107_StrainRate_20231121T141451+0800_7147_00000.npy"
file_path = os.path.join(repo_path, "test_data", dataname)
test_data = np.load(file_path)
nr, nc, h, w = test_data.shape
test_data = np.reshape(test_data, (-1,h,w))
test_data = normalize(test_data)
dataname, test_data.shape # DAS-N2N accepts input shape of (128,96)

## Load Model

In [None]:
model_path = os.path.join(repo_path, r"weights/dasn2n_model")
tuned_model_path = os.path.join(repo_path, r"weights/TunedModel.h5")
model = tf.keras.models.load_model(model_path)
tuned_model = tf.keras.models.load_model(tuned_model_path)

## Plot Model Architecture

In [None]:
plot_model(model, rankdir='LR')

# Example of Denoising Gaussian Noise

In [None]:
# gen random test samples to show the basic denoising ability
testnormal = np.random.normal(scale=2,size=(128,96))
testinput = np.zeros((128,96))
testinput[:,20:30] = 1
testinput[:,70:90] = -1
noised_test_input = testinput+testnormal

In [None]:
# predict (denoise)
pred = model.predict(np.expand_dims(testinput+testnormal,0))
g_pred = model.predict(np.expand_dims(testnormal,0))
tuned_pred = model.predict(np.expand_dims(testinput+testnormal,0))
tuned_g_pred = model.predict(np.expand_dims(testnormal,0))

In [None]:
plt.clf()
vmin, vmax = -4, 4
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(8,4), sharey=True, sharex=True)
ax[0,0].set_title("Noisy Input")
ax[0,1].set_title("DAS-N2N")
ax[0,2].set_title("DAS-N2N (tuned)")
ax[0,0].imshow(testinput+testnormal, cmap='seismic', aspect='auto', vmin=vmin, vmax=vmax)
ax[0,1].imshow(np.squeeze(pred), cmap='seismic', aspect='auto', vmin=vmin, vmax=vmax)
ax[0,2].imshow(np.squeeze(tuned_pred), cmap='seismic', aspect='auto', vmin=vmin, vmax=vmax)

ax[1,0].imshow(testnormal, cmap='seismic', aspect='auto', vmin=vmin, vmax=vmax)
ax[1,1].imshow(np.squeeze(g_pred), cmap='seismic', aspect='auto', vmin=vmin, vmax=vmax)
im = ax[1,2].imshow(np.squeeze(tuned_g_pred), cmap='seismic', aspect='auto', vmin=vmin, vmax=vmax)

fig.colorbar(im, ax=ax.ravel().tolist())

# Plot Input Data

In [None]:
# Define zoom region
region = [300, 556, 500, 1012]

In [None]:
plot_raw = arr_recover(test_data)[::10,:]  # downsample for plotting
# plot_zoomed_fig(plot_raw, region)

Model Prediciton (Not Tuned)

In [None]:
pred = model.predict(test_data)

In [None]:
plot_arr = arr_recover(pred)[::10,:] # downsample for plotting
# plot_zoomed_fig(plot_arr, region)

Model Prediction (Tuned)

In [None]:
tuned_pred = tuned_model.predict(test_data)

In [None]:
plot_tuned = arr_recover(tuned_pred)[::10,:] # downsample for plotting (3008)
# plot_zoomed_fig(plot_tuned, region)

# Plot Denoising Results

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=3, sharey=True, figsize=(14,4))
vmin, vmax = -1, 1
zoomed_fig_pos = [0.5,0.5,0.48,0.48]
# Define the extent to fit your array
extent = (0, plot_raw.shape[1], 0, plot_raw.shape[0])
x1, x2, y1, y2 = region
# Display the image
ax[0].set_title("plot_raw Data")
ax[0].imshow(plot_raw  , extent=extent, origin='lower', cmap='seismic', vmin=vmin, vmax=vmax, aspect='auto')
ax[1].set_title("Pred")
ax[1].imshow(plot_arr, extent=extent, origin='lower', cmap='seismic', vmin=vmin, vmax=vmax, aspect='auto')
ax[2].set_title("Pred (Tuned)")
ax[2].imshow(plot_tuned, extent=extent, origin='lower', cmap='seismic', vmin=vmin, vmax=vmax, aspect='auto')

# Create inset axis
axins = ax[0].inset_axes(
    zoomed_fig_pos,
    xlim=(x1, x2), ylim=(y1, y2), xticklabels=[], yticklabels=[])
axins.imshow(plot_raw, extent=extent, origin='lower', cmap='seismic', vmin=vmin, vmax=vmax, aspect='auto')
ax[0].indicate_inset_zoom(axins, edgecolor="black", linewidth=2.)

axins = ax[1].inset_axes(
    zoomed_fig_pos,
    xlim=(x1, x2), ylim=(y1, y2), xticklabels=[], yticklabels=[])
axins.imshow(plot_arr, extent=extent, origin='lower', cmap='seismic', vmin=vmin, vmax=vmax, aspect='auto')
ax[1].indicate_inset_zoom(axins, edgecolor="black", linewidth=2.)

axins = ax[2].inset_axes(
    zoomed_fig_pos,
    xlim=(x1, x2), ylim=(y1, y2), xticklabels=[], yticklabels=[])
axins.imshow(plot_tuned, extent=extent, origin='lower', cmap='seismic', vmin=vmin, vmax=vmax, aspect='auto')
ax[2].indicate_inset_zoom(axins, edgecolor="black", linewidth=2.)

plt.show()