In [None]:
import os
import cv2
import numpy as np
import fitz  # PyMuPDF for PDF processing
from tensorflow.keras.models import load_model
from PIL import Image

# === Configuration ===
INPUT_PDF = None  # Set to None if no PDF is given
OUTPUT_PDF = "F:/test/ALL.pdf"
EXTRACTED_IMAGES_DIR = "C:/gray/1"
PROCESSED_IMAGES_DIR = "C:/gray/2"
AI_MODEL_PATH = "F:/test/colorization_Model.keras"
PATCH_SIZE = 256  # Ensure this is set

# === Step 1: Extract Images from PDF (if provided) ===
def extract_images_from_pdf(pdf_path, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    doc = fitz.open(pdf_path)
    image_count = 0

    for page_num in range(len(doc)):
        images = doc[page_num].get_images(full=True)
        for img_index, img in enumerate(images):
            xref = img[0]
            base_image = doc.extract_image(xref)
            img_bytes = base_image["image"]
            img_ext = base_image["ext"].lower()  # Ensure lowercase extension

            # Convert WebP images to PNG for compatibility
            if img_ext == "webp":
                img_ext = "png"
            
            image_path = os.path.join(output_dir, f"image_{image_count}.{img_ext}")
            with open(image_path, "wb") as f:
                f.write(img_bytes)
            print(f"Extracted: {image_path}")
            image_count += 1

    print(f"Total Extracted Images: {image_count}")

# === Step 2: Load AI Model ===
def load_ai_model(model_path):
    print("Loading AI Model...")
    model = load_model(model_path)
    print("Model Loaded Successfully.")
    return model

# === Step 3: Process Image by Splitting, Running AI, and Reassembling ===
def split_image(image, patch_size):
    h, w, c = image.shape
    patches, positions = [], []

    for i in range(0, h, patch_size):
        for j in range(0, w, patch_size):
            patch = image[i:i+patch_size, j:j+patch_size]
            
            if patch.shape[0] < patch_size or patch.shape[1] < patch_size:
                patch = cv2.copyMakeBorder(patch, 0, patch_size - patch.shape[0], 
                                                  0, patch_size - patch.shape[1], 
                                                  cv2.BORDER_CONSTANT, value=(0, 0, 0))
            patches.append(patch)
            positions.append((i, j))
    
    return patches, positions, (h, w)

def process_patches(patches, model):
    processed_patches = []
    
    for patch in patches:
        if patch.shape[-1] == 3:
            patch_gray = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY)
        else:
            patch_gray = patch

        patch_input = patch_gray / 255.0  # Normalize
        patch_input = np.expand_dims(patch_input, axis=0)  # Add batch dimension
        patch_input = np.expand_dims(patch_input, axis=-1)  # Ensure single-channel grayscale

        processed_patch = model.predict(patch_input)[0]
        processed_patch = np.squeeze(processed_patch)  # Remove extra dimensions
        processed_patch = (processed_patch * 255).astype(np.uint8)

        if len(processed_patch.shape) == 2:  
            processed_patch = cv2.cvtColor(processed_patch, cv2.COLOR_GRAY2BGR)

        processed_patches.append(processed_patch)

    return processed_patches

def reassemble_image(patches, positions, original_size):
    h, w = original_size
    reconstructed_image = np.zeros((h, w, 3), dtype=np.uint8)
    
    for patch, (i, j) in zip(patches, positions):
        patch = patch[:min(h-i, PATCH_SIZE), :min(w-j, PATCH_SIZE)]
        reconstructed_image[i:i+patch.shape[0], j:j+patch.shape[1]] = patch

    return reconstructed_image

def process_image(image_path, model, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    
    image = cv2.imread(image_path)
    if image is None:
        print(f"Skipping: Unable to read {image_path}")
        return
    
    patches, positions, original_size = split_image(image, PATCH_SIZE)
    processed_patches = process_patches(patches, model)
    reconstructed_image = reassemble_image(processed_patches, positions, original_size)
    
    output_path = os.path.join(output_dir, os.path.basename(image_path))
    cv2.imwrite(output_path, reconstructed_image)
    print(f"Processed Image Saved: {output_path}")

# === Step 4: Convert Processed Images into PDF ===
def convert_images_to_pdf(image_dir, output_pdf):
    os.makedirs(os.path.dirname(output_pdf), exist_ok=True)  # Ensure output directory exists
    images = []

    for img_file in sorted(os.listdir(image_dir)):
        if img_file.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
            img_path = os.path.join(image_dir, img_file)
            image = Image.open(img_path).convert("RGB")
            images.append(image)
    
    if images:
        images[0].save(output_pdf, save_all=True, append_images=images[1:])
        print(f"Final PDF Created: {output_pdf}")
    else:
        print("No images to convert into PDF.")

# === Main Execution ===
if __name__ == "__main__":
    if INPUT_PDF:
        print(f"Extracting images from PDF: {INPUT_PDF}")
        extract_images_from_pdf(INPUT_PDF, EXTRACTED_IMAGES_DIR)
    else:
        print("No PDF provided. Skipping extraction.")

    if not os.listdir(EXTRACTED_IMAGES_DIR):
        print("No images found for processing. Exiting...")
        exit()

    ai_model = load_ai_model(AI_MODEL_PATH)

    for img_file in os.listdir(EXTRACTED_IMAGES_DIR):
        if img_file.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
            process_image(os.path.join(EXTRACTED_IMAGES_DIR, img_file), ai_model, PROCESSED_IMAGES_DIR)

    convert_images_to_pdf(PROCESSED_IMAGES_DIR, OUTPUT_PDF)
