In [1]:
#This program creates non-specular images from specular images by scanning adjacent frames for correct pixels. It then creates two
#Directories for training data


In [1]:
import cv2
import numpy as np
import os
import pywt
from tqdm import tqdm
import shutil

# --- Configuration ---
CONFIG = {
    "VIDEO_PATH": "video.mp4",
    "OUTPUT_ORIGINAL_DIR": "dataset/original_frames",
    "OUTPUT_GROUND_TRUTH_DIR": "dataset/ground_truth_frames_temporal_v2", # New output dir
    "TUNING_SAMPLE_DIR": "dataset/tuning_samples",
    "DETECTION_SAMPLE_DIR": "dataset/detection_samples_temporal_v2", # New sample dir
    "TARGET_IMG_SIZE": (512, 512),
    # --- Detection Parameters ---
    "SPECULAR_THRESHOLD_HSV_V": 130,
    "SPECULAR_THRESHOLD_HSV_S": 75,
    "SPECULAR_THRESHOLD_WAVELET": 0.05,
    "MIN_CONTOUR_AREA": 3,
    "DILATION_KERNEL_SIZE": 13,
    # --- Other Parameters ---
    "WAVELET": "haar",
    "WAVELET_LEVEL": 2,
    "FRAME_SKIP": 0,
    "SAMPLE_EVERY_N_FRAMES": 20,
    "NUM_INITIAL_SAMPLES_FOR_TUNING": 5,
    # --- Temporal Inpainting Parameters ---
    "TEMPORAL_SEARCH_WINDOW_PAST": 10, # Keep these moderate
    "TEMPORAL_SEARCH_WINDOW_FUTURE": 10,
    "MIN_BRIGHTNESS_FOR_TEMPORAL_SOURCE": 40, # Min brightness for a replacement pixel
    "MIN_SATURATION_FOR_TEMPORAL_SOURCE": 30, # Min saturation for a replacement pixel
    "HUE_RED_LOWER1": 0, "HUE_RED_UPPER1": 20,   # Broader red hue
    "HUE_RED_LOWER2": 160, "HUE_RED_UPPER2": 179, # Broader red hue
    "FALLBACK_INPAINT_RADIUS": 3, # Radius for cv2.inpaint if no temporal source found
}

def create_dir_if_not_exists(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path); print(f"Created directory: {dir_path}")

def clear_dir(dir_path):
    if os.path.exists(dir_path): shutil.rmtree(dir_path)
    os.makedirs(dir_path)

def dilate_mask(mask, dilation_kernel_size):
    if dilation_kernel_size <= 1: return mask.astype(bool)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilation_kernel_size, dilation_kernel_size))
    dilated_mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=1)
    return dilated_mask.astype(bool)

