In [1]:
import cv2
import numpy as np
import tensorflow as tf
import time


import tensorflow as tf
from tensorflow.keras.losses import MeanSquaredError, MeanAbsoluteError
from tensorflow.image import ssim  # Import SSIM

# Define your model and its architecture

# Define loss calculation function
def total_loss(y_true, y_pred):
    lambda_mse = 2.5  # Adjust as needed
    lambda_mae = 2.5
    lambda_ssim = 2  # Adjust as needed
    lambda_color = 20  # Adjust the weight as needed

    # Compute MSE loss
    mse = MeanSquaredError()(y_true, y_pred)

    # Compute MAE loss
    mae = MeanAbsoluteError()(y_true, y_pred)

    # Compute SSIM loss
    ssim_loss = 1 - ssim(y_true, y_pred, max_val=1.0)  # max_val should match your image range (e.g., 0-1)

    # Custom color loss based on color histograms
    def color_loss(y_true, y_pred):
        # Convert the images to grayscale
        y_true_gray = tf.image.rgb_to_grayscale(y_true)
        y_pred_gray = tf.image.rgb_to_grayscale(y_pred)

        # Calculate the MSE between the grayscale images
        color_loss = tf.reduce_mean(tf.square(y_true_gray - y_pred_gray))

        return color_loss

    # Compute color loss
    color = color_loss(y_true, y_pred)

    # Combine losses with weights
    loss = (
        lambda_mse * mse +
        lambda_mae * mae +
        lambda_ssim * ssim_loss +
        lambda_color * color
    )
    return loss



# Load the pre-trained dehazing model
#modelu2 is the best
#modeulu3 is promising
#model121 is the best and promising
#model = tf.keras.models.load_model('model123.h5',custom_objects={'total_loss': total_loss})
model = tf.keras.models.load_model(filepath='model127.h5',custom_objects={'total_loss':total_loss})

# Define a function to preprocess the frame
def preprocess_frame(frame):
    # Convert the frame to RGB
    rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # Resize the frame to (96, 124)
    resized_frame = cv2.resize(rgb_frame, (512, 512))

    # Normalize the frame
    normalized_frame = resized_frame / 255.0

    # Add an extra dimension for batch size
    preprocessed_frame = np.expand_dims(normalized_frame, axis=0)

    return preprocessed_frame

# Define a function to postprocess the frame
def postprocess_frame(frame, original_shape):
    # Resize the frame to the original shape
    resized_frame = cv2.resize(frame, (original_shape[1], original_shape[0]))

    # Convert the frame to BGR
    bgr_frame = cv2.cvtColor(resized_frame, cv2.COLOR_RGB2BGR)

    # Convert the frame to uint8
    final_frame = (bgr_frame * 255.0).astype(np.uint8)

    return final_frame

# Open the video capture
cap = cv2.VideoCapture('IMG-20231226-WA0025.png')

# Get the video dimensions
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)

# Create a video writer to save the output video
out = cv2.VideoWriter('jap(clean_bm).mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))

# Process each frame in the video
while True:
    ret, frame = cap.read()

    if not ret:
        break

    # Preprocess the frame
    preprocessed_frame = preprocess_frame(frame)


    starttime = time.process_time()

    # Apply the dehazing model to the frame
    dehazed_frame = model.predict(preprocessed_frame)

    endtime = time.process_time()

    # Postprocess the dehazed frame
    final_frame = postprocess_frame(dehazed_frame[0], frame.shape)
    cv2.imwrite("frame2.jpg", final_frame)

    # Display the dehazed frame
    cv2.imshow('Dehazed Video', final_frame)
    cv2.waitKey(1)

    cv2.namedWindow('Dehazed Video', cv2.WINDOW_NORMAL) #low fps output during inference
    cv2.resizeWindow('Dehazed Video', 640, 480) 

    # Write the dehazed frame to the output video
    out.write(final_frame)

# Release the video capture and writer
cap.release()
out.release()

# Close all windows
cv2.destroyAllWindows()




In [None]:
model.summary()