In [1]:
!pip install torch torchvision numpy opencv-python pywavelets pytorch_wavelets tqdm

Collecting pytorch_wavelets
  Downloading pytorch_wavelets-1.3.0-py3-none-any.whl.metadata (10.0 kB)
Downloading pytorch_wavelets-1.3.0-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.9/54.9 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pytorch_wavelets
Successfully installed pytorch_wavelets-1.3.0
[0m

In [3]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import pywt
import traceback
from tqdm import tqdm
from pytorch_wavelets import DWTForward

# ==================== CONFIGURATION ====================
# Change these parameters according to your needs
INPUT_DIR = "output_3_18/classified_clear"  # Directory containing blurry images
OUTPUT_DIR = "deblurred_output"  # Directory to save processed images
MODEL_PATH = "wavelet_unet_bn_best.pth"  # Path to model weights
PROCESS_EVERY_N = 100  # Process every Nth image
MONTAGE_COLS = 1  # Number of image pairs per row in montage
TARGET_HEIGHT = None  # Target height for processing (optional)
TARGET_WIDTH = None  # Target width for processing (optional)
BATCH_SIZE = 1  # Batch size for processing
WAVELET = 'db1'  # Wavelet type for preprocessing
# ==================== END CONFIGURATION ====================

# Wavelet-U-Net Model with BatchNorm
class WaveletUNet_BN(nn.Module):
    def __init__(self, in_channels=3, wavelet_channels=12):
        super().__init__()
        def conv_block(in_ch, out_ch):
            block = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            )
            for m in block.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            return block

        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)
        self.wavelet_enc1 = conv_block(wavelet_channels, 64)
        self.wavelet_enc2 = conv_block(64, 128)
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec3 = conv_block(512 + 256 + 128, 256)
        self.dec2 = conv_block(256 + 128 + 64, 128)
        self.dec1 = conv_block(128 + 64, 64)
        self.final = nn.Conv2d(64, 3, kernel_size=1)
        nn.init.kaiming_normal_(self.final.weight, mode='fan_out', nonlinearity='linear')
        if self.final.bias is not None: 
            nn.init.constant_(self.final.bias, 0)

    def forward(self, x, wavelet):
        e1 = self.enc1(x); p1 = self.pool(e1)
        e2 = self.enc2(p1); p2 = self.pool(e2)
        e3 = self.enc3(p2); p3 = self.pool(e3)
        e4 = self.enc4(p3)
        w_feat2 = self.wavelet_enc1(wavelet); pw_feat2 = self.pool(w_feat2)
        w_feat3 = self.wavelet_enc2(pw_feat2)
        up3 = self.up(e4); cat3 = torch.cat([up3, e3, w_feat3], dim=1); d3 = self.dec3(cat3)
        up2 = self.up(d3); cat2 = torch.cat([up2, e2, w_feat2], dim=1); d2 = self.dec2(cat2)
        up1 = self.up(d2); cat1 = torch.cat([up1, e1], dim=1); d1 = self.dec1(cat1)
        out = self.final(d1)
        return torch.sigmoid(out)

# Helper function for wavelet input generation
def get_wavelet_input(img_tensor, wavelet='db1', device='cpu'):
    # Input: img_tensor [B, C, H, W]
    B, C, H, W = img_tensor.shape
    img_np = img_tensor.detach().cpu().numpy().transpose(0, 2, 3, 1)  # -> [B, H, W, C]
    batch_wavelets = []
    target_h_half, target_w_half = H // 2, W // 2
    target_ch_out = C * 4

    for i in range(img_np.shape[0]):
        try:
            coeffs = pywt.dwt2(img_np[i], wavelet, mode='symmetric', axes=(-3, -2))
            cA, (cH, cV, cD) = coeffs
            wavelet_np = np.concatenate([cA, cH, cV, cD], axis=2).astype(np.float32)  # [H/2, W/2, 4*C]

            # Handle potential size mismatches due to odd dimensions in DWT
            h_np, w_np = wavelet_np.shape[:2]
            if h_np != target_h_half or w_np != target_w_half:
                wavelet_np = cv2.resize(wavelet_np, (target_w_half, target_h_half), interpolation=cv2.INTER_LINEAR)
                if wavelet_np.ndim == 2: 
                    wavelet_np = wavelet_np[:, :, np.newaxis]  # Add channel dim back if lost
                if wavelet_np.shape[2] != target_ch_out:  # Fix channel count if necessary
                    print(f"Warning: Wavelet channel mismatch after resize ({wavelet_np.shape[2]} vs {target_ch_out}).")
                    if wavelet_np.shape[2] < target_ch_out:
                        padding = np.zeros((target_h_half, target_w_half, target_ch_out - wavelet_np.shape[2]), dtype=np.float32)
                        wavelet_np = np.concatenate([wavelet_np, padding], axis=2)
                    else:
                        wavelet_np = wavelet_np[:, :, :target_ch_out]

            # Normalize wavelet coefficients (per channel)
            for ch in range(wavelet_np.shape[2]):
                channel_data = wavelet_np[:, :, ch]
                std = channel_data.std()
                wavelet_np[:, :, ch] = (channel_data - channel_data.mean()) / (std + 1e-8)

            batch_wavelets.append(wavelet_np)

        except Exception as e:
            print(f"Error generating pywt input for item {i}: {e}. Using zeros.")
            zero_wavelet = np.zeros((target_h_half, target_w_half, target_ch_out), dtype=np.float32)
            batch_wavelets.append(zero_wavelet)

    wavelet_batch_np = np.stack(batch_wavelets)  # [B, H/2, W/2, 4*C]
    wavelet_tensor = torch.from_numpy(wavelet_batch_np).permute(0, 3, 1, 2).float()  # [B, 4*C, H/2, W/2]
    return wavelet_tensor.to(device)

