In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model
import numpy as np

import os
import requests
from io import BytesIO
from PIL import Image
import numpy as np

import os
import random
import shutil
import matplotlib.pyplot as plt

import cv2
import gc


In [None]:
project_id = 'axial-glow-456914-n5'


from google.colab import auth
auth.authenticate_user()

In [None]:
from google.cloud import storage
storage_client = storage.Client(project=project_id)

bucket_name = 'cis-difusion-dataset'
bucket = storage_client.bucket(bucket_name)


In [None]:
def check_folder_contents(bucket, folder_name):
    """
    Checks and lists files inside a specific folder in a GCS bucket.

    Args:
        bucket: The GCS bucket object
        folder_name: The name of the folder to check

    Returns:
        A list of file names in the folder
    """
    
    if not folder_name.endswith('/'):
        folder_name += '/'

   
    blobs = list(bucket.list_blobs(prefix=folder_name))

    
    files = [blob.name for blob in blobs if blob.name != folder_name]

    
    print(f"Checking folder: {folder_name}")

    if files:
        print(f"Found {len(files)} files in the folder:")
        for file in files:
            print(f"- {file}")
    else:
        print(f"The folder '{folder_name}' is empty or doesn't exist.")

    return files


check_folder_contents(bucket, "DIV2K_train_HR")

In [None]:
folder_name = 'DIV2K_train_HR'  
local_path = '/dataset/'  

def download_files_from_gcs(bucket_name, folder_name, local_path):
    """
    Downloads files from a Google Cloud Storage folder to a local directory.
    """
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)

    os.makedirs(local_path, exist_ok=True)

    blobs = bucket.list_blobs(prefix=folder_name)
    for blob in blobs:
        if not blob.name.endswith('/'):
            file_name = os.path.basename(blob.name)
            blob.download_to_filename(os.path.join(local_path, file_name))
            print(f'Downloaded: {blob.name} to {os.path.join(local_path, file_name)}')

    print('Download complete.')


In [None]:
def download_image_batch(bucket, folder_name, local_path, batch_size=100, start_index=0):
    """Downloads a batch of images from Google Cloud Storage."""
    # Create local directory if it doesn't exist
    os.makedirs(local_path, exist_ok=True)

    blobs = list(bucket.list_blobs(prefix=folder_name))
    image_paths = []
    count = 0

    for i, blob in enumerate(blobs):
        if i < start_index:
            continue  # Skip already downloaded images
        if not blob.name.endswith('/') and count < batch_size:
            file_name = os.path.basename(blob.name)
            local_file_path = os.path.join(local_path, file_name)
            blob.download_to_filename(local_file_path)
            image_paths.append(local_file_path)  # Ensure this is a string
            count += 1
            print(f'Downloaded: {blob.name} to {local_file_path}')
        else:
            break

    print(f'Downloaded {count} images.')
    return image_paths

In [None]:
def prepare_dataset(bucket_name, folder_name, noise_std=0.1, img_size=(128, 128), is_training=True):  # Added is_training flag
    """
    Prepares the dataset for training by adding noise to the images.
    """
    def load_and_preprocess_image(path):
        # Load and preprocess image
        img = tf.io.read_file(path)
        img = tf.image.decode_png(img, channels=3)
        # Resize the image
        img = tf.image.resize(img, img_size)
        img = tf.cast(img, tf.float32) / 255.0  # Normalize to [0,1]

        if is_training:
            # Add Gaussian noise
            noise = tf.random.normal(shape=tf.shape(img), mean=0.0, stddev=noise_std)
            noisy_img = img + noise

            # Clip values to keep between [0,1]
            noisy_img = tf.clip_by_value(noisy_img, 0.0, 1.0)

            return noisy_img, img  # Return noisy and original for training
        else:
            return img, img  # Return original twice for testing

    # List paths of all images in the folder
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    blobs = bucket.list_blobs(prefix=folder_name)
    image_paths = [f'gs://{bucket_name}/{blob.name}' for blob in blobs if not blob.name.endswith('/')]

    # Create dataset
    dataset = tf.data.Dataset.from_tensor_slices(image_paths)
    dataset = dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(4)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    return dataset



