In [1]:
!pip install pytorch_wavelets

Collecting pytorch_wavelets
  Downloading pytorch_wavelets-1.3.0-py3-none-any.whl.metadata (10.0 kB)
Downloading pytorch_wavelets-1.3.0-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.9/54.9 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pytorch_wavelets
Successfully installed pytorch_wavelets-1.3.0
[0m

In [1]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pywt
from tqdm import tqdm
from pytorch_wavelets import DWTForward
import traceback
import random # Import random for sampling

# Custom Dataset for Paired Smoky/Clear Images
class ColonoscopyDataset(Dataset):
    def __init__(self, clear_dir, blurry_dir, wavelet='db1', max_samples=None):
        try:
            # Load all files first to determine total count before limiting
            all_clear_files = sorted([f for f in os.listdir(clear_dir) if os.path.isfile(os.path.join(clear_dir, f))])
            all_blurry_files = sorted([f for f in os.listdir(blurry_dir) if os.path.isfile(os.path.join(blurry_dir, f))])
        except FileNotFoundError as e:
            print(f"Error accessing directories: {e}")
            raise e

        if not all_clear_files or not all_blurry_files:
            raise ValueError(f"No files found in directories: Clear={len(all_clear_files)}, Blurry={len(all_blurry_files)}")

        min_len = min(len(all_clear_files), len(all_blurry_files))
        if len(all_clear_files) != len(all_blurry_files):
            print(f"Warning: Mismatched file counts. Using {min_len} pairs.")
            self.clear_files = all_clear_files[:min_len]
            self.blurry_files = all_blurry_files[:min_len]
        else:
            self.clear_files = all_clear_files
            self.blurry_files = all_blurry_files

        self.total_available_samples = len(self.clear_files) # Store total count before limiting

        # Apply max_samples limit for the training dataset instance
        if max_samples is not None and max_samples > 0 and self.total_available_samples > max_samples:
            print(f"Limiting training dataset from {self.total_available_samples} to {max_samples} samples.")
            self.clear_files = self.clear_files[:max_samples]
            self.blurry_files = self.blurry_files[:max_samples]
        else:
             # If max_samples is None or not limiting, use all available matched pairs
             print(f"Using {len(self.clear_files)} samples for this dataset instance.") # Use len(self.clear_files) after potential limit


        self.clear_dir = clear_dir
        self.blurry_dir = blurry_dir
        self.wavelet = wavelet

    def __len__(self):
        # This length reflects the samples used by THIS dataset instance (potentially limited)
        return len(self.clear_files)

    def get_total_available_samples(self):
        # Helper to get the count BEFORE applying max_samples limit
        return self.total_available_samples

    def get_filenames_by_index(self, idx):
        # Helper to get filenames for specific indices (relative to the potentially limited list of this instance)
         if idx >= len(self.clear_files):
                raise IndexError(f"Index {idx} out of bounds for current dataset size {len(self.clear_files)}")
         return self.clear_files[idx], self.blurry_files[idx]

    def _load_and_preprocess_image(self, img_path):
        img = cv2.imread(img_path)
        if img is None:
            base, ext = os.path.splitext(img_path)
            if not ext:
                for try_ext in [".png", ".jpg", ".jpeg"]: # Added jpeg
                    img = cv2.imread(img_path + try_ext)
                    if img is not None:
                        break
            if img is None:
                raise ValueError(f"Failed to load image: {img_path}")

        if len(img.shape) == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
        elif img.shape[2] == 4:
            img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
        elif img.shape[2] != 3:
            raise ValueError(f"Image {img_path} has unexpected shape {img.shape}")

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
        return img

    def __getitem__(self, idx):
        try:
            # This uses the potentially limited file lists of the current instance
            clear_file, blurry_file = self.get_filenames_by_index(idx)
            clear_img_path = os.path.join(self.clear_dir, clear_file)
            blurry_img_path = os.path.join(self.blurry_dir, blurry_file)

            clear_img_np = self._load_and_preprocess_image(clear_img_path)
            blurry_img_np = self._load_and_preprocess_image(blurry_img_path)

            # Resize blurry to match clear if needed
            if clear_img_np.shape[:2] != blurry_img_np.shape[:2]:
                target_h, target_w = clear_img_np.shape[:2]
                blurry_img_np = cv2.resize(blurry_img_np, (target_w, target_h), interpolation=cv2.INTER_LINEAR)

            clear_img_tensor = torch.from_numpy(clear_img_np).permute(2, 0, 1).float()
            blurry_img_tensor = torch.from_numpy(blurry_img_np).permute(2, 0, 1).float()

            return blurry_img_tensor, clear_img_tensor
        except Exception as e:
            print(f"Error loading item at index {idx}: {e}")
            traceback.print_exc()
            # Fallback for errors during training iteration
            if idx > 0:
                try:
                    return self.__getitem__(0) # Try loading the first item
                except: # If even item 0 fails, return dummy
                     dummy_tensor = torch.zeros((3, 256, 256), dtype=torch.float32)
                     return dummy_tensor, dummy_tensor
            else: # Error on item 0 itself
                dummy_tensor = torch.zeros((3, 256, 256), dtype=torch.float32)
                return dummy_tensor, dummy_tensor

# Wavelet-U-Net Model with BatchNorm
class WaveletUNet_BN(nn.Module):
    def __init__(self, in_channels=3, wavelet_channels=12):
        super().__init__()
        def conv_block(in_ch, out_ch):
            # Kaiming init added here for convenience
            block = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            )
            for m in block.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            return block

        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)
        self.wavelet_enc1 = conv_block(wavelet_channels, 64)
        self.wavelet_enc2 = conv_block(64, 128)
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec3 = conv_block(512 + 256 + 128, 256)
        self.dec2 = conv_block(256 + 128 + 64, 128)
        self.dec1 = conv_block(128 + 64, 64)
        self.final = nn.Conv2d(64, 3, kernel_size=1)
        # Init final layer weights
        nn.init.kaiming_normal_(self.final.weight, mode='fan_out', nonlinearity='linear')
        if self.final.bias is not None: nn.init.constant_(self.final.bias, 0)

    def forward(self, x, wavelet):
        e1 = self.enc1(x); p1 = self.pool(e1)
        e2 = self.enc2(p1); p2 = self.pool(e2)
        e3 = self.enc3(p2); p3 = self.pool(e3)
        e4 = self.enc4(p3)
        w_feat2 = self.wavelet_enc1(wavelet); pw_feat2 = self.pool(w_feat2)
        w_feat3 = self.wavelet_enc2(pw_feat2)
        up3 = self.up(e4); cat3 = torch.cat([up3, e3, w_feat3], dim=1); d3 = self.dec3(cat3)
        up2 = self.up(d3); cat2 = torch.cat([up2, e2, w_feat2], dim=1); d2 = self.dec2(cat2)
        up1 = self.up(d2); cat1 = torch.cat([up1, e1], dim=1); d1 = self.dec1(cat1)
        out = self.final(d1)
        return torch.sigmoid(out) # Output in [0, 1] range


