In [2]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Conv2DTranspose, concatenate, Input
from tensorflow.keras.models import Model

def unet_model(input_size=(256, 256, 3)):
    inputs = Input(input_size)
    
    # Encoder
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    
    conv5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(conv5)
    
    # Decoder
    up6 = Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(conv5)
    merge6 = concatenate([conv4, up6], axis=3)
    conv6 = Conv2D(512, (3, 3), activation='relu', padding='same')(merge6)
    conv6 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv6)
    
    up7 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv6)
    merge7 = concatenate([conv3, up7], axis=3)
    conv7 = Conv2D(256, (3, 3), activation='relu', padding='same')(merge7)
    conv7 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv7)
    
    up8 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv7)
    merge8 = concatenate([conv2, up8], axis=3)
    conv8 = Conv2D(128, (3, 3), activation='relu', padding='same')(merge8)
    conv8 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv8)
    
    up9 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv8)
    merge9 = concatenate([conv1, up9], axis=3)
    conv9 = Conv2D(64, (3, 3), activation='relu', padding='same')(merge9)
    conv9 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv9)
    
    conv10 = Conv2D(3, (1, 1), activation='sigmoid')(conv9)
    
    model = Model(inputs, conv10)
    
    model.compile(optimizer='adam', loss='mean_squared_error', metrics=['accuracy'])
    
    return model

model = unet_model()
model.summary()

In [5]:
import numpy as np
import cv2
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt

def load_data(folder_path, image_size=(256, 256)):
    clean_images = []
    noisy_images = []
    
    for root, _, files in os.walk(folder_path):
        for file in files:
            if file.startswith('GT_SRGB_010') and file.endswith('.PNG'):
                clean_image_path = os.path.join(root, file)
                noisy_image_path = os.path.join(root, file.replace('GT_SRGB', 'NOISY_SRGB'))
                
                clean_image = cv2.imread(clean_image_path)
                noisy_image = cv2.imread(noisy_image_path)
                
                clean_image = cv2.resize(clean_image, image_size)
                noisy_image = cv2.resize(noisy_image, image_size)
                
                clean_images.append(clean_image)
                noisy_images.append(noisy_image)
    
    clean_images = np.array(clean_images, dtype=np.float32) / 255.0
    noisy_images = np.array(noisy_images, dtype=np.float32) / 255.0
    
    return clean_images, noisy_images

folder_path = 'E:\\archive (1)\\SIDD_Small_sRGB_Only\\Data'
clean_images, noisy_images = load_data(folder_path)

# Split data into training and validation sets
split_ratio = 0.8
split_index = int(len(clean_images) * split_ratio)
x_train, x_val = noisy_images[:split_index], noisy_images[split_index:]
y_train, y_val = clean_images[:split_index], clean_images[split_index:]

# Data augmentation
datagen = ImageDataGenerator(horizontal_flip=True, vertical_flip=True, rotation_range=90)

# Train the model
batch_size = 8
epochs = 50

model.fit(datagen.flow(x_train, y_train, batch_size=batch_size),
          validation_data=(x_val, y_val),
          steps_per_epoch=len(x_train) // batch_size,
          epochs=epochs)

# Make predictions on the validation set
predicted_images = model.predict(x_val)

# Function to visualize original, noisy, and denoised images
def visualize_denoising(noisy_images, denoised_images, clean_images, num_images=10):
    plt.figure(figsize=(15, 15))
    for i in range(num_images):
        # Original noisy image
        plt.subplot(num_images, 3, 3 * i + 1)
        plt.title("Noisy Image")
        plt.imshow(noisy_images[i])
        plt.axis('off')
        
        # Denoised image (predicted)
        plt.subplot(num_images, 3, 3 * i + 2)
        plt.title("Denoised Image")
        plt.imshow(denoised_images[i])
        plt.axis('off')
        
        # Ground truth clean image
        plt.subplot(num_images, 3, 3 * i + 3)
        plt.title("Clean Image")
        plt.imshow(clean_images[i])
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize some examples from the validation set
visualize_denoising(x_val, predicted_images, y_val, num_images=10)

Epoch 1/50


ResourceExhaustedError: Graph execution error:

Detected at node functional_1_1/concatenate_3_1/concat defined at (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main

  File "<frozen runpy>", line 88, in _run_code

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\ipykernel_launcher.py", line 18, in <module>

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\traitlets\config\application.py", line 1075, in launch_instance

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\ipykernel\kernelapp.py", line 739, in start

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\tornado\platform\asyncio.py", line 205, in start

  File "c:\Python312\Lib\asyncio\base_events.py", line 638, in run_forever

  File "c:\Python312\Lib\asyncio\base_events.py", line 1971, in _run_once

  File "c:\Python312\Lib\asyncio\events.py", line 84, in _run

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\ipykernel\kernelbase.py", line 545, in dispatch_queue

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\ipykernel\kernelbase.py", line 534, in process_one

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\ipykernel\kernelbase.py", line 437, in dispatch_shell

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\ipykernel\ipkernel.py", line 362, in execute_request

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\ipykernel\kernelbase.py", line 778, in execute_request

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\ipykernel\ipkernel.py", line 449, in do_execute

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\ipykernel\zmqshell.py", line 549, in run_cell

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\IPython\core\interactiveshell.py", line 3075, in run_cell

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\IPython\core\interactiveshell.py", line 3130, in _run_cell

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\IPython\core\interactiveshell.py", line 3334, in run_cell_async

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\IPython\core\interactiveshell.py", line 3517, in run_ast_nodes

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\IPython\core\interactiveshell.py", line 3577, in run_code

  File "C:\Users\yasha\AppData\Local\Temp\ipykernel_15388\1582946878.py", line 54, in <module>

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\utils\traceback_utils.py", line 117, in error_handler

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\backend\tensorflow\trainer.py", line 314, in fit

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\backend\tensorflow\trainer.py", line 117, in one_step_on_iterator

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\backend\tensorflow\trainer.py", line 104, in one_step_on_data

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\backend\tensorflow\trainer.py", line 51, in train_step

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\utils\traceback_utils.py", line 117, in error_handler

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\layers\layer.py", line 846, in __call__

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\utils\traceback_utils.py", line 117, in error_handler

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\ops\operation.py", line 48, in __call__

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\utils\traceback_utils.py", line 156, in error_handler

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\models\functional.py", line 202, in call

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\ops\function.py", line 155, in _run_through_graph

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\models\functional.py", line 592, in call

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\utils\traceback_utils.py", line 117, in error_handler

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\layers\layer.py", line 846, in __call__

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\utils\traceback_utils.py", line 117, in error_handler

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\ops\operation.py", line 48, in __call__

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\utils\traceback_utils.py", line 156, in error_handler

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\layers\merging\base_merge.py", line 189, in call

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\layers\merging\concatenate.py", line 101, in _merge_function

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\ops\numpy.py", line 1352, in concatenate

  File "C:\Users\yasha\AppData\Roaming\Python\Python312\site-packages\keras\src\backend\tensorflow\numpy.py", line 883, in concatenate

OOM when allocating tensor with shape[8,256,256,128] and type float on /job:localhost/replica:0/task:0/device:CPU:0 by allocator mklcpu
	 [[{{node functional_1_1/concatenate_3_1/concat}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_one_step_on_iterator_7954]