In [1]:
import numpy as np
import cv2
from scipy.ndimage import gaussian_filter, laplace
from PIL import Image
import os

In [2]:
# Canny Edge Detection
def canny_edge_detection(image, low_threshold=100, high_threshold=200):
    # Convert image to uint8 type (if it was in float32, scale to [0,255])
    image_uint8 = np.uint8(image * 255)  # If your image is in [0, 1] range, scale it back
    edges = cv2.Canny(image_uint8, low_threshold, high_threshold)
    return edges

In [3]:
from scipy.ndimage import binary_dilation
import zstandard as zstd
from skimage import io, color

def compress_image(image, edge_mask, q, d):
    """
    Encodes pixel values along edges and boundary pixels with debugging.
    """
    # Ensure edge_mask is binary
    edge_mask = edge_mask.astype(bool)

    # Debug edge mask
    # plt.close('all')
    # plt.imshow(edge_mask, cmap='gray')
    # plt.title("Edge_mask")
    # plt.show()

    # Get image dimensions
    h, w = image.shape

    # Quantize the image
    max_val = 255  # Assuming image is 8-bit grayscale
    quantized_image = (image / max_val * (2**q - 1)).astype(np.uint8)

    # Debug quantized image
    # plt.imshow(quantized_image, cmap='gray')
    # plt.title("Quantized Image")
    # plt.colorbar()
    # plt.show()

    dilated_edges = binary_dilation(edge_mask)
    side_pixels = dilated_edges & ~edge_mask  # Pixels adjacent to edges

    # Debug dilated edges
    # print("Dilated Edges:")
    # plt.imshow(dilated_edges, cmap='gray')
    # plt.show()

    # print("Side Pixels:")
    # plt.imshow(side_pixels, cmap='gray')
    # plt.show()

    # Store boundary pixels
    boundary_mask = np.zeros_like(image, dtype=bool)
    boundary_mask[0, :] = boundary_mask[-1, :] = True
    boundary_mask[:, 0] = boundary_mask[:, -1] = True

    # Combine masks for pixels to be stored
    store_mask = side_pixels | boundary_mask

    # Debug store mask
    # print("Store Mask:")
    # plt.imshow(store_mask, cmap='gray')
    # plt.show()

    # Extract values based on sampling distance d
    indices = np.argwhere(store_mask)
    print(f"Number of Indices Before Sampling: {len(indices)}")
    sampled_indices = indices[::d]  # Subsampling
    print(f"Number of Indices After Sampling: {len(sampled_indices)}")

    sampled_data = [(idx[0], idx[1], quantized_image[tuple(idx)]) for idx in sampled_indices]
    print(f"Sampled Values (First 10): {sampled_data[:10]}")

    cctx = zstd.ZstdCompressor(level=3)  # Adjust level for compression ratio
    compressed_data = cctx.compress(pickle.dumps(sampled_data))

    return compressed_data

In [4]:
import numpy as np
from scipy.interpolate import interp1d
from scipy.ndimage import laplace
import matplotlib.pyplot as plt