# Combined Loss with Differentiable DWT
class CombinedLossDWT(nn.Module):
    def __init__(self, alpha=0.85, wavelet='db1', device='cpu'):
        super().__init__()
        self.alpha = alpha
        self.l1_loss = nn.L1Loss()
        # Ensure DWTForward is created only once and on the correct device
        self.dwt = DWTForward(J=1, wave=wavelet, mode='symmetric').to(device)
        self.device = device

    def _compute_dwt(self, x):
        # Input x shape: [B, C, H, W]
        x = x.to(self.device) # Ensure input is on the correct device
        LL, H_coeffs = self.dwt(x)
        # LL shape: [B, C, H/2, W/2]
        # H_coeffs is a list containing one tensor for J=1
        # H_coeffs[0] shape: [B, C, 3, H/2, W/2] where 3 is for LH, HL, HH
        details = H_coeffs[0]
        B, C, _, H_d, W_d = details.shape

        # Reshape details: [B, C, 3, H/2, W/2] -> [B, C*3, H/2, W/2]
        details_reshaped = details.reshape(B, C * 3, H_d, W_d)

        # Concatenate along channel dimension: [B, C + C*3, H/2, W/2] = [B, 4*C, H/2, W/2]
        coeffs_combined = torch.cat([LL, details_reshaped], dim=1)
        return coeffs_combined

    def forward(self, pred, target):
        pred, target = pred.float(), target.float()
        spatial_loss = self.l1_loss(pred, target)

        # Compute DWT only if alpha < 1 (frequency loss is used)
        if self.alpha < 1.0:
            try:
                pred_wavelet = self._compute_dwt(pred)
                target_wavelet = self._compute_dwt(target)
                freq_loss = self.l1_loss(pred_wavelet, target_wavelet)
            except Exception as e:
                 print(f"Error computing DWT loss: {e}. Setting freq_loss to 0.")
                 freq_loss = torch.tensor(0.0, device=self.device) # Avoid crashing if DWT fails
        else:
            freq_loss = torch.tensor(0.0, device=self.device) # No frequency loss if alpha is 1

        total_loss = self.alpha * spatial_loss + (1 - self.alpha) * freq_loss

        if torch.isnan(total_loss):
            print(f"NaN loss! Spatial: {spatial_loss.item()}, Freq: {freq_loss.item()}")
            # Return a large finite loss to allow scheduler/logging but indicate error
            return torch.tensor(1000.0, device=self.device, requires_grad=True)

        return total_loss

