In [None]:
import os
import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Concatenate, Input, Dropout, UpSampling2D, MaxPooling2D, Add
from sklearn.utils import shuffle
import matplotlib.pyplot as plt

# ------------------------ Load Dataset ------------------------
def load_patches(folder, img_size=(512, 512)):
    images = []
    filenames = sorted(os.listdir(folder))  # Ensure patches are loaded in order
    for filename in filenames:
        img = cv2.imread(os.path.join(folder, filename))
        if img is not None:
            img = cv2.resize(img, img_size)
            img = img.astype(np.float32) / 255.0  # Normalize
            images.append(img)
    return np.array(images)

# Load Training Data (Raw Images)
X_mars = load_patches("Patches_mars")
X_crater = load_patches("Patches_crater")
X_train = np.concatenate([X_mars, X_crater], axis=0)

# Load Ground Truth Data (Enhanced Images)
Y_mars = load_patches("Patches_enhanced_mars")
Y_crater = load_patches("Patches_enhanced_crater")
Y_mars2 = load_patches("Patches_enhanced_mars2")
Y_train = np.concatenate([Y_mars, Y_crater, Y_mars2], axis=0)

# Handle Data Imbalance (Ensure Equal Training & Validation Samples)
# X_train, Y_train = shuffle(X_train, Y_train, random_state=42)  # Shuffle data

# Print Dataset Shapes
print("Training Data Shape:", X_train.shape)
print("Ground Truth Shape:", Y_train.shape)