In [None]:
def prepare_dataset_local(local_image_paths, noise_std=0.1, img_size=(128, 128), is_training=True):
    """
    Prepara o dataset para treinamento usando imagens locais baixadas.
    """
    # First, ensure all paths are strings
    local_image_paths = [str(path) for path in local_image_paths]

    def load_and_preprocess_image(path):
        # Ensure path is a string (TensorFlow operations might convert it)
        path = tf.convert_to_tensor(path, dtype=tf.string)

        # Print the path and its data type for debugging
        tf.print("Path type:", tf.debugging.assert_type(path, tf.string))

        # Carregar e pré-processar imagem
        img = tf.io.read_file(path)
        img = tf.image.decode_png(img, channels=3)
        # Redimensionar a imagem
        img = tf.image.resize(img, img_size)
        img = tf.cast(img, tf.float32) / 255.0  # Normalizar para [0,1]

        if is_training:
            # Adicionar ruído Gaussiano
            noise = tf.random.normal(shape=tf.shape(img), mean=0.0, stddev=noise_std)
            noisy_img = img + noise

            # Limitar valores para manter entre [0,1]
            noisy_img = tf.clip_by_value(noisy_img, 0.0, 1.0)

            return noisy_img, img  # Retornar ruidosa e original para treinamento
        else:
            return img, img  # Retornar original duas vezes para teste

    # Create dataset with explicit string type
    dataset = tf.data.Dataset.from_tensor_slices(tf.constant(local_image_paths, dtype=tf.string))
    dataset = dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(4)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    return dataset

In [None]:
class DnCNN(Model):
    def __init__(self, D, C=64):
        super(DnCNN, self).__init__()
        self.D = D
        # Create convolution layers
        self.conv_layers = [layers.Conv2D(C, kernel_size=3, padding='same', input_shape=(None, None, 3))]
        self.conv_layers.extend([layers.Conv2D(C, kernel_size=3, padding='same') for _ in range(D)])
        self.conv_layers.append(layers.Conv2D(3, kernel_size=3, padding='same'))
        # BatchNormalization doesn't take an activation parameter
        self.bn_layers = [layers.BatchNormalization() for _ in range(D)]

    def call(self, x, training=False):
        h = tf.nn.relu(self.conv_layers[0](x))
        for i in range(self.D):
            # Apply batch normalization
            h = self.bn_layers[i](self.conv_layers[i + 1](h), training=training)
            # Apply ReLU activation separately
            h = tf.nn.relu(h)
        y = self.conv_layers[-1](h) + x
        return y


In [None]:
def train_model(model, train_dataset, epochs=200, learning_rate=1e-3):
    optimizer = tf.keras.optimizers.Adam(learning_rate)
    loss_fn = tf.keras.losses.MeanSquaredError()

    model.compile(optimizer=optimizer, loss=loss_fn)
    model.fit(train_dataset, epochs=epochs)