# Utility function to load and preprocess an image for the model
def load_and_preprocess_image(img_path, target_size=None):
    img = cv2.imread(img_path)
    if img is None:
        base, ext = os.path.splitext(img_path)
        if not ext:
            for try_ext in [".png", ".jpg", ".jpeg"]:
                img = cv2.imread(img_path + try_ext)
                if img is not None:
                    break
        if img is None:
            raise ValueError(f"Failed to load image: {img_path}")

    if len(img.shape) == 2:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
    elif img.shape[2] == 4:
        img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
    elif img.shape[2] != 3:
        raise ValueError(f"Image {img_path} has unexpected shape {img.shape}")

    # Store original size for later
    original_h, original_w = img.shape[:2]
    
    # Resize if requested
    if target_size and (original_h != target_size[0] or original_w != target_size[1]):
        img = cv2.resize(img, (target_size[1], target_size[0]), interpolation=cv2.INTER_LINEAR)
    
    # Convert to RGB and normalize to [0,1]
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    
    # Convert to tensor [C,H,W]
    img_tensor = torch.from_numpy(img_rgb).permute(2, 0, 1).float()
    
    return img_tensor, original_h, original_w

# Function to process a single image
def process_single_image(model, image_path, device, target_size=None, wavelet=WAVELET):
    try:
        # Load and preprocess the image
        img_tensor, orig_h, orig_w = load_and_preprocess_image(image_path, target_size)
        
        # Add batch dimension and send to device
        img_batch = img_tensor.unsqueeze(0).to(device)
        
        # Generate wavelet input
        wavelet_batch = get_wavelet_input(img_batch, wavelet=wavelet, device=device)
        
        # Process with model
        with torch.no_grad():
            output_batch = model(img_batch, wavelet_batch)
        
        # Convert output to numpy image
        output_tensor = output_batch[0].cpu()
        output_np = output_tensor.permute(1, 2, 0).numpy()
        output_np = np.clip(output_np, 0, 1)
        
        # Convert to BGR for OpenCV and rescale to uint8
        output_bgr = cv2.cvtColor((output_np * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
        
        # Resize back to original dimensions if needed
        if target_size and (orig_h != target_size[0] or orig_w != target_size[1]):
            output_bgr = cv2.resize(output_bgr, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
        
        return output_bgr
    
    except Exception as e:
        print(f"Error processing image {image_path}: {e}")
        traceback.print_exc()
        return None

# Function to create comparison image
def create_comparison(input_path, output_image):
    # Read the original input image
    input_image = cv2.imread(input_path)
    if input_image is None:
        print(f"Error reading input image: {input_path}")
        return None
    
    # Ensure both images have the same dimensions
    if input_image.shape != output_image.shape:
        output_image = cv2.resize(output_image, (input_image.shape[1], input_image.shape[0]))
    
    # Create a side-by-side comparison image with labels
    h, w = input_image.shape[:2]
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.8
    font_thickness = 2
    
    # Add labels to the images
    cv2.putText(input_image, "Before", (10, 30), font, font_scale, (0, 0, 255), font_thickness)
    cv2.putText(output_image, "After", (10, 30), font, font_scale, (0, 255, 0), font_thickness)
    
    # Create the comparison image
    comparison = np.hstack((input_image, output_image))
    
    return comparison

# Function to create a montage of multiple comparisons
def create_montage(comparisons, cols=MONTAGE_COLS):
    if not comparisons:
        return None
    
    # Get the dimensions of the first comparison
    h, w = comparisons[0].shape[:2]
    
    # Calculate the number of rows needed
    rows = (len(comparisons) + cols - 1) // cols
    
    # Create an empty montage
    montage = np.zeros((rows * h, cols * w, 3), dtype=np.uint8)
    
    # Fill the montage with comparisons
    for i, comp in enumerate(comparisons):
        row = i // cols
        col = i % cols
        
        # Handle the case where the last row isn't full
        if i >= len(comparisons):
            break
            
        y_start = row * h
        y_end = y_start + h
        x_start = col * w
        x_end = x_start + w
        
        # Ensure the dimensions match (resize if needed)
        if comp.shape[:2] != (h, w):
            comp = cv2.resize(comp, (w, h))
            
        # Place the comparison in the montage
        montage[y_start:y_end, x_start:x_end] = comp
    
    return montage

# Main processing function
def process_images():
    # Calculate target size
    target_size = None
    if TARGET_HEIGHT and TARGET_WIDTH:
        target_size = (TARGET_HEIGHT, TARGET_WIDTH)
        print(f"Using target size: {target_size}")
    
    # Check if input directory exists
    if not os.path.isdir(INPUT_DIR):
        print(f"Error: Input directory {INPUT_DIR} does not exist.")
        return
    
    # Create output directory structure if it doesn't exist
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    os.makedirs(os.path.join(OUTPUT_DIR, "individual"), exist_ok=True)
    
    # Set up device - use GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Create model and load weights
    try:
        model = WaveletUNet_BN().to(device)
        
        # Check if model path has a specific format suffix
        if MODEL_PATH.endswith('.pth'):
            try:
                # Try loading as a state dict directly
                model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
                print(f"Loaded state dict from {MODEL_PATH}")
            except Exception as e:
                # If that fails, try loading as a checkpoint dictionary
                print(f"Direct loading failed: {e}. Trying checkpoint format...")
                checkpoint = torch.load(MODEL_PATH, map_location=device)
                if 'model_state_dict' in checkpoint:
                    model.load_state_dict(checkpoint['model_state_dict'])
                    print(f"Loaded checkpoint from {MODEL_PATH}")
                else:
                    raise ValueError(f"Could not find model state dict in checkpoint")
        else:
            raise ValueError(f"Model path should end with .pth")
        
        # Set model to evaluation mode
        model.eval()
        print("Model loaded successfully")
    except Exception as e:
        print(f"Error loading model: {e}")
        traceback.print_exc()
        return
    
    # Get all image files from input directory
    image_files = []
    for ext in ['.png', '.jpg', '.jpeg', '.bmp']:
        image_files.extend(sorted([os.path.join(INPUT_DIR, f) for f in os.listdir(INPUT_DIR) 
                            if f.lower().endswith(ext)]))
    
    if not image_files:
        print(f"No image files found in {INPUT_DIR}")
        return
    
    print(f"Found {len(image_files)} image files in {INPUT_DIR}")
    
    # Select every Nth image
    selected_images = image_files[::PROCESS_EVERY_N]
    print(f"Selected {len(selected_images)} images to process (every {PROCESS_EVERY_N}th image)")
    
    # Process images
    comparisons = []
    for i, img_path in enumerate(tqdm(selected_images, desc="Processing images")):
        try:
            # Process the image
            output_image = process_single_image(model, img_path, device, target_size, WAVELET)
            
            if output_image is None:
                print(f"Failed to process {img_path}, skipping...")
                continue
            
            # Save individual processed image
            img_name = os.path.basename(img_path)
            output_path = os.path.join(OUTPUT_DIR, "individual", f"processed_{img_name}")
            cv2.imwrite(output_path, output_image)
            
            # Create and save comparison
            comparison = create_comparison(img_path, output_image)
            if comparison is not None:
                comparison_path = os.path.join(OUTPUT_DIR, f"comparison_{img_name}")
                cv2.imwrite(comparison_path, comparison)
                comparisons.append(comparison)
            
        except Exception as e:
            print(f"Error processing {img_path}: {e}")
    
    # Create and save montage
    if comparisons:
        montage = create_montage(comparisons, cols=MONTAGE_COLS)
        if montage is not None:
            montage_path = os.path.join(OUTPUT_DIR, "montage.png")
            cv2.imwrite(montage_path, montage)
            print(f"Saved montage to {montage_path}")
    
    print("Processing complete!")
    print(f"Processed {len(comparisons)} images")
    print(f"Results saved to {OUTPUT_DIR}")

if __name__ == "__main__":
    process_images()

Using device: cuda
Direct loading failed: Error(s) in loading state_dict for WaveletUNet_BN:
	Missing key(s) in state_dict: "enc1.0.weight", "enc1.1.weight", "enc1.1.bias", "enc1.1.running_mean", "enc1.1.running_var", "enc1.3.weight", "enc1.4.weight", "enc1.4.bias", "enc1.4.running_mean", "enc1.4.running_var", "enc2.0.weight", "enc2.1.weight", "enc2.1.bias", "enc2.1.running_mean", "enc2.1.running_var", "enc2.3.weight", "enc2.4.weight", "enc2.4.bias", "enc2.4.running_mean", "enc2.4.running_var", "enc3.0.weight", "enc3.1.weight", "enc3.1.bias", "enc3.1.running_mean", "enc3.1.running_var", "enc3.3.weight", "enc3.4.weight", "enc3.4.bias", "enc3.4.running_mean", "enc3.4.running_var", "enc4.0.weight", "enc4.1.weight", "enc4.1.bias", "enc4.1.running_mean", "enc4.1.running_var", "enc4.3.weight", "enc4.4.weight", "enc4.4.bias", "enc4.4.running_mean", "enc4.4.running_var", "wavelet_enc1.0.weight", "wavelet_enc1.1.weight", "wavelet_enc1.1.bias", "wavelet_enc1.1.running_mean", "wavelet_enc1.1.runn

Processing images: 100%|██████████| 154/154 [00:14<00:00, 10.98it/s]


Saved montage to deblurred_output/montage.png
Processing complete!
Processed 154 images
Results saved to deblurred_output