In [None]:
class ReflectionPadding2D(tf.keras.layers.Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        super(ReflectionPadding2D, self).__init__(**kwargs)
        self.padding = padding

    def call(self, inputs):
        pad_w, pad_h = self.padding
        return tf.pad(inputs, [[0, 0], [pad_h, pad_h], [pad_w, pad_w], [0, 0]], mode='REFLECT')

    def get_config(self):
        config = super(ReflectionPadding2D, self).get_config()
        config.update({'padding': self.padding})
        return config

def build_cnn_model():
    inputs = Input(shape=(512, 512, 3))

    # Convolutional Layers
    x = ReflectionPadding2D((1, 1))(inputs)
    x = Conv2D(64, (3, 3), padding="valid")(inputs)
    x = Activation("relu")(x)

    # Residual Blocks (Adding multiple layers for enhancement)
    for _ in range(8):
        res = ReflectionPadding2D((1, 1))(x)
        res = Conv2D(64, (3, 3), padding="valid")(x)
        res = BatchNormalization()(res)
        res = Activation("relu")(res)
        res = ReflectionPadding2D((1, 1))(res)
        res = Conv2D(64, (3, 3), padding="valid")(res)
        res = BatchNormalization()(res)
        x = Add()([x, res])  # Skip Connection

    # Output Layer (Restores Enhanced Image)
    x = ReflectionPadding2D((1, 1))(x)
    x = Conv2D(3, (3, 3), padding="valid", activation="sigmoid")(x)

    model = Model(inputs, x)
    return model

model = build_deep_unet()
model.compile(optimizer=Adam(learning_rate=1e-4), loss='mean_squared_error', metrics=['accuracy'])
model.summary()

In [None]:
# ------------------------ Train CNN Model ------------------------
history = model.fit(
    X_train, Y_train,
    validation_split=0.1,
    epochs=50,
    batch_size=6,
    shuffle=True
)

In [None]:
# Save the Model
model.save("mars_enhancement_resnet.pth")
print("model saved")

In [None]:
# ------------------------ Plot Training History ------------------------
import matplotlib.pyplot as plt

# Get training history from the model
history_dict = history.history

# Plot Loss (Training & Validation)
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(history_dict['loss'], label='Training Loss')
plt.plot(history_dict['val_loss'], label='Validation Loss')
plt.title('Loss Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss (MSE)')
plt.legend()
plt.grid(True)

# Plot MAE (Training & Validation)
plt.subplot(1, 2, 2)
plt.plot(history_dict['mae'], label='Training MAE')
plt.plot(history_dict['val_mae'], label='Validation MAE')
plt.title('Mean Absolute Error Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('MAE')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()


In [None]:
def enhance_patches(input_folder, output_folder, model):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    for filename in sorted(os.listdir(input_folder)):
        img_path = os.path.join(input_folder, filename)
        img = cv2.imread(img_path)
        if img is None:
          print(f"Skipping invalid image: {filename}")
          continue
        img = cv2.resize(img, (512, 512)).astype(np.float32) / 255.0
        img = np.expand_dims(img, axis=0)  # Add batch dimension

        enhanced_img = model.predict(img)[0]  # Remove batch dimension
        enhanced_img = (enhanced_img * 255).astype(np.uint8)

        cv2.imwrite(os.path.join(output_folder, filename), enhanced_img)

# Enhance Training Images
enhance_patches("Patches_mars", "Enhanced_mars_resnet", model)
enhance_patches("Patches_crater", "Enhanced_crater_resnet", model)

In [None]:
# ------------------------ Stitch Patches Back ------------------------
def stitch_patches(patch_folder, output_image_path, grid_size=(4, 4), patch_size=(512, 512)):
    stitched_image = np.zeros((patch_size[0] * grid_size[0], patch_size[1] * grid_size[1], 3), dtype=np.uint8)

    patches = sorted([f for f in os.listdir(patch_folder) if f.lower().endswith(('.png'. '.jpg', '.jpeg'))])  # Ensure correct order
    expected_patches = grid_size[0] * grid_size[1]

    if len(patches) < expected_patches:
        raise ValueError(f"Expected '{patch_folder}' patches, but found {len(patches)} in {expected_patches}")
    idx = 0

    for i in range(grid_size[0]):
        for j in range(grid_size[1]):
            patch = cv2.imread(os.path.join(patch_folder, patches[idx]))

            if patch is None:
                raise FileNotFoundError(f"Cannot read patch: {patch}")
            stitched_image[i * patch_size[0]: (i + 1) * patch_size[0], j * patch_size[1]: (j + 1) * patch_size[1]] = patch
            idx += 1

    cv2.imwrite(output_image_path, stitched_image)

# Stitch Enhanced Images
stitch_patches("Enhanced_mars_resnet", "resnet_mars.png")
stitch_patches("Enhanced_crater_resnet", "resnet_crater.png")


In [None]:
import cv2
import matplotlib.pyplot as plt

# --------- Replace with your actual image filenames ---------
image_paths = [
    "mars.jpg",  # or .jpg or whatever your extensions are
    "Enhanced_mars.jpg",
    "unet_mars.png",
    "resnet_mars.png",
    "crater.jpg",
    "Enhanced_crater.jpg",
    "unet_crater.png",
    "resnet_crater.png"
]

# --------- Read and convert images from BGR to RGB ---------
images = []
for path in image_paths:
    img = cv2.imread(path)
    if img is not None:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert to RGB for matplotlib
        images.append(img)
    else:
        print(f"Warning: Couldn't read image: {path}")
        images.append(None)

# --------- Plot images: 2 rows, 3 images per row ---------
fig, axs = plt.subplots(2, 4, figsize=(15, 10))  # 2 rows x 3 columns

for i, ax in enumerate(axs.flat):
    if images[i] is not None:
        ax.imshow(images[i])
        ax.set_title(f"Image {i+1}")
    else:
        ax.set_title("Missing Image")
    ax.axis("off")  # Hide axis

plt.tight_layout()
plt.show()


In [None]:
import cv2
import os
import numpy as np

def split_image_smooth(image_path, output_folder, patch_size=(512, 512)):
    """
    Split image into clean patches without borders/lines.
    Saves patches in PNG format to avoid compression artifacts.
    """
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    image = cv2.imread(image_path)
    if image is None:
        raise ValueError("Image not found or invalid format")

    h, w, _ = image.shape
    patch_h, patch_w = patch_size

    count = 0
    for i in range(0, h, patch_h):
        for j in range(0, w, patch_w):
            patch = image[i:i+patch_h, j:j+patch_w]
            # Ensure patch is full size
            if patch.shape[0] != patch_h or patch.shape[1] != patch_w:
                continue
            patch_path = os.path.join(output_folder, f"patch_{i//patch_h}_{j//patch_w}.png")
            cv2.imwrite(patch_path, patch)  # Use PNG only (no artifacts)
            count += 1

    print(f"✅ Saved {count} clean patches to '{output_folder}'")


In [None]:
def enhance_patches_with_model(patch_folder, output_folder, model_path):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Load model
    model = tf.keras.models.load_model(model_path)
    print("📦 Loaded model:", model_path)

    patches = sorted([f for f in os.listdir(patch_folder) if f.endswith('.png')])
    for filename in patches:
        path = os.path.join(patch_folder, filename)
        image = cv2.imread(path)
        if image is None:
            print(f"⚠️ Skipping: {filename} (invalid image)")
            continue

        # Normalize and reshape for model
        inp = image.astype(np.float32) / 255.0
        inp = np.expand_dims(inp, axis=0)

        # Predict
        pred = model.predict(inp)[0]
        pred = np.clip(pred * 255.0, 0, 255).astype(np.uint8)

        # Save enhanced patch
        out_path = os.path.join(output_folder, filename)
        cv2.imwrite(out_path, pred)

    print(f"✨ Enhanced patches saved to '{output_folder}'")

In [None]:
def gaussian_weight_mask(patch_size):
    """
    Generates a 2D Gaussian weight mask for smooth blending of patches.
    """
    h, w = patch_size
    y = np.linspace(-1, 1, h)
    x = np.linspace(-1, 1, w)
    xx, yy = np.meshgrid(x, y)
    d = np.sqrt(xx**2 + yy**2)
    sigma = 0.5
    gauss = np.exp(-((d**2) / (2.0 * sigma ** 2)))
    gauss -= gauss.min()
    gauss /= gauss.max()
    return gauss[..., np.newaxis]  # Add channel dimension


def final_smooth_stitch(patch_folder, output_path, grid_size=(4, 4), patch_size=(512, 512)):
    H, W = patch_size
    stitched_h = H * grid_size[0]
    stitched_w = W * grid_size[1]

    stitched = np.zeros((stitched_h, stitched_w, 3), dtype=np.float32)
    weight_map = np.zeros((stitched_h, stitched_w, 3), dtype=np.float32)

    patches = sorted([f for f in os.listdir(patch_folder) if f.endswith('.png')])
    if len(patches) != grid_size[0] * grid_size[1]:
        raise ValueError(f"Expected {grid_size[0]*grid_size[1]} patches, found {len(patches)}")

    print(f"Found {len(patches)} patches in {patch_folder}")

    weight_mask = gaussian_weight_mask(patch_size)

    idx = 0
    for i in range(grid_size[0]):
        for j in range(grid_size[1]):
            if idx >= len(patches):
                continue
            patch_path = os.path.join(patch_folder, patches[idx])
            patch = cv2.imread(patch_path)
            if patch is None:
                raise ValueError(f"❌ Cannot read patch: {patch_path}")
            patch = patch.astype(np.float32)

            y1 = i * patch_size[0]
            x1 = j * patch_size[1]

            stitched[y1:y1+H, x1:x1+W] += patch * weight_mask
            weight_map[y1:y1+H, x1:x1+W] += weight_mask
            idx += 1

    weight_map[weight_map == 0] = 1
    final_img = stitched / weight_map
    final_img = np.clip(final_img, 0, 255).astype(np.uint8)

    cv2.imwrite(output_path, final_img)
    print(f"✅ Final stitched image saved to {output_path}")

    # Show result
    plt.figure(figsize=(8, 8))
    plt.imshow(cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB))
    plt.title("Final Smooth Stitched Image")
    plt.axis("off")
    plt.show()

In [None]:
# Step 1: Split
split_image_smooth("mars13.jpg", "Patches_mars13", patch_size=(512, 512))

# Step 2: Enhance using your trained model
enhance_patches_with_model("Patches_mars13", "Enhanced_mars13_resnet", "mars_enhancement_resnet.pth")

# Step 3: Stitch result
final_smooth_stitch("Enhanced_mars13_resnet", "resnet_mars13.png", grid_size=(4, 4), patch_size=(512, 512))


In [None]:
import cv2
import matplotlib.pyplot as plt

# --------- Replace with your actual image filenames ---------
image_paths = [
    "mars13.jpg",  # or .jpg or whatever your extensions are
    "Enhanced_mars13.jpg",
    "unet_mars13.png",
    "resnet_mars13.jpg"
]

# --------- Read and convert images from BGR to RGB ---------
images = []
for path in image_paths:
    img = cv2.imread(path)
    if img is not None:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert to RGB for matplotlib
        images.append(img)
    else:
        print(f"Warning: Couldn't read image: {path}")
        images.append(None)

# --------- Plot images: 2 rows, 3 images per row ---------
fig, axs = plt.subplots(1, 3, figsize=(15, 10))  # 2 rows x 3 columns

for i, ax in enumerate(axs.flat):
    if images[i] is not None:
        ax.imshow(images[i])
        ax.set_title(f"Image {i+1}")
    else:
        ax.set_title("Missing Image")
    ax.axis("off")  # Hide axis

plt.tight_layout()
plt.show()

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from skimage import io, img_as_ubyte, img_as_float
from skimage.measure import shannon_entropy
from skimage.restoration import estimate_sigma
from skimage.metrics import structural_similarity as ssim
from skimage import color
import os


In [None]:
from skimage import exposure
from skimage.util import img_as_ubyte
from skimage.metrics import mean_squared_error
from skimage import img_as_float
from skimage.color import rgb2gray
from skimage.restoration import estimate_sigma
from skimage import util
from skimage.measure import shannon_entropy

def compute_mscn_coefficients(img, kernel_size=7, sigma=7/6):
    img = img.astype(np.float32)
    mu = cv2.GaussianBlur(img, (kernel_size, kernel_size), sigma)
    mu_sq = mu * mu
    sigma = cv2.GaussianBlur(img * img, (kernel_size, kernel_size), sigma)
    sigma = np.sqrt(np.abs(sigma - mu_sq))
    mscn = (img - mu) / (sigma + 1e-5)
    return mscn

def brisque_numpy(img):
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    mscn = compute_mscn_coefficients(gray)
    mean_mscn = np.mean(mscn)
    std_mscn = np.std(mscn)
    return np.abs(mean_mscn) + std_mscn

def niqe_numpy(img):
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    mscn = compute_mscn_coefficients(gray)
    mean = np.mean(mscn)
    var = np.var(mscn)
    return np.sqrt(mean**2 + var)

def piqe_numpy(img, block_size=32, threshold=10):
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    h, w = gray.shape
    distorted_blocks = 0
    total_blocks = 0

    for i in range(0, h, block_size):
        for j in range(0, w, block_size):
            block = gray[i:i+block_size, j:j+block_size]
            if block.size == 0 or block.shape[0] != block_size or block.shape[1] != block_size:
                continue
            std_dev = np.std(block)
            if std_dev < threshold:
                distorted_blocks += 1
            total_blocks += 1

    piqe_score = (distorted_blocks / total_blocks) * 100
    return piqe_score


# Entropy
def compute_entropy(image):
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    return shannon_entropy(gray)

# SNR (no-reference estimation)
def compute_snr(image):
    image = img_as_float(rgb2gray(image))
    mu = np.mean(image)
    sigma = estimate_sigma(image, multichannel=False)
    return 10 * np.log10(mu**2 / sigma**2)


# HVS Sharpness
def compute_hvs_sharpness(image):
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    laplacian = cv2.Laplacian(gray, cv2.CV_64F)
    return np.var(laplacian)


In [None]:
import cv2
import pandas as pd

# Replace with your actual image paths and titles
image_paths = ['mars13.jpg', 'Enhanced_mars13.jpg', 'unet_mars13.png', 'resnet_mars13.png']
image_titles = ['original', 'GT', 'unet', 'resnet']
metrics = []

for path in image_paths:
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Use the improved NumPy-based metric functions
    brisque_score = brisque_numpy(img)
    niqe_val = niqe_numpy(img)
    piqe_val = piqe_numpy(img)

    entropy_val = compute_entropy(img)
    snr_val = compute_snr(img)
    hvs_val = compute_hvs_sharpness(img)

    metrics.append({
        'BRISQUE ↓': round(brisque_score, 2),
        'NIQE ↓': round(niqe_val, 2),
        'PIQE ↓': round(piqe_val, 2),
        'Entropy ↑': round(entropy_val, 2),
        'SNR ↑': round(snr_val, 2),
        'HVS Sharpness ↑': round(hvs_val, 2),
    })

# Create a DataFrame
df = pd.DataFrame(metrics, index=image_titles)

# Display the table
print("📊 No-Reference Image Quality Metrics Comparison Table:\n")
display(df.style.set_caption("Comparison of Image Quality Metrics")
        .set_table_styles([{'selector': 'caption',
                            'props': [('color', 'black'),
                                      ('font-size', '16px'),
                                      ('text-align', 'center'),
                                      ('font-weight', 'bold')]}])
        .set_properties(**{'text-align': 'center'})
        .highlight_min(axis=0, subset=['BRISQUE ↓', 'NIQE ↓', 'PIQE ↓'], color='lightgreen')
        .highlight_max(axis=0, subset=['Entropy ↑', 'SNR ↑', 'HVS Sharpness ↑'], color='lightblue'))


In [None]:
import matplotlib.pyplot as plt

# Define metric names and colors
metric_names = ['BRISQUE ↓', 'Entropy ↑', 'SNR ↑', 'NIQE ↓', 'PIQE ↓', 'HVS Sharpness ↑']
colors = ['tomato', 'gold', 'darkcyan', 'mediumorchid', 'darkorange', 'deepskyblue']

fig, axs = plt.subplots(3, 2, figsize=(14, 14))
fig.suptitle("📊 No-Reference Image Quality Metrics (Visual Comparison)", fontsize=18, weight='bold', y=1.02)

axs = axs.flatten()

for i, metric in enumerate(metric_names):
    values = df_metrics[metric].values
    axs[i].bar(image_titles, values, color=colors[i], edgecolor='black')

    axs[i].set_title(metric, fontsize=14, pad=10)

    max_val = max(values)
    padding = max_val * 0.01  # Add 15% padding above tallest bar
    axs[i].set_ylim(0, max_val + padding)

    # Put labels inside the bars if possible
    for j, val in enumerate(values):
        label_y = val + (padding * 0.1) if val < max_val + padding * 0.5 else val - (padding * 0.1)
        ha = 'center'
        va = 'bottom' if label_y < val else 'top'
        axs[i].text(j, label_y, f"{val:.2f}", ha=ha, va=va, fontsize=10, color='black')

    axs[i].set_ylabel("↑ Better" if '↑' in metric else "↓ Better", fontsize=11)
    axs[i].grid(True, linestyle='--', alpha=0.5)

# Clean up extra axes if any
for i in range(len(metric_names), len(axs)):
    fig.delaxes(axs[i])

plt.subplots_adjust(hspace=0.3, top=0.9)
plt.show()
