# ## Autoencodeur Convolutif Avancé pour Débruitage d'Images
# Intégration des dernières techniques d'optimisation et d'architecture

### 1. Configuration Avancée des Hyperparamètres

In [37]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model, mixed_precision
from datetime import datetime
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# Activation de la précision mixte pour accélérer l'entraînement sur GPU
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

# Hyperparamètres globaux
IMG_SIZE = (256, 256)
BATCH_SIZE = 16
EPOCHS = 100
SEED = 42
tf.random.set_seed(SEED)
np.random.seed(SEED)

### 2. Pipeline de Traitement des Données Améliorée

In [38]:
class AdvancedDataGenerator:
    def __init__(self, img_size=(256, 256)):
        self.img_size = img_size
        self.augmenter = tf.keras.Sequential([
            layers.RandomFlip("horizontal"),
            layers.RandomRotation(0.1),
            layers.RandomZoom(0.1),
            layers.RandomContrast(0.1)
        ])

    def _safe_read(self, path):
        img = tf.io.read_file(path)
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, self.img_size)
        img = tf.image.convert_image_dtype(img, tf.float32)
        return img

    def _add_noise(self, image):
        noise_type = tf.random.uniform([], 0, 4, dtype=tf.int32)
        def gaussian():
            stddev = tf.random.uniform([], 0.01, 0.2)
            noise = tf.random.normal(tf.shape(image), 0.0, stddev)
            return tf.clip_by_value(image + noise, 0.0, 1.0)
        def salt_pepper():
            amount = tf.random.uniform([], 0.01, 0.08)
            salt = tf.cast(tf.random.uniform(tf.shape(image)) < amount, tf.float32)
            pepper = tf.cast(tf.random.uniform(tf.shape(image)) > 1 - amount, tf.float32)
            return tf.clip_by_value(image + salt - pepper, 0.0, 1.0)
        def speckle():
            stddev = tf.random.uniform([], 0.01, 0.2)
            noise = tf.random.normal(tf.shape(image), 0.0, stddev)
            return tf.clip_by_value(image + image * noise, 0.0, 1.0)
        def poisson():
            vals = 2 ** tf.random.uniform([], 2, 8, dtype=tf.int32)
            noisy = tf.random.poisson([], image * vals) / tf.cast(vals, tf.float32)
            return tf.clip_by_value(noisy, 0.0, 1.0)
        return tf.switch_case(noise_type, [gaussian, salt_pepper, speckle, poisson])

    def generate_pair(self, path):
        clean = self._safe_read(path)
        aug = self.augmenter(tf.expand_dims(clean, 0))[0]
        noisy = self._add_noise(aug)
        return noisy, aug
    
