In [1]:
import h5py
import numpy as np
import os
import matplotlib.pyplot as plt
import os
import tensorflow as tf

2025-07-15 14:10:17.307862: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752613817.321528 1001188 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752613817.325771 1001188 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1752613817.337953 1001188 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752613817.337968 1001188 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752613817.337969 1001188 computation_placer.cc:177] computation placer alr

In [2]:
# minorized reference
with h5py.File('/global/u2/k/kberard/SCGSR/Research/Diamond/Data/density_tot_ref.h5', 'r') as file:
    #print("Keys: %s" % file.keys())
    ref_d = file['density'][:]
#print(ref_d)
print(ref_d.shape)
minorized_ref_d = (np.sum(ref_d,axis=2))
minorized_ref_d.shape



(64, 64, 64)


(64, 64)

In [3]:
####################################################################################################################################################
def stochastic_density(d,N):
    # poisson model
    #  accurate and fast for all values of N
    # N  = number of MC samples
    assert isinstance(d,np.ndarray)
    assert isinstance(N,(int,float,np.int64,np.float64))
    assert N>0
    ds = np.random.poisson(N*d)/N
    ds*= d.sum()/ds.sum()
    return ds
#end def stochastic_density

####################################################################################################################################################

In [4]:
import tensorflow as tf

def jensen_shannon_divergence_loss(y_true, y_pred):
    y_t = tf.cast(y_true, tf.float32)
    y_p = tf.cast(y_pred, tf.float32)

    y_t = tf.reshape(y_t, [tf.shape(y_t)[0], -1])
    y_p = tf.reshape(y_p, [tf.shape(y_p)[0], -1])

    y_t /= tf.reduce_sum(y_t, axis=1, keepdims=True) + 1e-8
    y_p /= tf.reduce_sum(y_p, axis=1, keepdims=True) + 1e-8

    m = 0.5 * (y_t + y_p)

    kl_true = tf.reduce_sum(y_t * tf.math.log((y_t + 1e-8) / (m + 1e-8)), axis=1)
    kl_pred = tf.reduce_sum(y_p * tf.math.log((y_p + 1e-8) / (m + 1e-8)), axis=1)

    jsd = 0.5 * (kl_true + kl_pred)

    return tf.reduce_mean(jsd)


In [None]:
import os
import re
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import register_keras_serializable

# === Custom Components ===
@register_keras_serializable(package="Custom")
class Renormalize(tf.keras.layers.Layer):
    def __init__(self, target_sum=8.0, **kwargs):
        super().__init__(**kwargs)
        self.target_sum = target_sum

    def call(self, inputs):
        x_sum = tf.reduce_sum(inputs, axis=[1, 2, 3], keepdims=True)
        return inputs / (x_sum + 1e-8) * self.target_sum

    def get_config(self):
        config = super().get_config()
        config.update({'target_sum': self.target_sum})
        return config

@register_keras_serializable(package="Custom")
def jensen_shannon_divergence_loss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    m = 0.5 * (y_true + y_pred)
    return 0.5 * tf.reduce_mean(
        tf.keras.losses.KLDivergence()(y_true, m) +
        tf.keras.losses.KLDivergence()(y_pred, m)
    )

# === Load U-Net Model ===
unet_model = load_model("UNet_Denoiser.keras", custom_objects={
    "Renormalize": Renormalize,
    "jensen_shannon_divergence_loss": jensen_shannon_divergence_loss
}, compile=False)


# === Find DAE and CAE models ===
model_files = [f for f in os.listdir('.') if re.match(r'\d+_(DAE|CAE)\.keras$', f)]
model_dict = {}
print("here",model_files)
for f in model_files:
    print("here")
    match = re.match(r'(\d+)_(DAE|CAE)\.keras$', f)
    if match:
        n, model_type = match.groups()
        n = int(n)
        if n not in model_dict:
            model_dict[n] = {}
        model_dict[n][model_type] = f

combined_results = {}