def decompress_image(sampled_data, edge_mask, image_shape, q, d, diffusion_iters=200000, tol=1e-5):
    import matplotlib.pyplot as plt
    plt.ion()  # Enable interactive mode

    # Reconstruct pixel values
    max_val = 255
    quant_step = max_val / (2**q - 1)

    # Initialize the image
    reconstructed_image = np.zeros(image_shape, dtype=np.float32)
    sampled_positions = []
    pixel_values = []

    # Place quantized values at edge positions
    for row, col, quantized_value in sampled_data:
        reconstructed_image[row, col] = quantized_value * quant_step  # Dequantize
        sampled_positions.append((row, col))
        pixel_values.append(quantized_value * quant_step)

    # Debug: Visualize initial placement
    # plt.close('all')
    # plt.imshow(reconstructed_image, cmap='gray')
    # plt.title("After Placing Edge-Adjacent Pixels")
    # plt.colorbar()
    # plt.show()

    # Interpolate along edges
    for edge_idx in range(len(sampled_positions) - 1):
        start, end = sampled_positions[edge_idx], sampled_positions[edge_idx + 1]
        line_coords = np.linspace(start, end, num=d, endpoint=False, axis=0)
        interp_values = np.linspace(pixel_values[edge_idx], pixel_values[edge_idx + 1], len(line_coords))
        for coord, interp_value in zip(line_coords, interp_values):
            reconstructed_image[tuple(map(int, coord))] = interp_value

    dilated_edge_mask = binary_dilation(edge_mask)  # Dilate the edge mask
    dilated_edge_mask[0, :] = dilated_edge_mask[-1, :] = True
    dilated_edge_mask[:, 0] = dilated_edge_mask[:, -1] = True
    reconstructed_image = reconstructed_image * dilated_edge_mask  # Bitwise AND

    # Debug: Visualize after interpolation
    # plt.close('all')
    # plt.imshow(reconstructed_image, cmap='gray')
    # plt.title("After Edge Interpolation")
    # plt.colorbar()
    # plt.show()

    # Missing data mask
    missing_mask = (reconstructed_image == 0)
    # plt.close('all')
    # plt.imshow(missing_mask, cmap='gray')
    # plt.title("Missing Data Mask")
    # plt.colorbar()
    # plt.show()

    dt = 0.1
    for i in range(diffusion_iters):
      previous_image = reconstructed_image.copy()
      diffusion_step = laplace(previous_image)
      reconstructed_image[missing_mask] += dt * diffusion_step[missing_mask]
      diff = np.abs(reconstructed_image - previous_image).max()
      if diff < tol:
          print(f"Converged after {i} iterations with max diff {diff}")
          break

    # Clip values to valid range
    reconstructed_image = np.clip(reconstructed_image, 0, 255)

    # Debug: Final visualization
    # plt.close('all')
    # plt.figure()
    # plt.imshow(reconstructed_image, cmap='gray')
    # plt.title("Reconstructed Image")
    # plt.colorbar()
    # plt.show()

    return reconstructed_image

In [5]:
import pickle

def rle_encode(data):
    """Encodes data using Run-Length Encoding (RLE)."""

    data = data.flatten()
    encoded = []
    count = 1
    for i in range(1, len(data)):
        if data[i] == data[i - 1]:
            count += 1
        else:
            encoded.extend([data[i - 1], count])
            count = 1
    encoded.extend([data[-1], count])  # Add last run

    encoded_bytes = pickle.dumps(encoded)

    return encoded_bytes  # Return as bytearray

def rle_decode(encoded):
    """Decodes RLE-encoded data."""
    decoded = []
    for i in range(0, len(encoded), 2):
        value, count = encoded[i], encoded[i + 1]
        decoded.extend([value] * count)
    return decoded

In [6]:
def create_compressed_file(image,edge_mask, compressed_data, q, d, filename="compressed_image.bin"):
    """Creates the final compressed file."""

    edge_mask_encode = rle_encode(edge_mask)

    # Compress pixel data
    paq_data = compressed_data

    # Header information (example, modify as needed)
    header = {
        "img_size": image.shape,
        "q": q,
        "d": d,
        "edge_mask_len": len(edge_mask_encode),
        "channels": 1  # Or 3 for color images
    }

    # Combine data and write to file
    with open(filename, "wb") as f:
        f.write(pickle.dumps(header))  # Write header
        f.write(edge_mask_encode)  # Write JBIG data
        f.write(paq_data)  # Write PAQ data

def decompress_file(filename="compressed_image.bin"):

    """Decompresses the compressed file and reconstructs the image."""

    with open(filename, "rb") as f:
        # 1. Read Header Information
        header = pickle.load(f)
        image_shape = header["img_size"]
        q = header["q"]
        d = header["d"]
        edge_mask_len = header["edge_mask_len"]
        channels = header["channels"]

        # 2. Read JBIG Data and Decompress Edge Mask
        edge_mask_data = f.read(edge_mask_len)  # Read only the edge mask data
        decoded_data = pickle.loads(edge_mask_data)
        decoded = rle_decode(decoded_data)
        edge_mask = np.array(decoded).reshape(image_shape)

        # 3. Read and Decompress Pixel Data
        compressed_pixel_data = f.read()  # Read the remaining data as pixel data
        dctx = zstd.ZstdDecompressor()
        sampled_data = pickle.loads(dctx.decompress(compressed_pixel_data))

        # 4. Call decompress_image Function
        reconstructed_image = decompress_image(sampled_data, edge_mask, image_shape, q, d)
        return reconstructed_image