def build_datasets(clean_dir, batch_size=BATCH_SIZE, img_size=IMG_SIZE):
    paths = [os.path.join(clean_dir, f) for f in os.listdir(clean_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    train_paths, val_paths = train_test_split(paths, test_size=0.15, random_state=SEED)
    generator = AdvancedDataGenerator(img_size=img_size)

    def _tf_pair(path):
        noisy, clean = tf.py_function(generator.generate_pair, [path], [tf.float32, tf.float32])
        noisy.set_shape((*img_size, 3))
        clean.set_shape((*img_size, 3))
        return noisy, clean

    train_ds = tf.data.Dataset.from_tensor_slices(train_paths).shuffle(1000, seed=SEED)
    train_ds = train_ds.map(_tf_pair, num_parallel_calls=tf.data.AUTOTUNE).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    val_ds = tf.data.Dataset.from_tensor_slices(val_paths)
    val_ds = val_ds.map(_tf_pair, num_parallel_calls=tf.data.AUTOTUNE).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return train_ds, val_ds

### 3. Architecture du Modèle State-of-the-Art

In [39]:
def conv_block(x, filters, kernel_size=3, activation='swish'):
    x = layers.Conv2D(filters, kernel_size, padding='same', kernel_initializer='he_normal')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation(activation)(x)
    return x

def encoder_block(x, filters):
    c = conv_block(x, filters)
    c = conv_block(c, filters)
    p = layers.MaxPooling2D((2, 2))(c)
    return c, p

def decoder_block(x, skip, filters):
    x = layers.UpSampling2D((2, 2))(x)
    x = layers.Concatenate()([x, skip])
    x = conv_block(x, filters)
    x = conv_block(x, filters)
    return x

def build_unet(input_shape=(256, 256, 3)):
    inputs = layers.Input(shape=input_shape)
    # Encoder
    c1, p1 = encoder_block(inputs, 32)
    c2, p2 = encoder_block(p1, 64)
    c3, p3 = encoder_block(p2, 128)
    c4, p4 = encoder_block(p3, 256)
    # Bottleneck
    bn = conv_block(p4, 512)
    bn = conv_block(bn, 512)
    # Decoder
    d4 = decoder_block(bn, c4, 256)
    d3 = decoder_block(d4, c3, 128)
    d2 = decoder_block(d3, c2, 64)
    d1 = decoder_block(d2, c1, 32)
    outputs = layers.Conv2D(3, 1, activation='sigmoid', dtype='float32')(d1)
    return Model(inputs, outputs, name='UNetDenoiser')

### 4. Fonction de Perte Hybride Optimisée

In [40]:
class HybridLoss(tf.keras.losses.Loss):
    def __init__(self, alpha=0.8, beta=0.2, **kwargs):
        super().__init__(**kwargs)
        self.alpha = alpha
        self.beta = beta
        self.epsilon = 1e-6
        self.feature_extractor = tf.keras.applications.VGG16(
            include_top=False, weights='imagenet', input_shape=(128, 128, 3), pooling='avg'
        )
        self.feature_extractor.trainable = False

    def call(self, y_true, y_pred):
        y_true = tf.clip_by_value(y_true, self.epsilon, 1.0)
        y_pred = tf.clip_by_value(y_pred, self.epsilon, 1.0)
        # SSIM Loss
        ssim_loss = 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, max_val=1.0))
        # Perceptual Loss
        y_true_resized = tf.image.resize(y_true, (128, 128))
        y_pred_resized = tf.image.resize(y_pred, (128, 128))
        f_true = self.feature_extractor(y_true_resized)
        f_pred = self.feature_extractor(y_pred_resized)
        perceptual_loss = tf.reduce_mean(tf.square(f_true - f_pred))
        # Correction : cast en float32 avant addition
        ssim_loss = tf.cast(ssim_loss, tf.float32)
        perceptual_loss = tf.cast(perceptual_loss, tf.float32)
        alpha = tf.cast(self.alpha, tf.float32)
        beta = tf.cast(self.beta, tf.float32)
        return alpha * ssim_loss + beta * perceptual_loss

### 5. Stratégie d'Entraînement Avancée

In [41]:
def compile_and_callbacks(model, lr=2e-4):
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
    model.compile(
        optimizer=optimizer,
        loss=HybridLoss(),
        metrics=[tf.keras.metrics.MeanAbsoluteError(), tf.keras.metrics.RootMeanSquaredError()]
    )
    callbacks = [
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
        tf.keras.callbacks.ModelCheckpoint('best_unet_denoiser.keras', save_best_only=True, monitor='val_loss'),
        tf.keras.callbacks.TensorBoard(log_dir='logs/' + datetime.now().strftime("%Y%m%d-%H%M%S"))
    ]
    return callbacks

### 6. Évaluation et Inférence

In [42]:
# %% [markdown]
### 5. Pipeline Complète

def build_pipeline(clean_dir, batch_size=16):
    paths = [os.path.join(clean_dir, f) for f in os.listdir(clean_dir)]
    train_paths, val_paths = train_test_split(paths, test_size=0.2)
    
    generator = AdvancedDataGenerator(clean_dir)
    
    def _dataset(paths):
        ds = tf.data.Dataset.from_tensor_slices(paths)
        ds = ds.map(lambda x: tf.py_function(
            generator.generate_pair, [x], [tf.float32, tf.float32]),
            num_parallel_calls=tf.data.AUTOTUNE
        )
        ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
        return ds
    
    return _dataset(train_paths), _dataset(val_paths)

### 7. Pipeline Complète d'Entraînement

In [43]:

CLEAN_DIR = "/home/kevin/datasets/livrable2/raw"

train_ds, val_ds = build_datasets(CLEAN_DIR)
model = build_unet(input_shape=(*IMG_SIZE, 3))
callbacks = compile_and_callbacks(model)

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=2
)

Epoch 1/100