# Helper function for non-differentiable wavelet INPUT generation
def get_wavelet_input(img_tensor, wavelet='db1', device='cpu'):
    # Input: img_tensor [B, C, H, W]
    B, C, H, W = img_tensor.shape
    img_np = img_tensor.detach().cpu().numpy().transpose(0, 2, 3, 1) # -> [B, H, W, C]
    batch_wavelets = []
    target_h_half, target_w_half = H // 2, W // 2
    target_ch_out = C * 4

    for i in range(img_np.shape[0]):
        try:
            coeffs = pywt.dwt2(img_np[i], wavelet, mode='symmetric', axes=(-3, -2))
            cA, (cH, cV, cD) = coeffs
            wavelet_np = np.concatenate([cA, cH, cV, cD], axis=2).astype(np.float32) # [H/2, W/2, 4*C]

            # Handle potential size mismatches due to odd dimensions in DWT
            h_np, w_np = wavelet_np.shape[:2]
            if h_np != target_h_half or w_np != target_w_half:
                 wavelet_np = cv2.resize(wavelet_np, (target_w_half, target_h_half), interpolation=cv2.INTER_LINEAR)
                 if wavelet_np.ndim == 2: wavelet_np = wavelet_np[:, :, np.newaxis] # Add channel dim back if lost
                 if wavelet_np.shape[2] != target_ch_out: # Fix channel count if necessary (unlikely but defensive)
                     print(f"Warning: Wavelet channel mismatch after resize ({wavelet_np.shape[2]} vs {target_ch_out}).")
                     if wavelet_np.shape[2] < target_ch_out:
                         padding = np.zeros((target_h_half, target_w_half, target_ch_out - wavelet_np.shape[2]), dtype=np.float32)
                         wavelet_np = np.concatenate([wavelet_np, padding], axis=2)
                     else:
                         wavelet_np = wavelet_np[:, :, :target_ch_out]

            # Normalize wavelet coefficients (per channel)
            for ch in range(wavelet_np.shape[2]):
                channel_data = wavelet_np[:, :, ch]
                std = channel_data.std()
                wavelet_np[:, :, ch] = (channel_data - channel_data.mean()) / (std + 1e-8)

            batch_wavelets.append(wavelet_np)

        except Exception as e:
            print(f"Error generating pywt input for item {i}: {e}. Using zeros.")
            zero_wavelet = np.zeros((target_h_half, target_w_half, target_ch_out), dtype=np.float32)
            batch_wavelets.append(zero_wavelet)

    wavelet_batch_np = np.stack(batch_wavelets) # [B, H/2, W/2, 4*C]
    wavelet_tensor = torch.from_numpy(wavelet_batch_np).permute(0, 3, 1, 2).float() # [B, 4*C, H/2, W/2]
    return wavelet_tensor.to(device)