In [None]:
def train_model_in_batches(model, bucket, train_folder_name, local_path, batch_size=100, epochs_per_batch=10):
    """Treina o modelo em batches de imagens baixadas do GCS, continuando de pesos anteriores se existirem."""
    num_batches = 5  # Você pode ajustar isso com base no número desejado de batches
    
    # Build the model before loading weights
    # Create a dummy input tensor of the expected shape
    dummy_input = tf.random.normal((1, 128, 128, 3))  # Adjust shape as needed
    _ = model(dummy_input)  # This builds the model
    
    # Check if weights file exists and load it
    weights_path = '/content/weights.weights.h5'
    if os.path.exists(weights_path):
        print("Loading existing weights to continue training...")
        model.load_weights(weights_path)
    else:
        print("No existing weights found. Training from scratch.")
    
    # Initialize start_index to track our position in the dataset
    start_index = 0
    
    for batch_index in range(num_batches):
        print(f"Treinando no batch {batch_index + 1}/{num_batches}")
        print(f"Starting from image index: {start_index}")
        
        # Clean the local directory before downloading new images
        if os.path.exists(local_path):
            shutil.rmtree(local_path)
        os.makedirs(local_path, exist_ok=True)
    

        # Download and prepare a batch of images
        image_paths = download_image_batch(bucket, 
                                          train_folder_name, 
                                          local_path, 
                                          batch_size=batch_size, 
                                          start_index=start_index)
        
        # Increment start_index for the next batch
        start_index += batch_size
        
        train_dataset = prepare_dataset_local(image_paths)
        
        # Treinar o modelo para o batch atual
        train_model(model, train_dataset, epochs=epochs_per_batch)
        
        # Salvar pesos após cada batch
        model.save_weights(weights_path)
        print(f"Treinamento do batch {batch_index + 1} completo. Pesos salvos.")
        
        # Cleanup to save memory
        del train_dataset
        gc.collect()

In [None]:
def test_network_batch(model, bucket_name, test_folder_name, noise_std=0.1, batch_size=1):
    """
    Tests the model on a batch of images loaded from GCS.
    Shows images at full resolution in a vertical layout.
    """
    # Path for downloaded images
    local_path = '/content/temp/test_images/'
    storage_client = storage.Client()

    # Load the saved weights into the model if they exist
    if os.path.exists('/content/weights.weights.h5'):
        model.load_weights('/content/weights.weights.h5')
        print("Loaded saved weights for testing.")
    else:
        print("No saved weights found. Using initialized weights.")

    # Ensure temp directory exists
    os.makedirs(local_path, exist_ok=True)

    # Download a single image for demonstration
    blobs = list(storage_client.bucket(bucket_name).list_blobs(prefix=test_folder_name))
    
    # Filter out folder entries and take just one image
    image_blobs = [blob for blob in blobs if not blob.name.endswith('/')]
    
    if not image_blobs:
        print(f"No images found in {test_folder_name}")
        return None
        
    # Just use the first image for demonstration
    blob = image_blobs[0]
    file_name = os.path.basename(blob.name)
    local_file_path = os.path.join(local_path, file_name)
    blob.download_to_filename(local_file_path)
    print(f'Downloaded: {blob.name} to {local_file_path}')
    
    # Load the image at full resolution for display
    img_highres = tf.io.read_file(local_file_path)
    img_highres = tf.image.decode_png(img_highres, channels=3)
    img_highres = tf.cast(img_highres, tf.float32) / 255.0
    
    # Store original dimensions
    original_height, original_width = img_highres.shape[0], img_highres.shape[1]
    print(f"Original image dimensions: {original_height}x{original_width}")
    
    # Create a resized version for the model
    model_input_size = (128, 128)
    img = tf.image.resize(img_highres, model_input_size)

    # Create a copy for adding noise
    img_for_noise = tf.identity(img)
    
    # Add noise to the model-size image
    noise = tf.random.normal(shape=tf.shape(img_for_noise), mean=0.0, stddev=noise_std)
    noisy_img = img_for_noise + noise
    noisy_img = tf.clip_by_value(noisy_img, 0.0, 1.0)
    
    # Add batch dimension for the model
    noisy_img_batch = tf.expand_dims(noisy_img, 0)
    
    # Process with the model
    denoised_img_batch = model(noisy_img_batch, training=False)
    denoised_img = tf.squeeze(denoised_img_batch)
    
    # Create a high-resolution noisy image for display
    noise_highres = tf.random.normal(shape=tf.shape(img_highres), mean=0.0, stddev=noise_std)
    noisy_img_highres = img_highres + noise_highres
    noisy_img_highres = tf.clip_by_value(noisy_img_highres, 0.0, 1.0)
    
    # Resize the denoised image back to high resolution for comparison
    denoised_img_highres = tf.image.resize(denoised_img, (original_height, original_width))
    
    # Display in vertical layout for better visualization
    plt.figure(figsize=(10, 24))  # Vertical figure layout
    
    # Display original image
    plt.subplot(3, 1, 1)
    plt.imshow(img_highres.numpy())
    plt.title("Original Image")
    plt.axis("off")
    
    # Display noisy image
    plt.subplot(3, 1, 2)
    plt.imshow(noisy_img_highres.numpy())
    plt.title("Noisy Image")
    plt.axis("off")
    
    # Display denoised image
    plt.subplot(3, 1, 3)
    plt.imshow(denoised_img_highres.numpy())
    plt.title("Denoised Image")
    plt.axis("off")
    
    plt.tight_layout()
    plt.show()
    
    # Print PSNR values for comparison
    noisy_psnr = tf.image.psnr(
        img_highres,
        noisy_img_highres,
        max_val=1.0
    )
    denoised_psnr = tf.image.psnr(
        img_highres,
        denoised_img_highres,
        max_val=1.0
    )
    
    print(f"Noisy Image PSNR: {noisy_psnr.numpy():.2f} dB")
    print(f"Denoised Image PSNR: {denoised_psnr.numpy():.2f} dB")
    
    # Clean up temporary files
    shutil.rmtree(local_path, ignore_errors=True)
    gc.collect()
    
    return {
        'original': img_highres.numpy(),
        'noisy': noisy_img_highres.numpy(),
        'denoised': denoised_img_highres.numpy(),
        'local_path': local_file_path
    }