2025-04-17 12:08:46.115932: W tensorflow/core/framework/op_kernel.cc:1844] UNKNOWN: InvalidArgumentError: cannot compute Mul as input #1(zero-based) was expected to be a half tensor but is a int32 tensor [Op:Mul] name: 
Traceback (most recent call last):

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 267, in __call__
    return func(device, token, args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 145, in __call__
    outputs = self._call(device, args)
              ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 152, in _call
    ret = self._func(*args)
          ^^^^^^^^^^^^^^^^^

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper
    return func(*args, **kwargs)


UnknownError: Graph execution error:

Detected at node EagerPyFunc defined at (most recent call last):
<stack traces unavailable>
Detected at node EagerPyFunc defined at (most recent call last):
<stack traces unavailable>
2 root error(s) found.
  (0) UNKNOWN:  InvalidArgumentError: cannot compute Mul as input #1(zero-based) was expected to be a half tensor but is a float tensor [Op:Mul] name: 
Traceback (most recent call last):

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 267, in __call__
    return func(device, token, args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 145, in __call__
    outputs = self._call(device, args)
              ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 152, in _call
    ret = self._func(*args)
          ^^^^^^^^^^^^^^^^^

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^

  File "/tmp/ipykernel_3461/1764851286.py", line 42, in generate_pair
    noisy = self._add_noise(aug)
            ^^^^^^^^^^^^^^^^^^^^

  File "/tmp/ipykernel_3461/1764851286.py", line 37, in _add_noise
    return tf.switch_case(noise_type, [gaussian, salt_pepper, speckle, poisson])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/ops/control_flow_switch_case.py", line 253, in switch_case
    return _indexed_case_helper(branch_fns, default, branch_index, name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/ops/control_flow_switch_case.py", line 125, in _indexed_case_helper
    return branch_fns[int(branch_index)]()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/tmp/ipykernel_3461/1764851286.py", line 32, in speckle
    return tf.clip_by_value(image + image * noise, 0.0, 1.0)
                                    ~~~~~~^~~~~~~

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/framework/ops.py", line 6006, in raise_from_not_ok_status
    raise core._status_to_exception(e) from None  # pylint: disable=protected-access
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute Mul as input #1(zero-based) was expected to be a half tensor but is a float tensor [Op:Mul] name: 


	 [[{{node EagerPyFunc}}]]
	 [[IteratorGetNext]]
	 [[IteratorGetNext/_4]]
  (1) UNKNOWN:  InvalidArgumentError: cannot compute Mul as input #1(zero-based) was expected to be a half tensor but is a float tensor [Op:Mul] name: 
Traceback (most recent call last):

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 267, in __call__
    return func(device, token, args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 145, in __call__
    outputs = self._call(device, args)
              ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 152, in _call
    ret = self._func(*args)
          ^^^^^^^^^^^^^^^^^

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^

  File "/tmp/ipykernel_3461/1764851286.py", line 42, in generate_pair
    noisy = self._add_noise(aug)
            ^^^^^^^^^^^^^^^^^^^^

  File "/tmp/ipykernel_3461/1764851286.py", line 37, in _add_noise
    return tf.switch_case(noise_type, [gaussian, salt_pepper, speckle, poisson])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/ops/control_flow_switch_case.py", line 253, in switch_case
    return _indexed_case_helper(branch_fns, default, branch_index, name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/ops/control_flow_switch_case.py", line 125, in _indexed_case_helper
    return branch_fns[int(branch_index)]()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/tmp/ipykernel_3461/1764851286.py", line 32, in speckle
    return tf.clip_by_value(image + image * noise, 0.0, 1.0)
                                    ~~~~~~^~~~~~~

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/framework/ops.py", line 6006, in raise_from_not_ok_status
    raise core._status_to_exception(e) from None  # pylint: disable=protected-access
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute Mul as input #1(zero-based) was expected to be a half tensor but is a float tensor [Op:Mul] name: 


	 [[{{node EagerPyFunc}}]]
	 [[IteratorGetNext]]
0 successful operations.
0 derived errors ignored. [Op:__inference_multi_step_on_iterator_65412]

2025-04-17 12:08:46.520657: W tensorflow/core/framework/op_kernel.cc:1844] UNKNOWN: InvalidArgumentError: cannot compute Mul as input #1(zero-based) was expected to be a half tensor but is a int32 tensor [Op:Mul] name: 
Traceback (most recent call last):

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 267, in __call__
    return func(device, token, args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 145, in __call__
    outputs = self._call(device, args)
              ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 152, in _call
    ret = self._func(*args)
          ^^^^^^^^^^^^^^^^^

  File "/home/kevin/dev/tf217/tf217/lib/python3.12/site-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper
    return func(*args, **kwargs)


In [None]:
def plot_denoising(model, dataset, n=3):
    for noisy, clean in dataset.take(1):
        preds = model.predict(noisy)
        for i in range(n):
            plt.figure(figsize=(12, 4))
            plt.subplot(1, 3, 1)
            plt.title("Noisy")
            plt.imshow(noisy[i].numpy())
            plt.axis('off')
            plt.subplot(1, 3, 2)
            plt.title("Denoised")
            plt.imshow(preds[i])
            plt.axis('off')
            plt.subplot(1, 3, 3)
            plt.title("Clean")
            plt.imshow(clean[i].numpy())
            plt.axis('off')
            plt.show()

# Exemple d'utilisation :
plot_denoising(model, val_ds)