This is our recreation of Lehtinen et al.'s paper "Noise2Noise". 

## Google Colab setup

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

## Imports

In [None]:
import tensorflow as tf
import keras
from importlib import reload

In [None]:
from helpers import generators, utils, noise_models
from models import unet_n2n, irfan_2020

In [None]:
reload(utils)
reload(noise_models)

## Initialise data

In [None]:
# TRAIN_IMAGE_DIR = "drive/MyDrive/imagenet-10k"
TRAIN_IMAGE_DIR = "../../raw_datasets/imagenet-10k"
# VAL_IMAGE_DIR = "drive/MyDrive/BSDS300/images/train"
VAL_IMAGE_DIR = "../../raw_datasets/BSDS300/images/train"
# TEST_IMAGE_DIR = "drive/MyDrive/BSDS300/images/test"
TEST_IMAGE_DIR = "../../raw_datasets/BSDS300/images/test"

NUM_EPOCHS = 10
GPU_BATCH_SIZE = 4
DATA_BATCH_SIZE = 16

In [None]:
noise_model = noise_models.gaussian_noise

In [None]:
training_gen = generators.NoisyImageTrainGenerator(
    image_dir=TRAIN_IMAGE_DIR,
    input_noise_model=noise_model,
    target_noise_model=noise_model,
    batch_size=DATA_BATCH_SIZE,
    patch_edge_size=192,
    zero_one_normalisation=True,
    greyscale=True
)

validation_gen = generators.NoisyImageValGenerator(
    image_dir=VAL_IMAGE_DIR,
    noise_model=noise_model,
    batch_size=DATA_BATCH_SIZE,
    patch_edge_size=192,
    zero_one_normalisation=True,
    greyscale=True
)

test_gen = generators.NoisyImageValGenerator(
    image_dir=TEST_IMAGE_DIR,
    noise_model=noise_model,
    batch_size=DATA_BATCH_SIZE,
    patch_edge_size=192,
    zero_one_normalisation=True,
    greyscale=True
)

## Initialise model

In [None]:
def psnr(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, max_val=1.0)

In [None]:
# model = irfan_2020.get_irfan_model(input_shape=(192, 192, 1))
model = unet_n2n.get_unet_model(input_shape=(192, 192, 1))

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-5, beta_1=0.9, beta_2=0.99, epsilon=1e-8),
    loss=keras.losses.MeanSquaredError(),
    metrics=[psnr],
)

model.summary()

## Train model

In [None]:
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    # filepath="drive/MyDrive/saved_models/n2n_imagenet_unet/{epoch:02d}.keras",
    filepath="models/saved/n2n_imagenet10k_05122024/unet/{epoch:02d}.keras",
    monitor='val_psnr',
    mode='max',
    save_best_only=True
)

csv_logger = keras.callbacks.CSVLogger(
    # filename='drive/MyDrive/saved_models/n2n_imagenet_unet/training.log',
    filename="models/saved/n2n_imagenet10k_05122024/unet/training.log",
    separator=',',
    append=False
)

backup_callback = keras.callbacks.BackupAndRestore(
    # backup_dir="drive/MyDrive/saved_models/n2n_imagenet_unet"
    backup_dir="models/saved/n2n_imagenet10k_05122024/unet/tmp"
)

In [None]:
history = model.fit(
    training_gen,
    validation_data=validation_gen,
    batch_size=GPU_BATCH_SIZE,
    epochs=NUM_EPOCHS,
    verbose=2,
    callbacks=[csv_logger, backup_callback]
)

In [None]:
model.save("models/saved/n2n_imagenet10k_05122024/unet/unet_10epochs.keras")

In [None]:
reload(utils)

# fig_psnr, fig_loss = utils.get_psnr_and_loss_curves(history, together=False)
# fig_psnr.savefig('models/saved/n2n_imagenet10k_05122024/irfan/img/psnr.pdf')
# fig_loss.savefig('models/saved/n2n_imagenet10k_05122024/irfan/img/loss.pdf')
# fig_psnr.show()
# fig_loss.show()

together = utils.get_psnr_and_loss_curves(history, together=True)
together.savefig('models/saved/n2n_imagenet10k_05122024/unet/img/psnr_loss.pdf')
together.show()


## Evaluate model

In [None]:
evals = model.evaluate(
    test_gen,
    batch_size=GPU_BATCH_SIZE,
    verbose=2
)

In [None]:
evals

In [None]:
# reload(utils)

# ground_truth_fig, patches_fig = utils.test_model_on_image(
#     image_path=TEST_IMAGE_DIR + "/285079.jpg", 
#     model=model,
#     patch_size=192, 
#     zero_one_normalisation=True, 
#     greyscale=True,
#     stddev=30
# )

# ground_truth_fig.savefig('models/saved/n2n_imagenet10k_05122024/irfan/img/ground_truth_fig_1.pdf')
# patches_fig.savefig('models/saved/n2n_imagenet10k_05122024/irfan/img/patches_fig_1.pdf')

In [None]:
reload(utils)

ground_truth_fig, patches_fig = utils.test_model_on_image(
    image_path=TEST_IMAGE_DIR + "/108005.jpg", 
    model=model,
    patch_size=192, 
    zero_one_normalisation=True, 
    greyscale=True,
    stddev=30
)

ground_truth_fig.savefig('models/saved/n2n_imagenet10k_05122024/unet/img/ground_truth_fig_2.pdf')
patches_fig.savefig('models/saved/n2n_imagenet10k_05122024/unet/img/patches_fig_2.pdf')

In [None]:
reload(utils)

ground_truth_fig, patches_fig = utils.test_model_on_image(
    image_path=TEST_IMAGE_DIR + "/148089.jpg", 
    model=model,
    patch_size=192, 
    zero_one_normalisation=True, 
    greyscale=True,
    stddev=30
)

ground_truth_fig.savefig('models/saved/n2n_imagenet10k_05122024/unet/img/ground_truth_fig_3.pdf')
patches_fig.savefig('models/saved/n2n_imagenet10k_05122024/unet/img/patches_fig_3.pdf')