<a href="https://colab.research.google.com/github/adam-blip/test/blob/master/hdr%2B.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# stacking low light images for hdr creation

In this notebook, we will perform the following steps:

1. **Load and Preprocess the Image**:
   - Load an image from the `dataset` folder.
   - Extract a random patch of at least half the original size.
   - Perform various augmentations on the patch.

2. **Augment the Image**:
   - Apply transformations such as rotation, flipping, color and brightness adjustments, noise addition, and darkening to simulate night images.
   - Repeat this process to create a total of 5 augmented images.

3. **Build the Model**:
   - Construct a TensorFlow Keras model with 2 residual blocks and channel attention.
   - The model will take 5 input images and produce a single output image.

4. **Train the Model**:
   - Train the model for 100 epochs, minimizing the mean squared error (MSE) between the output image and the original image.
   - Visualize the training process every epoch.

5. **Convert to TensorFlow Lite**:
   - Convert the trained model to TensorFlow Lite format with quantization for efficient inference on low-budget phones.

6. **Test Inference Speed**:
   - Compare the inference speed of the original model and the quantized TensorFlow Lite model.


In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
from skimage.io import imread
from skimage.util import random_noise
from skimage.transform import rotate, rescale
from skimage.color import rgb2hsv, hsv2rgb
import time

# Define paths and constants
dataset_folder = 'dataset'
image_path = os.path.join(dataset_folder, os.listdir(dataset_folder)[0])  # Load the first image in the dataset folder

# Load the image
original_image = imread(image_path)
original_image_float = original_image.astype('float32') / 255.0

# Function to extract a random patch
def extract_random_patch(image, min_size=0.5):
    height, width, _ = image.shape
    min_dim = int(min(height, width) * min_size)
    patch_height = np.random.randint(min_dim, height)
    patch_width = np.random.randint(min_dim, width)
    y = np.random.randint(0, height - patch_height)
    x = np.random.randint(0, width - patch_width)
    return image[y:y+patch_height, x:x+patch_width]

# Function to augment the image
def augment_image(image):
    # Random rotation
    angle = np.random.uniform(-25, 25)
    image = rotate(image, angle)

    # Random flipping
    if np.random.rand() > 0.5:
        image = np.fliplr(image)
    if np.random.rand() > 0.5:
        image = np.flipud(image)

    # Random color and brightness adjustment
    image_hsv = rgb2hsv(image)
    image_hsv[:, :, 1] *= np.random.uniform(0.5, 1.5)  # Adjust saturation
    image_hsv[:, :, 2] *= np.random.uniform(0.5, 1.5)  # Adjust brightness
    image = hsv2rgb(image_hsv)

    # Add random noise
    image = random_noise(image, mode='gaussian', var=0.01)

    # Darken the image to simulate night
    image = np.clip(image * np.random.uniform(0.5, 1.0), 0, 1)

    return image

# Extract a random patch from the original image
patch = extract_random_patch(original_image_float)

# Generate augmented images
augmented_images = [patch]
for _ in range(4):
    augmented_images.append(augment_image(augmented_images[-1]))

# Plot the original and augmented images
plt.figure(figsize=(15, 5))
for i, img in enumerate([original_image_float] + augmented_images):
    plt.subplot(2, 3, i+1)
    plt.imshow(img)
    plt.title(f"Image {i}")
    plt.axis('off')
plt.show()

# Define a residual block with channel attention
def residual_block(x, filters):
    shortcut = x
    x = layers.Conv2D(filters, (3, 3), padding='same', activation='relu')(x)
    x = layers.Conv2D(filters, (3, 3), padding='same')(x)
    x = layers.Add()([x, shortcut])
    x = layers.Activation('relu')(x)
    return x

def channel_attention(x, filters):
    avg_pool = layers.GlobalAveragePooling2D()(x)
    max_pool = layers.GlobalMaxPooling2D()(x)
    avg_pool = layers.Reshape((1, 1, filters))(avg_pool)
    max_pool = layers.Reshape((1, 1, filters))(max_pool)
    shared_layer_one = layers.Dense(filters // 2, activation='relu')
    shared_layer_two = layers.Dense(filters, activation='sigmoid')
    avg_pool = shared_layer_two(shared_layer_one(avg_pool))
    max_pool = shared_layer_two(shared_layer_one(max_pool))
    cbam_feature = layers.Add()([avg_pool, max_pool])
    cbam_feature = layers.Activation('sigmoid')(cbam_feature)
    return layers.Multiply()([x, cbam_feature])

# Build the model
input_shape = augmented_images[0].shape
inputs = layers.Input(shape=input_shape)
x = layers.Conv2D(64, (3, 3), padding='same', activation='relu')(inputs)
x = residual_block(x, 64)
x = channel_attention(x, 64)
x = residual_block(x, 64)
x = channel_attention(x, 64)
outputs = layers.Conv2D(3, (3, 3), padding='same', activation='sigmoid')(x)

model = models.Model(inputs, outputs)
model.compile(optimizer='adam', loss='mean_squared_error')

# Prepare the data
augmented_images = np.array(augmented_images)
target_image = np.expand_dims(patch, axis=0)

# Train the model
history = model.fit(augmented_images, target_image, epochs=100, verbose=1)

# Visualize training process
plt.plot(history.history['loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.show()

# Convert the model to TensorFlow Lite with quantization
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the TFLite model
with open('model.tflite', 'wb') as f:
    f.write(tflite_model)

# Test inference speed
def test_inference_speed(model, input_data, num_runs=100):
    start_time = time.time()
    for _ in range(num_runs):
        model.predict(input_data)
    end_time = time.time()
    return (end_time - start_time) / num_runs

# Original model inference speed
original_model_time = test_inference_speed(model, augmented_images)
print(f'Original model inference time: {original_model_time:.6f} seconds')

# TFLite model inference speed
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

def tflite_inference(interpreter, input_data):
    interpreter.set_tensor(input_details[0]['index'], input_data)
    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])
    return output_data

tflite_model_time = test_inference_speed(lambda x: tflite_inference(interpreter, x), augmented_images)
print(f'TFLite model inference time: {tflite_model_time:.6f} seconds')