for n_samples in sorted(model_dict.keys()):
    models = model_dict[n_samples]
    print(f"\nProcessing models for {n_samples} samples...")

    if 'DAE' not in models or 'CAE' not in models:
        print(f"Missing models for {n_samples}. Skipping.")
        continue

    sample_file = f"{n_samples}_sample_density.npy"
    if not os.path.exists(sample_file):
        print(f"Missing mean sample file: {sample_file}")
        continue

    dae = load_model(models['DAE'], compile=False)
    cae = load_model(models['CAE'], compile=False)
    avg_sample = np.load(sample_file)

    num_steps = 1000000 // 1000 - 1
    Base_JSD = np.zeros((num_steps, 1))
    DAE_JSD = np.zeros((num_steps, 1))
    CAE_JSD = np.zeros((num_steps, 1))
    UNET_JSD = np.zeros((num_steps, 1))

    avg = None
    counter = -1

    for i in range(1000, 1000000, 1000):
        counter += 1
        sample = stochastic_density(ref_d, i)
        test = np.sum(sample, axis=2)
        test_reshaped = test.reshape(-1, 64, 64, 1)

        if counter == 0:
            avg = test
        else:
            avg += test

        Base_JSD[counter] = jensen_shannon_divergence_loss(test[0].copy(), minorized_ref_d.copy()).numpy()
        DAE_JSD[counter] = jensen_shannon_divergence_loss(
            dae.predict(test_reshaped)[0, :, :, 0], minorized_ref_d.copy()).numpy()
        CAE_JSD[counter] = jensen_shannon_divergence_loss(
            cae.predict(test_reshaped)[0, :, :, 0], minorized_ref_d.copy()).numpy()
        UNET_JSD[counter] = jensen_shannon_divergence_loss(
            unet_model.predict(test_reshaped)[0, :, :, 0], minorized_ref_d.copy()).numpy()

    avg = avg / counter
    avg_prediction = dae.predict(avg.reshape(-1, 64, 64, 1))[0, :, :, 0]
    avg_JSD = jensen_shannon_divergence_loss(avg_prediction, minorized_ref_d.copy()).numpy()
    AVG_TOTAL_JSD = jensen_shannon_divergence_loss(avg_sample, minorized_ref_d.copy()).numpy()

    x_vals = np.arange(num_steps) * 1000

    combined_results[n_samples] = {
        "x": x_vals,
        "base": np.log(Base_JSD.flatten()),
        "dae": np.log(DAE_JSD.flatten()),
        "cae": np.log(CAE_JSD.flatten()),
        "unet": np.log(UNET_JSD.flatten()),
        "avg_JSD": np.log(avg_JSD),
        "AVG_TOTAL_JSD": np.log(AVG_TOTAL_JSD)
    }

    # Plot individual
    plt.figure(figsize=(7, 5))
    plt.plot(x_vals, np.log(Base_JSD), label='Base JSD')
    plt.plot(x_vals, np.log(DAE_JSD), label='DAE JSD')
    plt.plot(x_vals, np.log(CAE_JSD), label='CAE JSD')
    plt.plot(x_vals, np.log(UNET_JSD), label='UNet JSD')

    plt.axhline(np.log(avg_JSD), color='black', linestyle='--', label='avg_JSD')
    plt.axhline(np.log(AVG_TOTAL_JSD), color='purple', linestyle='-.', label='AVG_TOTAL_JSD')

    plt.text(x_vals[-1], np.log(avg_JSD), f'avg_JSD={avg_JSD:.4e}', color='black', ha='right')
    plt.text(x_vals[-1], np.log(AVG_TOTAL_JSD), f'AVG_TOTAL_JSD={AVG_TOTAL_JSD:.4e}', color='purple', ha='right')

    plt.xlabel("Sample Index (×1000)")
    plt.ylabel("log(JSD)")
    plt.title(f"JSD vs Samples - {n_samples} Samples")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"JSD_plot_{n_samples}.png")
    plt.close()

# Combined plot
plt.figure(figsize=(10, 6))
for n_samples, result in sorted(combined_results.items()):
    plt.plot(result["x"], result["dae"], label=f"DAE {n_samples}")
    plt.plot(result["x"], result["cae"], linestyle='dashed', label=f"CAE {n_samples}")
    plt.plot(result["x"], result["unet"], linestyle='dotted', label=f"UNet {n_samples}")

plt.xlabel("Sample Index (×1000)")
plt.ylabel("log(JSD)")
plt.title("Comparison of JSD Scores Across Sample Sizes")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("Combined_JSD_plot.png")
plt.show()


I0000 00:00:1752613846.400686 1001188 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 839 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:03:00.0, compute capability: 8.0
I0000 00:00:1752613846.405640 1001188 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 37946 MB memory:  -> device: 1, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:41:00.0, compute capability: 8.0
I0000 00:00:1752613846.409041 1001188 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 37946 MB memory:  -> device: 2, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:82:00.0, compute capability: 8.0
I0000 00:00:1752613846.410483 1001188 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 37946 MB memory:  -> device: 3, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:c1:00.0, compute capability: 8.0


here ['22684000_CAE.keras', '22947000000_DAE.keras', '114815000_CAE.keras', '114815000_DAE.keras', '229090000_DAE.keras', '22922000_CAE.keras', '22922000_DAE.keras', '11511500000_CAE.keras', '231450000000_CAE.keras', '114435000000_DAE.keras', '115565000_CAE.keras', '22947000000_CAE.keras', '11511500000_DAE.keras', '22684000_DAE.keras', '229090000_CAE.keras', '115565000_DAE.keras', '2309400000_CAE.keras', '1148900000_CAE.keras', '2309400000_DAE.keras', '231450000000_DAE.keras', '114435000000_CAE.keras', '1148900000_DAE.keras']
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here

Processing models for 22684000 samples...


I0000 00:00:1752613848.438081 1001685 service.cc:152] XLA service 0x7f2c4c007330 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1752613848.438106 1001685 service.cc:160]   StreamExecutor device (0): NVIDIA A100-SXM4-40GB, Compute Capability 8.0
I0000 00:00:1752613848.438109 1001685 service.cc:160]   StreamExecutor device (1): NVIDIA A100-SXM4-40GB, Compute Capability 8.0
I0000 00:00:1752613848.438111 1001685 service.cc:160]   StreamExecutor device (2): NVIDIA A100-SXM4-40GB, Compute Capability 8.0
I0000 00:00:1752613848.438113 1001685 service.cc:160]   StreamExecutor device (3): NVIDIA A100-SXM4-40GB, Compute Capability 8.0
2025-07-15 14:10:48.445679: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1752613848.501811 1001685 cuda_dnn.cc:529] Loaded cuDNN version 90300


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 519ms/step


I0000 00:00:1752613848.801512 1001685 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step   
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 88ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 91ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 86ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 82ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 93ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 90ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 85ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 95ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 88ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 85ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 98ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 91m