# Function to Load Specific Image Pairs by Indices
def load_image_pairs_by_indices(indices, clear_dir, blurry_dir, device='cpu'):
    """Loads specific clear/blurry image pairs given their indices from the full dataset."""
    # Create a dataset instance that loads ALL files to access any index based on original file lists
    full_dataset = ColonoscopyDataset(clear_dir, blurry_dir, max_samples=None) # Ensure no limit for loading

    loaded_blurry_tensors = []
    loaded_clear_tensors = []
    loaded_indices = []

    for idx in indices:
        # Check index against the total available samples found by full_dataset
        if idx >= full_dataset.get_total_available_samples():
            print(f"Warning: Requested sample index {idx} is out of bounds ({full_dataset.get_total_available_samples()}). Skipping.")
            continue
        try:
            # Use the __getitem__ of the full dataset instance which accesses files based on the index
            blurry_tensor, clear_tensor = full_dataset[idx]
            loaded_blurry_tensors.append(blurry_tensor)
            loaded_clear_tensors.append(clear_tensor)
            loaded_indices.append(idx)
        except Exception as e:
            print(f"Error loading sample pair at index {idx}: {e}. Skipping.")

    if not loaded_blurry_tensors:
        return None, None, None, []

    # Stack tensors into batches and move to device
    blurry_batch = torch.stack(loaded_blurry_tensors).to(device)
    clear_batch = torch.stack(loaded_clear_tensors).to(device)

    # Generate wavelet input for the blurry batch
    wavelet_input_batch = get_wavelet_input(blurry_batch, device=device)

    return blurry_batch, wavelet_input_batch, clear_batch, loaded_indices

# Utility function to convert tensor batch to visualizable numpy images
def tensors_to_cv2_images(tensor_batch):
    """ Converts a batch of [B, C, H, W] tensors (range [0,1]) to list of OpenCV images (BGR, uint8). """
    images = []
    if tensor_batch is None: return images # Handle None input
    tensor_batch = tensor_batch.detach().cpu()
    for i in range(tensor_batch.shape[0]):
        img_np = tensor_batch[i].numpy().transpose(1, 2, 0) # H, W, C
        img_np = np.clip(img_np, 0, 1) # Ensure range [0, 1]
        img_uint8 = (img_np * 255).astype(np.uint8)
        img_bgr = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2BGR) # Convert to BGR for OpenCV saving
        images.append(img_bgr)
    return images