def detect_specular_regions(frame, current_config):
    hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
    v_channel, s_channel = hsv[:, :, 2], hsv[:, :, 1]
    hsv_mask_ = (v_channel > current_config["SPECULAR_THRESHOLD_HSV_V"]) & \
               (s_channel < current_config["SPECULAR_THRESHOLD_HSV_S"])

    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    wavelet_mask_ = np.zeros_like(gray, dtype=bool)
    try:
        coeffs = pywt.wavedec2(gray, current_config["WAVELET"], level=current_config["WAVELET_LEVEL"])
        details = []
        for c_level in coeffs[1:]:
            for d_idx in range(len(c_level)):
                 details.append(cv2.resize(np.abs(c_level[d_idx]), (gray.shape[1], gray.shape[0])))
        if details:
            max_detail = np.max(np.stack(details), axis=0)
            min_v, max_v = np.min(max_detail), np.max(max_detail)
            norm_details = (max_detail - min_v) / (max_v - min_v + 1e-6) if max_v > min_v else np.zeros_like(max_detail)
            wavelet_mask_ = norm_details > current_config["SPECULAR_THRESHOLD_WAVELET"]
    except Exception as e: print(f"  Warn: Wavelet err: {e}.")

    combined = np.logical_or(hsv_mask_, wavelet_mask_)
    k_morph = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5))
    morphed = cv2.morphologyEx(cv2.morphologyEx(combined.astype(np.uint8), cv2.MORPH_CLOSE, k_morph), cv2.MORPH_OPEN, k_morph)
    
    final_m = np.zeros_like(morphed, dtype=np.uint8)
    contours, _ = cv2.findContours(morphed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    for cnt in contours:
        if cv2.contourArea(cnt) > current_config["MIN_CONTOUR_AREA"]:
            cv2.drawContours(final_m, [cnt], -1, 1, thickness=cv2.FILLED)
    
    return dilate_mask(final_m, current_config["DILATION_KERNEL_SIZE"])

def get_user_tuned_parameters(initial_config, video_path):
    tuned_cfg = initial_config.copy()
    if "DILATION_KERNEL_SIZE" not in tuned_cfg:
        tuned_cfg["DILATION_KERNEL_SIZE"] = initial_config.get("DILATION_KERNEL_SIZE", 13) # Use the one in CONFIG

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened(): print(f"Error: No video for tuning: {video_path}"); return None
    
    sample_frames = []
    for _ in range(tuned_cfg["NUM_INITIAL_SAMPLES_FOR_TUNING"]):
        ret, frame = cap.read()
        if not ret: break
        sample_frames.append(frame)
    cap.release()

    if not sample_frames: print("No frames for tuning."); return None
    create_dir_if_not_exists(tuned_cfg["TUNING_SAMPLE_DIR"])

    while True:
        print("\n--- Parameter Tuning ---")
        print(f"  1. HSV Val Thresh (hsv_v): {tuned_cfg['SPECULAR_THRESHOLD_HSV_V']}")
        print(f"  2. HSV Sat Max    (hsv_s): {tuned_cfg['SPECULAR_THRESHOLD_HSV_S']}")
        print(f"  3. Wavelet Thresh (wavelet): {tuned_cfg['SPECULAR_THRESHOLD_WAVELET']:.2f}")
        print(f"  4. Min Contour Area (area): {tuned_cfg['MIN_CONTOUR_AREA']}")
        print(f"  5. Dilation Kernel (dilate): {tuned_cfg['DILATION_KERNEL_SIZE']}")
        
        clear_dir(tuned_cfg["TUNING_SAMPLE_DIR"])
        print(f"\nGen samples in '{tuned_cfg['TUNING_SAMPLE_DIR']}'...")
        for i, frame in enumerate(sample_frames):
            if frame is None: continue # Skip if a frame is None
            orig_s = cv2.resize(frame, tuned_cfg["TARGET_IMG_SIZE"])
            mask_os = detect_specular_regions(frame, tuned_cfg)
            mask_rs = cv2.resize(mask_os.astype(np.uint8)*255, tuned_cfg["TARGET_IMG_SIZE"], interpolation=cv2.INTER_NEAREST)
            mask_vis = cv2.cvtColor(mask_rs, cv2.COLOR_GRAY2BGR)
            high_s = orig_s.copy()
            high_s[mask_rs > 0] = [0, 0, 255] 
            cv2.imwrite(os.path.join(tuned_cfg["TUNING_SAMPLE_DIR"], f"tune_s_{i}.png"), np.concatenate((orig_s, high_s, mask_vis), axis=1))
        
        print(f"Review samples in '{tuned_cfg['TUNING_SAMPLE_DIR']}'.")
        choice = input("Acceptable? (y/n/q): ").lower()
        if choice == 'y': print("Params accepted."); return tuned_cfg
        if choice == 'q': print("Tuning quit."); return None
        if choice == 'n':
            print("\nAdjust (blank to keep current):")
            try:
                def _get_val(prompt, current, type_conv):
                    val = input(prompt(current))
                    return type_conv(val) if val else current
                tuned_cfg["SPECULAR_THRESHOLD_HSV_V"] = _get_val(lambda c: f"  HSV Val (cur {c}): ", tuned_cfg["SPECULAR_THRESHOLD_HSV_V"], int)
                tuned_cfg["SPECULAR_THRESHOLD_HSV_S"] = _get_val(lambda c: f"  HSV Sat (cur {c}): ", tuned_cfg["SPECULAR_THRESHOLD_HSV_S"], int)
                tuned_cfg["SPECULAR_THRESHOLD_WAVELET"] = _get_val(lambda c: f"  Wavelet (cur {c:.2f}): ", tuned_cfg["SPECULAR_THRESHOLD_WAVELET"], float)
                tuned_cfg["MIN_CONTOUR_AREA"] = _get_val(lambda c: f"  Min Area (cur {c}): ", tuned_cfg["MIN_CONTOUR_AREA"], int)
                tuned_cfg["DILATION_KERNEL_SIZE"] = _get_val(lambda c: f"  Dilation (cur {c}): ", tuned_cfg["DILATION_KERNEL_SIZE"], int)
            except ValueError: print("Invalid input type.")
        else: print("Invalid choice.")

def is_pixel_valid_temporal_source(pixel_bgr, config):
    """Checks if a BGR pixel is a valid 'tissue-like' source for temporal inpainting."""
    # Convert BGR to HSV for checking
    # Note: Converting single pixels repeatedly can be slow.
    # If performance is an issue, convert the whole neighbor_frame to HSV once.
    hsv_pixel_arr = cv2.cvtColor(np.uint8([[pixel_bgr]]), cv2.COLOR_BGR2HSV)
    h, s, v = hsv_pixel_arr[0][0]

    if v < config["MIN_BRIGHTNESS_FOR_TEMPORAL_SOURCE"]: return False
    if s < config["MIN_SATURATION_FOR_TEMPORAL_SOURCE"]: return False

    is_red = (config["HUE_RED_LOWER1"] <= h <= config["HUE_RED_UPPER1"]) or \
             (config["HUE_RED_LOWER2"] <= h <= config["HUE_RED_UPPER2"])
    return is_red


def temporally_inpaint_frame_with_validation(current_frame_idx, all_frames_list, all_masks_list, config_obj):
    current_frame = all_frames_list[current_frame_idx]
    current_specular_mask = all_masks_list[current_frame_idx]
    
    gt_frame = current_frame.copy()
    num_loaded_frames = len(all_frames_list)

    if not np.any(current_specular_mask):
        return gt_frame # No speculars to inpaint

    specular_coords_y, specular_coords_x = np.where(current_specular_mask)

    # For fallback inpainting, create a mask of pixels that couldn't be temporally inpainted
    unresolved_specular_mask_for_fallback = np.zeros_like(current_specular_mask, dtype=np.uint8)

    for r, c in zip(specular_coords_y, specular_coords_x):
        pixel_replaced = False
        
        # Search backward then forward, prioritizing closer frames
        search_order = []
        # Past frames
        for i in range(1, config_obj["TEMPORAL_SEARCH_WINDOW_PAST"] + 1):
            idx = current_frame_idx - i
            if idx >= 0: search_order.append(idx)
            else: break
        # Future frames
        for i in range(1, config_obj["TEMPORAL_SEARCH_WINDOW_FUTURE"] + 1):
            idx = current_frame_idx + i
            if idx < num_loaded_frames: search_order.append(idx)
            else: break
            
        for neighbor_idx in search_order:
            neighbor_frame = all_frames_list[neighbor_idx]
            # Check if the pixel in neighbor_frame is NOT specular AND is a valid color
            if not all_masks_list[neighbor_idx][r, c]:
                candidate_pixel_bgr = neighbor_frame[r, c]
                if is_pixel_valid_temporal_source(candidate_pixel_bgr, config_obj):
                    gt_frame[r, c] = candidate_pixel_bgr
                    pixel_replaced = True
                    break # Found a good replacement
            if pixel_replaced: break
        
        if not pixel_replaced:
            # Mark this pixel for fallback spatial inpainting
            unresolved_specular_mask_for_fallback[r, c] = 255 # OpenCV inpaint needs 255 for mask

    # Fallback for unresolved pixels: apply cv2.inpaint
    if np.any(unresolved_specular_mask_for_fallback):
        # Important: inpaint on the 'gt_frame' which might have some pixels already filled temporally
        # This ensures we don't overwrite good temporal fills with spatial ones.
        # However, cv2.inpaint works best if it has good boundaries from the original frame.
        # So, we inpaint on the original frame, but only apply the result to unresolved pixels.
        
        # Create a temporary frame for inpainting that starts as the original
        temp_inpaint_source = current_frame.copy()
        # But fill in already resolved pixels from gt_frame, so inpaint uses those as boundary if close
        # This gets complex. Simpler: Inpaint on original, apply only to unresolved.
        
        inpainted_fallback_region = cv2.inpaint(current_frame, # Source is original frame
                                                unresolved_specular_mask_for_fallback,
                                                config_obj["FALLBACK_INPAINT_RADIUS"],
                                                cv2.INPAINT_TELEA) # Or INPAINT_NS
        
        # Apply the fallback only to the pixels that were marked as unresolved
        gt_frame[unresolved_specular_mask_for_fallback == 255] = \
            inpainted_fallback_region[unresolved_specular_mask_for_fallback == 255]
            
    return gt_frame


def main():
    effective_config = get_user_tuned_parameters(CONFIG, CONFIG["VIDEO_PATH"])
    if effective_config is None: print("Exiting."); return
    
    print("\n--- Using Tuned Parameters: ---")
    for k in ["SPECULAR_THRESHOLD_HSV_V", "SPECULAR_THRESHOLD_HSV_S", "SPECULAR_THRESHOLD_WAVELET", "MIN_CONTOUR_AREA", "DILATION_KERNEL_SIZE"]:
        if k in effective_config: print(f"  {k}: {effective_config[k]}")
    print("---------------------------------")

    for dir_key in ["OUTPUT_ORIGINAL_DIR", "OUTPUT_GROUND_TRUTH_DIR", "DETECTION_SAMPLE_DIR"]:
        create_dir_if_not_exists(effective_config[dir_key])

    cap = cv2.VideoCapture(effective_config["VIDEO_PATH"])
    if not cap.isOpened(): print(f"Error: No video: {effective_config['VIDEO_PATH']}"); return
    total_frames_in_video = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    print(f"Video frames: {total_frames_in_video}")

    print("Step 1: Reading all frames and detecting specular masks...")
    all_frames_list = []
    all_masks_list = []
    
    for i in tqdm(range(total_frames_in_video), desc="Reading & Masking Frames"):
        ret, frame = cap.read()
        if not ret: print(f"Warn: No frame {i}. Stop."); break
        all_frames_list.append(frame)
        # Detect mask using original frame size
        mask = detect_specular_regions(frame, effective_config)
        all_masks_list.append(mask)
    cap.release()
    
    if not all_frames_list: print("No frames read."); return
    num_loaded_frames = len(all_frames_list)
    print(f"Loaded {num_loaded_frames} frames and their masks.")

    print("\nStep 2: Performing temporal inpainting with validation for ground truth...")
    saved_frame_idx = 0
    for current_frame_idx in tqdm(range(num_loaded_frames), desc="Temporal Inpainting & Saving"):
        if effective_config["FRAME_SKIP"] > 0 and current_frame_idx % (effective_config["FRAME_SKIP"] + 1) != 0:
            continue

        original_frame = all_frames_list[current_frame_idx]
        
        # Perform temporal inpainting with validation
        gt_frame = temporally_inpaint_frame_with_validation(
            current_frame_idx, all_frames_list, all_masks_list, effective_config
        )

        orig_resized = cv2.resize(original_frame, effective_config["TARGET_IMG_SIZE"])
        gt_resized = cv2.resize(gt_frame, effective_config["TARGET_IMG_SIZE"])
        
        fname = f"frame_{saved_frame_idx:06d}.png"
        cv2.imwrite(os.path.join(effective_config["OUTPUT_ORIGINAL_DIR"], fname), orig_resized)
        cv2.imwrite(os.path.join(effective_config["OUTPUT_GROUND_TRUTH_DIR"], fname), gt_resized)

        if saved_frame_idx > 0 and saved_frame_idx % effective_config["SAMPLE_EVERY_N_FRAMES"] == 0:
            # For sample, use the mask that was generated for this frame
            specular_mask_for_sample = all_masks_list[current_frame_idx]
            mask_resized = cv2.resize(specular_mask_for_sample.astype(np.uint8)*255, 
                                      effective_config["TARGET_IMG_SIZE"], 
                                      interpolation=cv2.INTER_NEAREST)
            mask_vis = cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2BGR)
            orig_sample_highlighted = orig_resized.copy()
            orig_sample_highlighted[mask_resized > 0] = [0, 0, 255]
            sample_display = np.concatenate((orig_resized, orig_sample_highlighted, gt_resized, mask_vis), axis=1)
            cv2.imwrite(os.path.join(effective_config["DETECTION_SAMPLE_DIR"], f"gt_gen_samp_{saved_frame_idx:06d}.png"), sample_display)
        
        saved_frame_idx += 1

    print(f"\nComplete. {saved_frame_idx} paired frames saved.")
    print(f"Originals: {effective_config['OUTPUT_ORIGINAL_DIR']}")
    print(f"Ground Truth: {effective_config['OUTPUT_GROUND_TRUTH_DIR']}")
    print(f"Detection Samples: {effective_config['DETECTION_SAMPLE_DIR']}")
    print(f"Tuning Samples were in: {effective_config['TUNING_SAMPLE_DIR']}")