In [None]:
def test_network(model, image_path, noise_std=0.1):
    # Load and preprocess the image
    img = tf.io.read_file(image_path)
    img = tf.image.decode_png(img, channels=3)
    img = tf.image.resize(img, (128, 128))
    img = tf.cast(img, tf.float32) / 255.0
    img = tf.expand_dims(img, axis=0)
    # Add Gaussian noise
    noise = tf.random.normal(shape=tf.shape(img), mean=0.0, stddev=noise_std)
    noisy_img = img + noise
    noisy_img = tf.clip_by_value(noisy_img, 0.0, 1.0)  # Keep values in [0, 1]

    # Predict denoised using the model
    denoised_img = model(noisy_img, training=False)

    # Denoise the image by subtracting the predicted noise
    denoised_img = tf.clip_by_value(denoised_img, 0.0, 1.0)

    # Visualize the results
    plt.figure(figsize=(15, 5))

    plt.subplot(2, 3, 1)
    plt.imshow(tf.squeeze(img).numpy())
    plt.title("Original Image")
    plt.axis("off")

    plt.subplot(2, 3, 2)
    plt.imshow(tf.squeeze(noisy_img).numpy())
    plt.title("Noisy Image")
    plt.axis("off")

    plt.subplot(2, 3, 3)
    plt.imshow(tf.squeeze(denoised_img).numpy())
    plt.title("Denoised Image")
    plt.axis("off")

    plt.show()
    return (img, noisy_img, denoised_img)



In [None]:
# Definir parâmetros
model = DnCNN(D=8)
train_folder_name = f"{folder_name}/DIV2K_train_HR"  # Nome da pasta de treinamento
local_path = '/content/dataset/'  # Caminho local para salvar os dados baixados
os.makedirs('/content/dataset/', exist_ok=True)
batch_size = 100  # Tamanho do batch
epochs_per_batch = 10  # Épocas por batch

# Treinar o modelo em batches
train_model_in_batches(model, bucket, train_folder_name, local_path, batch_size, epochs_per_batch)

# Testar o modelo
test_folder_name = f"{folder_name}/DIV2K_valid_HR"  # Nome da pasta de teste
result = test_network_batch(model, bucket_name, test_folder_name, noise_std=0.1)

In [None]:
# Testar o modelo
test_folder_name = "DIV2K_valid_HR"  # Nome da pasta de teste
test_network_batch(model, bucket_name, test_folder_name, noise_std=0.1)