In [7]:
def compress_and_decompress_image(image, q, d, filename="compressed_image.bin"):
    """
    Compresses and decompresses an image using Canny edge detection, JBIG, and
    a chosen compression library (zstd in this example).

    Args:
        image: The input image.
        q: Quantization parameter.
        d: Sampling distance.
        filename: The filename for the compressed image.

    Returns:

        The reconstructed image.
    """
    image1 = image / 255.0  # Normalize image

    # 1. Edge Detection (Canny)
    edges_canny = canny_edge_detection(image1, low_threshold=100, high_threshold=200)

    # 2. Compression
    compressed_data = compress_image(image, edges_canny, q, d)

    create_compressed_file(image,edges_canny, compressed_data, q, d)

    # 3. Decompression
    reconstructed_image = decompress_file(filename)
    print(len(reconstructed_image))
    return reconstructed_image

In [None]:
def calculate_rmse(original_image, decompressed_image):
    """Calculates the RMSE between two images."""
    return np.sqrt(np.mean((original_image - decompressed_image) ** 2))

# Calculate BPP (Bits per Pixel)
def calculate_bpp(compressed_data, image_shape):
    """Calculates the BPP of a compressed image."""
    compressed_size_bits = len(compressed_data) * 8  # Size in bits
    num_pixels = image_shape[0] * image_shape[1]
    return compressed_size_bits / num_pixels

def diff_qnd_for_all_images():
    """
    Analyze RMSE vs BPP for multiple images and generate one plot per (q, d) pair.
    """
    # Input and output directories
    input_dir = './dataset/cartoon'
    plots_dir = './plots'
    output_dir = './reconstructed_images'
    os.makedirs(plots_dir, exist_ok=True)
    os.makedirs(output_dir, exist_ok=True)

    # Define q and d values
    q_values = [4]  # Example quantization factors
    d_values = [1,4,8]  # Example downsampling factors

    # Get a list of images in the directory
    image_files = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    image_files = image_files[:3]  # Process only the first 5 images

    for q in q_values:
        for d in d_values:
            bpp_values = []
            rmse_values = []

            for img_file in image_files:
                img_path = os.path.join(input_dir, img_file)
                image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)

                if image is None:
                    print(f"Error: Unable to load {img_file}")
                    continue

                # Normalize the image for edge detection
                normalized_image = image / 255.0

                # Apply edge detection
                edges_canny = canny_edge_detection(normalized_image, low_threshold=100, high_threshold=200)

                # Compress the image and get compressed data
                compressed_data = compress_image(image, edges_canny, q, d)

                # Calculate BPP for the compressed data
                bpp = calculate_bpp(compressed_data, image.shape)

                # Decompress the image and calculate RMSE
                decompressed_image = compress_and_decompress_image(image, q, d)
                rmse = calculate_rmse(image, decompressed_image)

                # Store the results
                bpp_values.append(bpp)
                rmse_values.append(rmse)

                # Save the reconstructed image
                reconstructed_image_pil = Image.fromarray(np.uint8(decompressed_image))
                save_path = os.path.join(output_dir, f'reconstructed_{q}_{d}_{img_file}')
                reconstructed_image_pil.save(save_path, format="PNG")
                print(f"Saved: {save_path}")

                # Print the result
                print(f"Image: {img_file}, q: {q}, d: {d}, BPP: {bpp:.4f}, RMSE: {rmse:.4f}")

            # Plot RMSE vs. BPP for this q, d pair
            plt.figure()
            plt.plot(bpp_values, rmse_values, 'o-', label=f'q={q}, d={d}')
            plt.xlabel('BPP (Bits per Pixel)')
            plt.ylabel('RMSE')
            plt.title(f'RMSE vs. BPP (q={q}, d={d})')
            plt.legend()
            plt.grid()
            plot_path = os.path.join(plots_dir, f'rmse_bpp_q{q}_d{d}.png')
            plt.savefig(plot_path)
            plt.close()

            print(f"Plot saved for q={q}, d={d} at {plot_path}")


# Run the function
diff_qnd_for_all_images()