if __name__ == "__main__":
    main()


--- Parameter Tuning ---
  1. HSV Val Thresh (hsv_v): 130
  2. HSV Sat Max    (hsv_s): 75
  3. Wavelet Thresh (wavelet): 0.05
  4. Min Contour Area (area): 3
  5. Dilation Kernel (dilate): 13

Gen samples in 'dataset/tuning_samples'...
Review samples in 'dataset/tuning_samples'.


Acceptable? (y/n/q):  y


Params accepted.

--- Using Tuned Parameters: ---
  SPECULAR_THRESHOLD_HSV_V: 130
  SPECULAR_THRESHOLD_HSV_S: 75
  SPECULAR_THRESHOLD_WAVELET: 0.05
  MIN_CONTOUR_AREA: 3
  DILATION_KERNEL_SIZE: 13
---------------------------------
Created directory: dataset/ground_truth_frames_temporal_v2
Created directory: dataset/detection_samples_temporal_v2
Video frames: 1548
Step 1: Reading all frames and detecting specular masks...


Reading & Masking Frames: 100%|██████████| 1548/1548 [08:44<00:00,  2.95it/s]


Loaded 1548 frames and their masks.

Step 2: Performing temporal inpainting with validation for ground truth...


Temporal Inpainting & Saving: 100%|██████████| 1548/1548 [30:42<00:00,  1.19s/it]


Complete. 1548 paired frames saved.
Originals: dataset/original_frames
Ground Truth: dataset/ground_truth_frames_temporal_v2
Detection Samples: dataset/detection_samples_temporal_v2
Tuning Samples were in: dataset/tuning_samples