In [None]:
def apply_median_filter(result, kernel_size=3):
    """
    Apply median filter to the noisy image and display comparison of original, noisy, and denoised images.
    """
    # Extract images from result
    denoised_img = result[0]  # First is denoised image
    noisy_img = result[1]     # Second is noisy image
    original_img = result[2]  # Third is original image

    # Convert TensorFlow tensor to numpy array - remove batch dimension
    denoised_img_np = tf.squeeze(denoised_img).numpy()
    noisy_img_np = tf.squeeze(noisy_img).numpy()
    original_img_np = tf.squeeze(original_img).numpy()

    # Apply median filter
    median_filtered = np.zeros_like(noisy_img_np)
    for i in range(3):  # Apply filter to each channel independently
        median_filtered[:, :, i] = cv2.medianBlur(noisy_img_np[:, :, i], kernel_size)

    # Create a figure for display
    plt.figure(figsize=(15, 5))

    # Display original image
    plt.subplot(1, 3, 1)
    plt.imshow(original_img_np)
    plt.title("Original Image")
    plt.axis("off")

    # Display noisy image
    plt.subplot(1, 3, 2)
    plt.imshow(noisy_img_np)
    plt.title("Noisy Image")
    plt.axis("off")

    # Display denoised image
    plt.subplot(1, 3, 3)
    plt.imshow(denoised_img_np)
    plt.title("Denoised Image")
    plt.axis("off")

    plt.tight_layout()
    plt.show()

    # Print PSNR values for comparison
    noisy_psnr = tf.image.psnr(
        tf.convert_to_tensor(original_img_np),
        tf.convert_to_tensor(noisy_img_np),
        max_val=1.0
    )
    denoised_psnr = tf.image.psnr(
        tf.convert_to_tensor(original_img_np),
        tf.convert_to_tensor(denoised_img_np),
        max_val=1.0
    )

    print(f"Noisy Image PSNR: {noisy_psnr.numpy():.2f} dB")
    print(f"Denoised Image PSNR: {denoised_psnr.numpy():.2f} dB")

In [None]:
# Test the network on a sample image
gcs_image_path = "DIV2K_valid_HR/0801.png"  # Or another image in your bucket
result = test_single_image(model, bucket, gcs_image_path, noise_std=0.1)

In [None]:
def test_single_image(model, gcs_bucket, image_path, noise_std=0.1):
    """Tests the model on a single image from GCS."""

    # Load the saved weights into the model
    if os.path.exists('/content/weights.weights.h5'):
        model.load_weights('/content/weights.weights.h5')
        print("Loaded saved weights for testing.")

    # Create temp directory
    temp_dir = '/content/temp/'
    os.makedirs(temp_dir, exist_ok=True)

    # Download the image from GCS
    blob = gcs_bucket.blob(image_path)
    local_path = os.path.join(temp_dir, os.path.basename(image_path))
    blob.download_to_filename(local_path)
    print(f"Downloaded test image to {local_path}")

    # Now process the local file
    img = tf.io.read_file(local_path)
    img = tf.image.decode_png(img, channels=3)
    img = tf.cast(img, tf.float32) / 255.0
    img = tf.image.resize(img, (128, 128))

    # Create a separate copy for adding noise
    img_for_noise = tf.identity(img)

    # Add noise
    noise = tf.random.normal(shape=tf.shape(img_for_noise), mean=0.0, stddev=noise_std)
    noisy_img = img_for_noise + noise
    noisy_img = tf.clip_by_value(noisy_img, 0.0, 1.0)

    # Add batch dimension
    noisy_img = tf.expand_dims(noisy_img, 0)
    img = tf.expand_dims(img, 0)

    # Denoise the image
    denoised_img = model(noisy_img, training=False)

    return denoised_img, noisy_img, img

In [None]:
# Then, apply the median filter and visualize the results
apply_median_filter(result, kernel_size=3)  # kernel_size is optional, default is 3