# --- Training Function (MODIFIED Sampling Section) ---
def train_model(clear_dir, blurry_dir, output_dir, epochs=50, batch_size=4, lr=0.0005,
                wavelet='db1', alpha=0.85, weight_decay=1e-5, max_samples_train=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model = WaveletUNet_BN().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = CombinedLossDWT(alpha=alpha, wavelet=wavelet, device=device)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5, verbose=True)

    # Initialize dataset for training (potentially limited)
    try:
        train_dataset = ColonoscopyDataset(clear_dir, blurry_dir, wavelet=wavelet, max_samples=max_samples_train)
        if len(train_dataset) == 0: print("ERROR: Training dataset is empty!"); return None
        total_available_samples = train_dataset.get_total_available_samples() # Get total count from before limit
    except Exception as e: print(f"Error initializing dataset: {e}"); return None

    num_w = min(4, os.cpu_count() // 2 if os.cpu_count() > 1 else 1)
    print(f"Using {num_w} dataloader workers.")
    dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_w,
                            pin_memory=True if device.type == 'cuda' else False, drop_last=True)

    samples_dir = os.path.join(output_dir, "samples")
    checkpoints_dir = os.path.join(output_dir, "checkpoints")
    os.makedirs(samples_dir, exist_ok=True)
    os.makedirs(checkpoints_dir, exist_ok=True)

    # --- Determine indices for sampling ---
    num_samples_to_show = 5
    # Use max_samples_train to define the end of the training set
    train_set_end_idx = (max_samples_train - 1) if (max_samples_train is not None and max_samples_train > 0) else (total_available_samples - 1)
    # Sampling range starts *after* the training set
    sampling_start_idx = train_set_end_idx + 1
    sampling_end_idx = total_available_samples - 1 # Last available index overall
    num_available_for_sampling = max(0, sampling_end_idx - sampling_start_idx + 1)

    fixed_sample_indices = []
    if num_available_for_sampling >= num_samples_to_show:
        print(f"Selecting {num_samples_to_show} fixed random indices for sampling from range [{sampling_start_idx}, {sampling_end_idx}].")
        fixed_sample_indices = random.sample(range(sampling_start_idx, sampling_end_idx + 1), num_samples_to_show)
    elif num_available_for_sampling > 0:
        print(f"Warning: Only {num_available_for_sampling} samples available outside training set. Using all.")
        fixed_sample_indices = list(range(sampling_start_idx, sampling_end_idx + 1))
    else:
        print("Warning: No samples available outside the training set for sampling. Using last available samples instead.")
        # Fallback: use last samples from the *entire* dataset if none outside training
        fallback_start = max(0, total_available_samples - num_samples_to_show)
        fixed_sample_indices = list(range(fallback_start, total_available_samples))

    print(f"Using fixed sample indices: {fixed_sample_indices}")

    # Pre-load the fixed sample pairs ONCE before training loop
    sample_blurry_batch, sample_wavelet_batch, sample_clear_batch, loaded_indices = None, None, None, []
    if fixed_sample_indices: # Only load if indices were determined
        try:
            sample_blurry_batch, sample_wavelet_batch, sample_clear_batch, loaded_indices = \
                load_image_pairs_by_indices(fixed_sample_indices, clear_dir, blurry_dir, device=device)
            if sample_blurry_batch is None:
                 print("ERROR: Failed to load any sample images. Disabling sampling.")
                 fixed_sample_indices = [] # Disable sampling
            elif len(loaded_indices) != len(fixed_sample_indices):
                 print(f"Warning: Loaded {len(loaded_indices)} sample images, but requested {len(fixed_sample_indices)}. Using loaded ones.")
                 fixed_sample_indices = loaded_indices # Update indices to reflect reality
        except Exception as e:
            print(f"ERROR loading initial sample images: {e}. Disabling sampling.")
            fixed_sample_indices = []
    # --- End sampling setup ---


    # Save initial prediction using one of the sample images if available
    if fixed_sample_indices and sample_blurry_batch is not None:
        try:
            model.eval()
            with torch.no_grad():
                initial_pred = model(sample_blurry_batch[0:1], sample_wavelet_batch[0:1]) # Predict first sample
            initial_pred_img = tensors_to_cv2_images(initial_pred)[0] # Convert first prediction
            cv2.imwrite(os.path.join(samples_dir, "initial_pred_test.png"), initial_pred_img)
            print("Saved initial prediction sample.")
        except Exception as e:
            print(f"Error generating initial prediction: {e}")

    print(f"Training dataset size: {len(train_dataset)}")
    print(f"Total available pairs: {total_available_samples}")

    print("Starting training...")
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", leave=True)

        for i, batch in enumerate(progress_bar):
            # --- Training Step ---
            if batch is None: continue
            try: blurry_img, clear_img = batch
            except Exception as e: print(f"Error unpacking batch {i}: {e}"); continue

            blurry_img = blurry_img.to(device); clear_img = clear_img.to(device)
            if not isinstance(blurry_img, torch.Tensor) or not isinstance(clear_img, torch.Tensor): continue
            if blurry_img.ndim != 4 or clear_img.ndim != 4: continue

            try: wavelet_input = get_wavelet_input(blurry_img, wavelet=wavelet, device=device)
            except Exception as e: print(f"Error generating wavelet input batch {i}: {e}"); continue

            optimizer.zero_grad()
            try: pred = model(blurry_img, wavelet_input)
            except Exception as e: print(f"\nForward pass error batch {i}: {e}"); traceback.print_exc(); continue

            try: loss = criterion(pred, clear_img)
            except Exception as e: print(f"\nLoss calculation error batch {i}: {e}"); traceback.print_exc(); continue

            if torch.isnan(loss): print(f"NaN loss detected at batch {i}, skipping backward/step."); optimizer.zero_grad(); continue

            try: loss.backward()
            except Exception as e: print(f"\nBackward pass error batch {i}: {e}"); traceback.print_exc(); continue

            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Optional

            try: optimizer.step()
            except Exception as e: print(f"\nOptimizer step error batch {i}: {e}"); traceback.print_exc(); continue

            current_loss = loss.item()
            running_loss += current_loss
            progress_bar.set_postfix(loss=f"{current_loss:.6f}", avg_loss=f"{running_loss/(i+1):.6f}", lr=f"{optimizer.param_groups[0]['lr']:.1e}")
        # --- End Batch Loop ---

        avg_loss = running_loss / len(dataloader) if len(dataloader) > 0 else 0
        print(f"\nEpoch {epoch+1}/{epochs} finished. Average Training Loss: {avg_loss:.6f}")
        scheduler.step(avg_loss)


        # --- Generate and Save Sample Montage EVERY Epoch ---
        # Check if sampling is enabled and data was successfully loaded
        if fixed_sample_indices and sample_blurry_batch is not None and sample_clear_batch is not None:
            try:
                model.eval() # Set model to evaluation mode
                with torch.no_grad():
                    pred_sample_batch = model(sample_blurry_batch, sample_wavelet_batch)

                # Convert tensors to OpenCV images (BGR, uint8)
                clear_cv2 = tensors_to_cv2_images(sample_clear_batch) # Row 1: Ground Truth Clear
                pred_cv2 = tensors_to_cv2_images(pred_sample_batch)   # Row 2: Model Output
                blurry_cv2 = tensors_to_cv2_images(sample_blurry_batch) # Row 3: Original Blurry Input

                # Ensure all images for the montage have the same size
                # Use size of first predicted image as target (or clear image, should be same if dataset is consistent)
                if pred_cv2: # Check if prediction list is not empty
                     target_h, target_w = pred_cv2[0].shape[:2]
                elif clear_cv2:
                     target_h, target_w = clear_cv2[0].shape[:2]
                else: # Cannot determine target size
                     print(f"Warning: Cannot determine target size for montage epoch {epoch+1}. Skipping.")
                     continue # Skip montage generation for this epoch

                # Resize function with check
                def resize_img_list(img_list, target_w, target_h):
                     return [cv2.resize(img, (target_w, target_h), interpolation=cv2.INTER_LINEAR) if img.shape[:2] != (target_h, target_w) else img for img in img_list]

                resized_clear = resize_img_list(clear_cv2, target_w, target_h)
                resized_pred = resize_img_list(pred_cv2, target_w, target_h)
                resized_blurry = resize_img_list(blurry_cv2, target_w, target_h)

                # Create the montage if all lists have content
                if resized_clear and resized_pred and resized_blurry:
                    row1 = cv2.hconcat(resized_clear)  # Top: Ground Truth Clear
                    row2 = cv2.hconcat(resized_pred)   # Middle: Model Output
                    row3 = cv2.hconcat(resized_blurry) # Bottom: Original Blurry Input

                    # Add padding between rows
                    padding_height = 10
                    padding = np.zeros((padding_height, row1.shape[1], 3), dtype=np.uint8) # Black padding

                    # Stack rows vertically: Clear, Padding, Prediction, Padding, Blurry
                    montage = cv2.vconcat([row1, padding.copy(), row2, padding.copy(), row3])

                    sample_path = os.path.join(samples_dir, f"sample_montage_epoch_{epoch+1:03d}.png")
                    cv2.imwrite(sample_path, montage)
                    # print(f"Saved sample montage to {sample_path}") # Reduce frequency if desired
                else:
                     print(f"Warning: Could not generate montage for epoch {epoch+1}, one or more image lists were empty after processing.")

            except Exception as e:
                print(f"Error generating sample montage for epoch {epoch+1}: {e}")
                traceback.print_exc()
            finally:
                 model.train() # Ensure model is back in training mode
        # --- End Sample Montage Generation ---

        # Save model checkpoint periodically
        if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
            checkpoint_path = os.path.join(checkpoints_dir, f"wavelet_unet_bn_epoch_{epoch+1:03d}.pth")
            try:
                 torch.save({
                    'epoch': epoch + 1, 'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(),
                    'loss': avg_loss,
                 }, checkpoint_path)
                 print(f"Saved checkpoint to {checkpoint_path}")
            except Exception as e: print(f"Error saving checkpoint: {e}")
    # --- End Epoch Loop ---

    print("Training finished.")
    final_model_path = os.path.join(output_dir, "wavelet_unet_bn_final.pth")
    try:
         torch.save(model.state_dict(), final_model_path)
         print(f"Saved final model to {final_model_path}")
    except Exception as e: print(f"Error saving final model: {e}")
    return model



