In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

def create_augmenter():
    return A.Compose([
        A.RandomRotate90(p=0.5),  # Randomly rotate images 90 degrees
        A.HorizontalFlip(p=0.5),   # Flip images horizontally
        A.VerticalFlip(p=0.5),     # Flip images vertically
        A.RandomBrightnessContrast(p=0.2),  # Random brightness/contrast adjustment
        A.GaussianNoise(p=0.2),   # Add Gaussian noise for denoising
        A.PadIfNeeded(min_height=256, min_width=256, p=1),  # Pad if necessary
        A.RandomCrop(width=256, height=256, p=1),  # Crop a random region
        ToTensorV2(),  # Convert images to tensors
    ])
    
    
def train(self, X, Y, validation_data, epochs=None, steps_per_epoch=None, augmenter=None):
    """Train the neural network with the given data.

    Parameters
    ----------
    X : :class:`numpy.ndarray`
        Array of source images (noisy images).
    Y : :class:`numpy.ndarray`
        Array of target images (clean images).
    validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`)
        Tuple of arrays for source and target validation images.
    epochs : int
        Optional argument to use instead of the value from ``config``.
    steps_per_epoch : int
        Optional argument to use instead of the value from ``config``.
    augmenter : callable
        A callable augmenter (e.g., from Albumentations).

    Returns
    -------
    ``History`` object
        See `Keras training history <https://keras.io/models/model/#fit>`_.
    """

    ((isinstance(validation_data, (list, tuple)) and len(validation_data) == 2)
     or _raise(ValueError('validation_data must be a pair of numpy arrays')))

    n_train, n_val = len(X), len(validation_data[0])
    frac_val = (1.0 * n_val) / (n_train + n_val)
    frac_warn = 0.05
    if frac_val < frac_warn:
        warnings.warn("small number of validation images (only %.1f%% of all images)" % (100 * frac_val))
    
    if epochs is None:
        epochs = self.config.train_epochs
    if steps_per_epoch is None:
        steps_per_epoch = self.config.train_steps_per_epoch

    if not self._model_prepared:
        self.prepare_for_training()

    if (self.config.train_tensorboard and self.basedir is not None and
        not IS_TF_1 and not any(isinstance(cb, CARETensorBoardImage) for cb in self.callbacks)):
        self.callbacks.append(CARETensorBoardImage(model=self.keras_model, data=validation_data,
                                                   log_dir=str(self.logdir / 'logs' / 'images'),
                                                   n_images=3, prob_out=self.config.probabilistic))

    # Apply augmentation to both X (input) and Y (output) using Albumentations
    class AlbumentationsDataWrapper(train.DataWrapper):
        def __init__(self, X, Y, batch_size, length, augmenter=None):
            super().__init__(X, Y, batch_size, length)
            self.augmenter = augmenter

        def _get_batch(self, indices):
            X_batch, Y_batch = super()._get_batch(indices)
            # Apply the augmentation to both input (X) and output (Y)
            augmented = [self.augmenter(image=X_img, mask=Y_img) for X_img, Y_img in zip(X_batch, Y_batch)]
            X_batch = np.array([aug["image"] for aug in augmented])
            Y_batch = np.array([aug["mask"] for aug in augmented])
            return X_batch, Y_batch

    # Instantiate the custom DataWrapper
    training_data = AlbumentationsDataWrapper(X, Y, self.config.train_batch_size, length=epochs * steps_per_epoch, augmenter=augmenter)

    fit = self.keras_model.fit_generator if IS_TF_1 else self.keras_model.fit
    history = fit(iter(training_data), validation_data=validation_data,
                  epochs=epochs, steps_per_epoch=steps_per_epoch,
                  callbacks=self.callbacks, verbose=1)
    self._training_finished()

    return history

augmenter = create_augmenter()

# Train the model with augmentation
history = model.train(X_train, Y_train, validation_data=(X_val, Y_val), epochs=10, steps_per_epoch=100, augmenter=augmenter)