if __name__ == "__main__":
    clear_dir = "output_3_18/smokeV2/clear"
    blurry_dir = "output_3_18/smokeV2/blurry"
    # Update output directory name for the 3-row montage version
    output_dir = "output_3_18/smoke/output_V2"

    os.makedirs(output_dir, exist_ok=True)
    print(f"Clear image directory: {os.path.abspath(clear_dir)}")
    print(f"Blurry image directory: {os.path.abspath(blurry_dir)}")
    print(f"Output directory: {os.path.abspath(output_dir)}")

    if not os.path.isdir(clear_dir) or not os.path.isdir(blurry_dir):
        print(f"Error: Directory not found - Check paths!")
    else:
        try:
            # Using 5000 samples and batch size 8 as per last user code
            model = train_model(
                clear_dir, blurry_dir, output_dir,
                epochs=50,
                batch_size=8,       # As used in the previous code
                lr=0.0005,
                wavelet='db1',
                alpha=0.85,
                weight_decay=1e-5,
                max_samples_train=5000 # As used in the previous code
            )
            if model:
                print("Training process completed.")
            else:
                print("Training process failed.")
        except Exception as e:
            print(f"An unexpected error occurred during script execution: {e}")
            traceback.print_exc()

Clear image directory: /notebooks/output_3_18/smokeV2/clear
Blurry image directory: /notebooks/output_3_18/smokeV2/blurry
Output directory: /notebooks/output_3_18/smoke/output_V2
Using device: cuda
Limiting training dataset from 10212 to 5000 samples.
Using 4 dataloader workers.
Selecting 5 fixed random indices for sampling from range [5000, 10211].
Using fixed sample indices: [8009, 8703, 6311, 9161, 6144]
Using 10212 samples for this dataset instance.
Saved initial prediction sample.
Training dataset size: 5000
Total available pairs: 10212
Starting training...


Epoch 1/50: 100%|██████████| 625/625 [13:28<00:00,  1.29s/it, avg_loss=0.015049, loss=0.008244, lr=5.0e-04]



Epoch 1/50 finished. Average Training Loss: 0.015049


Epoch 2/50: 100%|██████████| 625/625 [13:47<00:00,  1.32s/it, avg_loss=0.007532, loss=0.006601, lr=5.0e-04]



Epoch 2/50 finished. Average Training Loss: 0.007532


Epoch 3/50: 100%|██████████| 625/625 [13:35<00:00,  1.30s/it, avg_loss=0.006383, loss=0.005555, lr=5.0e-04]



Epoch 3/50 finished. Average Training Loss: 0.006383


Epoch 4/50: 100%|██████████| 625/625 [13:54<00:00,  1.34s/it, avg_loss=0.006428, loss=0.005662, lr=5.0e-04]



Epoch 4/50 finished. Average Training Loss: 0.006428


Epoch 5/50: 100%|██████████| 625/625 [13:44<00:00,  1.32s/it, avg_loss=0.005642, loss=0.004905, lr=5.0e-04]



Epoch 5/50 finished. Average Training Loss: 0.005642


Epoch 6/50: 100%|██████████| 625/625 [13:51<00:00,  1.33s/it, avg_loss=0.005253, loss=0.004955, lr=5.0e-04]



Epoch 6/50 finished. Average Training Loss: 0.005253


Epoch 7/50: 100%|██████████| 625/625 [13:49<00:00,  1.33s/it, avg_loss=0.004848, loss=0.004565, lr=5.0e-04]



Epoch 7/50 finished. Average Training Loss: 0.004848


Epoch 8/50: 100%|██████████| 625/625 [13:46<00:00,  1.32s/it, avg_loss=0.004682, loss=0.004280, lr=5.0e-04]



Epoch 8/50 finished. Average Training Loss: 0.004682


Epoch 9/50: 100%|██████████| 625/625 [14:06<00:00,  1.36s/it, avg_loss=0.004364, loss=0.004044, lr=5.0e-04]



Epoch 9/50 finished. Average Training Loss: 0.004364


Epoch 10/50:  34%|███▍      | 211/625 [04:41<09:12,  1.33s/it, avg_loss=0.004236, loss=0.003780, lr=5.0e-04]
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7fc761428e10>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


KeyboardInterrupt: 