In [1]:
!pip install pytorch_wavelets

Collecting pytorch_wavelets
  Downloading pytorch_wavelets-1.3.0-py3-none-any.whl.metadata (10.0 kB)
Downloading pytorch_wavelets-1.3.0-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.9/54.9 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pytorch_wavelets
Successfully installed pytorch_wavelets-1.3.0
[0m

In [2]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import pywt
from tqdm import tqdm
from pytorch_wavelets import DWTForward # Keep this for the loss
import traceback
import random
import csv
import pandas as pd
from datetime import datetime
import sys # Needed for sys.exit
import matplotlib.pyplot as plt # Needed for plotting

# Make sure training is reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Custom Dataset for Paired Smoky/Clear Images
class ColonoscopyDataset(Dataset):
    def __init__(self, clear_dir, blurry_dir, wavelet='db1'):
        try:
            # Load all files first
            all_clear_files = sorted([f for f in os.listdir(clear_dir) if os.path.isfile(os.path.join(clear_dir, f))])
            all_blurry_files = sorted([f for f in os.listdir(blurry_dir) if os.path.isfile(os.path.join(blurry_dir, f))])
        except FileNotFoundError as e:
            print(f"Error accessing directories: {e}")
            raise e

        if not all_clear_files or not all_blurry_files:
            raise ValueError(f"No files found in directories: Clear={len(all_clear_files)}, Blurry={len(all_blurry_files)}")

        min_len = min(len(all_clear_files), len(all_blurry_files))
        if len(all_clear_files) != len(all_blurry_files):
            print(f"Warning: Mismatched file counts ({len(all_clear_files)} clear vs {len(all_blurry_files)} blurry). Using {min_len} pairs.")
            self.clear_files = all_clear_files[:min_len]
            self.blurry_files = all_blurry_files[:min_len]
        else:
            self.clear_files = all_clear_files
            self.blurry_files = all_blurry_files

        self.total_samples = len(self.clear_files)
        print(f"Total paired samples available for this dataset instance: {self.total_samples}") # Clarified print

        self.clear_dir = clear_dir
        self.blurry_dir = blurry_dir
        self.wavelet = wavelet

    def __len__(self):
        return len(self.clear_files)

    def get_filenames_by_index(self, idx):
        if idx >= len(self.clear_files):
            raise IndexError(f"Index {idx} out of bounds for dataset size {len(self.clear_files)}")
        return self.clear_files[idx], self.blurry_files[idx]

    def _load_and_preprocess_image(self, img_path):
        img = cv2.imread(img_path)
        if img is None:
            # Try adding common extensions if load fails initially
            base, ext = os.path.splitext(img_path)
            if not ext:
                for try_ext in [".png", ".jpg", ".jpeg", ".bmp", ".tiff"]:
                    try_path = img_path + try_ext
                    if os.path.exists(try_path):
                       img = cv2.imread(try_path)
                       if img is not None:
                           # print(f"Successfully loaded {try_path}") # Optional debug
                           break
            if img is None:
                raise ValueError(f"Failed to load image: {img_path} (even with added extensions)")

        if len(img.shape) == 2: # Grayscale
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
        elif img.shape[2] == 4: # BGRA
            img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
        elif img.shape[2] != 3: # Unexpected channels
            raise ValueError(f"Image {img_path} has unexpected shape {img.shape}")

        # Ensure output is RGB, float32, [0, 1] range
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
        return img

    def __getitem__(self, idx):
        try:
            clear_file, blurry_file = self.get_filenames_by_index(idx)
            clear_img_path = os.path.join(self.clear_dir, clear_file)
            blurry_img_path = os.path.join(self.blurry_dir, blurry_file)

            clear_img_np = self._load_and_preprocess_image(clear_img_path)
            blurry_img_np = self._load_and_preprocess_image(blurry_img_path)

            # Resize blurry to match clear if needed (important!)
            if clear_img_np.shape[:2] != blurry_img_np.shape[:2]:
                target_h, target_w = clear_img_np.shape[:2]
                # print(f"Warning: Resizing blurry image {blurry_file} from {blurry_img_np.shape[:2]} to match clear image {clear_file} { (target_h, target_w)}") # Optional warning
                blurry_img_np = cv2.resize(blurry_img_np, (target_w, target_h), interpolation=cv2.INTER_LINEAR)

            clear_img_tensor = torch.from_numpy(clear_img_np).permute(2, 0, 1).float()
            blurry_img_tensor = torch.from_numpy(blurry_img_np).permute(2, 0, 1).float()

            return blurry_img_tensor, clear_img_tensor
        except Exception as e:
            print(f"-----------------------------------")
            print(f"FATAL Error loading item at index {idx}:")
            print(f"Clear file attempted: {os.path.join(self.clear_dir, self.clear_files[idx]) if idx < len(self.clear_files) else 'Index out of bounds'}")
            print(f"Blurry file attempted: {os.path.join(self.blurry_dir, self.blurry_files[idx]) if idx < len(self.blurry_files) else 'Index out of bounds'}")
            print(f"Error details: {e}")
            traceback.print_exc()
            print(f"-----------------------------------")
            # Try returning a neighboring item or a dummy tensor to avoid crashing the loader if possible
            # Return dummy tensor if error occurs, allowing training to continue potentially
            # Adjust dummy shape if your images aren't 256x256
            dummy_h, dummy_w = (256, 256)
            if idx > 0:
                try:
                    print(f"Attempting to return item at index {idx-1} instead.")
                    return self.__getitem__(idx-1)
                except:
                    print(f"Fallback failed. Returning dummy tensor for index {idx}.")
                    dummy_tensor = torch.zeros((3, dummy_h, dummy_w), dtype=torch.float32)
                    return dummy_tensor, dummy_tensor
            else: # Error on item 0 itself
                 print(f"Error on first item (index 0). Returning dummy tensor.")
                 dummy_tensor = torch.zeros((3, dummy_h, dummy_w), dtype=torch.float32)
                 return dummy_tensor, dummy_tensor

# Wavelet-U-Net Model with BatchNorm
class WaveletUNet_BN(nn.Module):
    def __init__(self, in_channels=3, wavelet_channels=12): # wavelet_channels = C*4 = 3*4
        super().__init__()
        def conv_block(in_ch, out_ch):
            block = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            )
            # Initialize weights
            for m in block.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
            return block

        self.enc1 = conv_block(in_channels, 64)           # Input: B, 3, H, W -> Output: B, 64, H, W
        self.pool1 = nn.MaxPool2d(2)                       # -> B, 64, H/2, W/2
        self.enc2 = conv_block(64, 128)                    # -> B, 128, H/2, W/2
        self.pool2 = nn.MaxPool2d(2)                       # -> B, 128, H/4, W/4
        self.enc3 = conv_block(128, 256)                   # -> B, 256, H/4, W/4
        self.pool3 = nn.MaxPool2d(2)                       # -> B, 256, H/8, W/8
        self.enc4 = conv_block(256, 512)                   # -> B, 512, H/8, W/8

        # Wavelet feature pathway (operates on wavelet input [B, 12, H/2, W/2])
        self.wavelet_enc1 = conv_block(wavelet_channels, 64) # Input: B, 12, H/2, W/2 -> Output: B, 64, H/2, W/2
        self.pool_w1 = nn.MaxPool2d(2)                     # -> B, 64, H/4, W/4
        self.wavelet_enc2 = conv_block(64, 128)            # -> B, 128, H/4, W/4

        self.up3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # Upsample e4: B, 512, H/4, W/4
        # Concat: [up3(e4), e3, w_feat3] = [B, 512, H/4, W/4] + [B, 256, H/4, W/4] + [B, 128, H/4, W/4] = B, 896, H/4, W/4
        self.dec3 = conv_block(512 + 256 + 128, 256)       # -> B, 256, H/4, W/4

        self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # Upsample d3: B, 256, H/2, W/2
        # Concat: [up2(d3), e2, w_feat2] = [B, 256, H/2, W/2] + [B, 128, H/2, W/2] + [B, 64, H/2, W/2] = B, 448, H/2, W/2
        self.dec2 = conv_block(256 + 128 + 64, 128)        # -> B, 128, H/2, W/2

        self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # Upsample d2: B, 128, H, W
        # Concat: [up1(d2), e1] = [B, 128, H, W] + [B, 64, H, W] = B, 192, H, W
        self.dec1 = conv_block(128 + 64, 64)               # -> B, 64, H, W

        self.final = nn.Conv2d(64, 3, kernel_size=1)       # -> B, 3, H, W
        # Init final layer weights
        nn.init.kaiming_normal_(self.final.weight, mode='fan_out', nonlinearity='linear')
        if self.final.bias is not None: nn.init.constant_(self.final.bias, 0)

    def forward(self, x, wavelet):
        # Encoder path
        e1 = self.enc1(x);          # B, 64, H, W
        p1 = self.pool1(e1);        # B, 64, H/2, W/2
        e2 = self.enc2(p1);         # B, 128, H/2, W/2
        p2 = self.pool2(e2);        # B, 128, H/4, W/4
        e3 = self.enc3(p2);         # B, 256, H/4, W/4
        p3 = self.pool3(e3);        # B, 256, H/8, W/8
        e4 = self.enc4(p3);         # B, 512, H/8, W/8

        # Wavelet feature path
        w_feat2 = self.wavelet_enc1(wavelet) # B, 64, H/2, W/2 (Matches e2 size)
        pw_feat2 = self.pool_w1(w_feat2)   # B, 64, H/4, W/4
        w_feat3 = self.wavelet_enc2(pw_feat2)  # B, 128, H/4, W/4 (Matches e3 size)

        # Decoder path with skip connections and wavelet features
        up3 = self.up3(e4);         # B, 512, H/4, W/4
        cat3 = torch.cat([up3, e3, w_feat3], dim=1); # B, 512+256+128=896, H/4, W/4
        d3 = self.dec3(cat3);       # B, 256, H/4, W/4

        up2 = self.up2(d3);         # B, 256, H/2, W/2
        cat2 = torch.cat([up2, e2, w_feat2], dim=1); # B, 256+128+64=448, H/2, W/2
        d2 = self.dec2(cat2);       # B, 128, H/2, W/2

        up1 = self.up1(d2);         # B, 128, H, W
        cat1 = torch.cat([up1, e1], dim=1); # B, 128+64=192, H, W
        d1 = self.dec1(cat1);       # B, 64, H, W

        out = self.final(d1);       # B, 3, H, W
        return torch.sigmoid(out) # Output in [0, 1] range

# Combined Loss with Differentiable DWT
class CombinedLossDWT(nn.Module):
    def __init__(self, alpha=0.85, wavelet='db1', device='cpu'):
        super().__init__()
        self.alpha = alpha
        self.l1_loss = nn.L1Loss()
        # Ensure DWTForward is created only once and on the correct device
        try:
            self.dwt = DWTForward(J=1, wave=wavelet, mode='symmetric').to(device)
            self._dwt_initialized = True
        except Exception as e:
            print(f"ERROR initializing DWTForward: {e}. Frequency loss will be disabled.")
            self._dwt_initialized = False
        self.device = device

    def _compute_dwt(self, x):
        # Input x shape: [B, C, H, W]
        x = x.to(self.device) # Ensure input is on the correct device
        LL, H_coeffs = self.dwt(x)
        # LL shape: [B, C, H/2, W/2]
        # H_coeffs is a list containing one tensor for J=1
        # H_coeffs[0] shape: [B, C, 3, H/2, W/2] where 3 is for LH, HL, HH
        details = H_coeffs[0]
        B, C, _, H_d, W_d = details.shape

        # Reshape details: [B, C, 3, H/2, W/2] -> [B, C*3, H/2, W/2]
        details_reshaped = details.reshape(B, C * 3, H_d, W_d)

        # Concatenate along channel dimension: [B, C + C*3, H/2, W/2] = [B, 4*C, H/2, W/2]
        coeffs_combined = torch.cat([LL, details_reshaped], dim=1)
        return coeffs_combined

    def forward(self, pred, target):
        pred, target = pred.float().to(self.device), target.float().to(self.device) # Ensure on correct device
        spatial_loss = self.l1_loss(pred, target)

        # Compute DWT only if alpha < 1 and DWT was initialized
        if self.alpha < 1.0 and self._dwt_initialized:
            try:
                pred_wavelet = self._compute_dwt(pred)
                target_wavelet = self._compute_dwt(target)
                freq_loss = self.l1_loss(pred_wavelet, target_wavelet)
            except Exception as e:
                 print(f"Warning: Error computing DWT loss during forward pass: {e}. Setting freq_loss to 0 for this batch.")
                 # traceback.print_exc() # Optional: more detail
                 freq_loss = torch.tensor(0.0, device=self.device) # Avoid crashing if DWT fails
        else:
            freq_loss = torch.tensor(0.0, device=self.device) # No frequency loss if alpha is 1 or DWT failed init

        total_loss = self.alpha * spatial_loss + (1 - self.alpha) * freq_loss

        # Check for NaN loss RIGHT BEFORE returning
        if torch.isnan(total_loss):
            print(f"NaN loss detected! Spatial: {spatial_loss.item():.6f}, Freq: {freq_loss.item():.6f}, Alpha: {self.alpha}")
            print("Check input data, model outputs, or DWT calculations.")
            # Return a large finite loss to allow scheduler/logging but indicate error
            # Make sure it requires grad if backprop is expected
            return torch.tensor(1000.0, device=self.device, requires_grad=True)

        return total_loss

# Helper function for non-differentiable wavelet INPUT generation (using pywt)
def get_wavelet_input(img_tensor, wavelet='db1', device='cpu'):
    # Input: img_tensor [B, C, H, W]
    if not isinstance(img_tensor, torch.Tensor):
        raise TypeError(f"Expected img_tensor to be a torch.Tensor, got {type(img_tensor)}")
    if img_tensor.ndim != 4:
         raise ValueError(f"Expected img_tensor to have 4 dimensions [B, C, H, W], got shape {img_tensor.shape}")

    B, C, H, W = img_tensor.shape
    # Ensure tensor is on CPU for numpy conversion
    img_np = img_tensor.detach().cpu().numpy().transpose(0, 2, 3, 1) # -> [B, H, W, C]
    batch_wavelets = []
    # Target size for wavelet coefficients (H/2, W/2)
    target_h_half, target_w_half = H // 2, W // 2
    target_ch_out = C * 4 # LL, LH, HL, HH for each input channel C

    for i in range(img_np.shape[0]): # Iterate through batch
        single_img_np = img_np[i] # H, W, C
        if not np.isfinite(single_img_np).all():
             print(f"Warning: Non-finite values detected in input image at batch index {i}. Replacing with zeros before DWT.")
             single_img_np = np.nan_to_num(single_img_np) # Replace NaN/inf with 0

        try:
            # Perform 2D DWT on each channel separately? No, pywt handles multichannel via axes
            # axes=(-3, -2) should correspond to H, W dimensions in (H, W, C)
            coeffs = pywt.dwt2(single_img_np, wavelet, mode='symmetric', axes=(0, 1))
            cA, (cH, cV, cD) = coeffs
            # cA, cH, cV, cD will have shape like [H/2, W/2, C]

            # Concatenate along the channel dimension (axis=2)
            # Result shape: [H/2, W/2, 4*C]
            wavelet_np = np.concatenate([cA, cH, cV, cD], axis=2).astype(np.float32)

            # Handle potential size mismatches due to odd dimensions in DWT
            h_np, w_np = wavelet_np.shape[:2]
            if h_np != target_h_half or w_np != target_w_half:
                 # print(f"Warning: Resizing pywt output from {(h_np, w_np)} to {(target_h_half, target_w_half)} for batch item {i}") # Optional warning
                 # Resize requires (width, height) order for cv2
                 wavelet_np = cv2.resize(wavelet_np, (target_w_half, target_h_half), interpolation=cv2.INTER_LINEAR)
                 # If resize squashes channels (e.g., single channel input), restore dim
                 if wavelet_np.ndim == 2: wavelet_np = wavelet_np[:, :, np.newaxis]
                 # Check channel count after potential resize (unlikely to change but defensive)
                 if wavelet_np.shape[2] != target_ch_out:
                     print(f"FATAL: Wavelet channel mismatch after resize ({wavelet_np.shape[2]} vs {target_ch_out}). This should not happen.")
                     # Fallback: create zeros of correct shape
                     wavelet_np = np.zeros((target_h_half, target_w_half, target_ch_out), dtype=np.float32)


            # Normalize wavelet coefficients (per channel within the 4*C channels)
            for ch in range(wavelet_np.shape[2]):
                channel_data = wavelet_np[:, :, ch]
                mean = channel_data.mean()
                std = channel_data.std()
                wavelet_np[:, :, ch] = (channel_data - mean) / (std + 1e-8) # Add epsilon for stability

            batch_wavelets.append(wavelet_np)

        except Exception as e:
            print(f"-----------------------------------")
            print(f"Error generating pywt input for batch item {i}: {e}")
            traceback.print_exc()
            print(f"Input image shape: {single_img_np.shape}")
            print(f"Using zeros as fallback for this item.")
            print(f"-----------------------------------")
            # Use zeros if pywt fails for an item
            zero_wavelet = np.zeros((target_h_half, target_w_half, target_ch_out), dtype=np.float32)
            batch_wavelets.append(zero_wavelet)

    # Stack the processed wavelets into a single batch numpy array
    wavelet_batch_np = np.stack(batch_wavelets) # [B, H/2, W/2, 4*C]
    # Permute to PyTorch format [B, C', H', W'] where C'=4*C, H'=H/2, W'=W/2
    wavelet_tensor = torch.from_numpy(wavelet_batch_np).permute(0, 3, 1, 2).float()
    return wavelet_tensor.to(device)

# Utility function to convert tensor batch to visualizable numpy images
def tensors_to_cv2_images(tensor_batch):
    """ Converts a batch of [B, C, H, W] tensors (range [0,1]) to list of OpenCV images (BGR, uint8). """
    images = []
    if tensor_batch is None: return images # Handle None input
    if not isinstance(tensor_batch, torch.Tensor): return images # Handle non-tensor input

    tensor_batch = tensor_batch.detach().cpu() # Move to CPU and detach from graph
    for i in range(tensor_batch.shape[0]):
        img_tensor = tensor_batch[i] # Get single image tensor [C, H, W]
        # Ensure 3 channels for color conversion
        if img_tensor.shape[0] == 1: # Grayscale tensor
             img_tensor = img_tensor.repeat(3, 1, 1) # Convert to 3 channels
        elif img_tensor.shape[0] != 3:
             print(f"Warning: tensor_to_cv2_images expects 3 channels, got {img_tensor.shape[0]}. Skipping item {i}.")
             continue

        img_np = img_tensor.numpy().transpose(1, 2, 0) # Convert to H, W, C (NumPy format)
        img_np = np.clip(img_np, 0, 1) # Ensure range [0, 1] after potential model overshoot
        img_uint8 = (img_np * 255).astype(np.uint8) # Convert to uint8 [0, 255]
        img_bgr = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2BGR) # Convert RGB -> BGR for OpenCV saving
        images.append(img_bgr)
    return images

# Evaluate model on a batch of images (used potentially for validation samples)
def evaluate_batch(model, criterion, blurry_batch, clear_batch, device, wavelet='db1'):
    """Evaluates the model on a single batch and returns loss and predictions."""
    model.eval() # Set model to evaluation mode
    with torch.no_grad(): # Disable gradient calculation
        # Ensure data is on the correct device
        blurry_batch = blurry_batch.to(device)
        clear_batch = clear_batch.to(device)

        # Generate wavelet input
        wavelet_batch = get_wavelet_input(blurry_batch, wavelet=wavelet, device=device)

        # Get model predictions
        predictions = model(blurry_batch, wavelet_batch)

        # Calculate loss
        loss = criterion(predictions, clear_batch).item() # .item() gets scalar value

    # Return model to training mode? No, let the caller handle it.
    return loss, predictions

# Function to create a montage of sample images
def create_sample_montage(blurry_batch, pred_batch, clear_batch, sample_indices):
    """Creates a 3-row montage: Ground Truth, Prediction, Blurry Input."""
    # Convert tensors to OpenCV images (list of BGR uint8 numpy arrays)
    clear_cv2 = tensors_to_cv2_images(clear_batch)    # Row 1: Ground Truth
    pred_cv2 = tensors_to_cv2_images(pred_batch)      # Row 2: Model Output
    blurry_cv2 = tensors_to_cv2_images(blurry_batch)  # Row 3: Blurry Input

    # Check if conversion yielded any images
    if not clear_cv2 or not pred_cv2 or not blurry_cv2:
        print("Warning: Failed to convert one or more tensor batches to CV2 images for montage.")
        return None

    # Determine target size for consistency (use prediction size as reference)
    target_h, target_w = pred_cv2[0].shape[:2]

    # Resize function (ensure consistent dimensions for hconcat/vconcat)
    def resize_img_list(img_list, target_w, target_h):
        resized_list = []
        for img in img_list:
            if img is None: continue # Skip None images
            if img.shape[:2] != (target_h, target_w):
                try:
                    resized_img = cv2.resize(img, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
                    resized_list.append(resized_img)
                except Exception as e:
                    print(f"Warning: Failed to resize image in montage: {e}")
            else:
                resized_list.append(img)
        return resized_list

    # Resize all image lists
    resized_clear = resize_img_list(clear_cv2, target_w, target_h)
    resized_pred = resize_img_list(pred_cv2, target_w, target_h)
    resized_blurry = resize_img_list(blurry_cv2, target_w, target_h)

    # Check if we still have images after resizing
    if not resized_clear or not resized_pred or not resized_blurry:
         print("Warning: One or more image lists became empty after resizing for montage.")
         return None
    # Ensure all lists have the same number of images after potential resize failures?
    # For simplicity, we'll assume they match the number of predictions if resizing worked.
    num_images = len(resized_pred)
    resized_clear = resized_clear[:num_images]
    resized_blurry = resized_blurry[:num_images]


    # Add index labels to images
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.6 # Slightly larger font
    thickness = 1
    text_color = (255, 255, 255)  # White
    bg_color = (0, 0, 0) # Black background for text
    text_y_offset = 20
    label_y_offset = 45

    for i in range(num_images):
        if i < len(sample_indices): # Check index validity
            idx_text = f"Orig Idx: {sample_indices[i]}"

            # Add text with background rectangle for better visibility
            def add_text_with_bg(img, text, org_y, scale=font_scale, thickness=thickness):
                 (text_width, text_height), _ = cv2.getTextSize(text, font, scale, thickness)
                 cv2.rectangle(img, (5, org_y - text_height - 2), (5 + text_width + 4, org_y + 4), bg_color, -1)
                 cv2.putText(img, text, (7, org_y), font, scale, text_color, thickness, cv2.LINE_AA)

            # Add labels to each image in the row
            add_text_with_bg(resized_clear[i], "Ground Truth", text_y_offset)
            add_text_with_bg(resized_clear[i], idx_text, label_y_offset)

            add_text_with_bg(resized_pred[i], "Prediction", text_y_offset)
            add_text_with_bg(resized_pred[i], idx_text, label_y_offset)

            add_text_with_bg(resized_blurry[i], "Blurry Input", text_y_offset)
            add_text_with_bg(resized_blurry[i], idx_text, label_y_offset)
        else:
             print(f"Warning: Not enough sample indices provided ({len(sample_indices)}) for image {i} in montage.")

    # Create the montage by horizontally concatenating images in each row
    try:
        row1 = cv2.hconcat(resized_clear)  # Ground Truth
        row2 = cv2.hconcat(resized_pred)   # Model Output
        row3 = cv2.hconcat(resized_blurry) # Blurry Input

        # Add padding between rows
        padding_h = 10
        # Ensure padding width matches the concatenated row width
        padding = np.zeros((padding_h, row1.shape[1], 3), dtype=np.uint8) # Black padding

        # Stack rows vertically: Truth, Padding, Prediction, Padding, Blurry
        montage = cv2.vconcat([row1, padding.copy(), row2, padding.copy(), row3])
        return montage

    except Exception as e:
         print(f"Error during hconcat/vconcat for montage: {e}")
         print(f"Shapes - Clear: {[img.shape for img in resized_clear]}, Pred: {[img.shape for img in resized_pred]}, Blurry: {[img.shape for img in resized_blurry]}")
         traceback.print_exc()
         return None

# --- Training Function ---
def train_model(clear_dir, blurry_dir, output_dir, epochs=50, batch_size=8, lr=0.0005,
                wavelet='db1', alpha=0.85, weight_decay=1e-5,
                train_samples_per_epoch=1000, # Max samples for training subset per epoch
                val_samples_per_epoch=100,    # Max samples for validation subset per epoch
                train_val_split=0.8, random_seed=42, num_workers=None):

    # Set random seed for reproducibility
    random.seed(random_seed)
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(random_seed) # for multi-GPU

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize full dataset
    try:
        full_dataset = ColonoscopyDataset(clear_dir, blurry_dir, wavelet=wavelet)
        total_samples = len(full_dataset)
        if total_samples == 0:
            raise ValueError("Dataset is empty! Check clear/blurry directories and file pairing.")
    except Exception as e:
        print(f"Error initializing dataset: {e}")
        traceback.print_exc()
        return None

    # Create train/validation split indices
    indices = list(range(total_samples))
    random.shuffle(indices) # Shuffle indices before splitting
    split_point = int(total_samples * train_val_split)
    train_indices = indices[:split_point]
    val_indices = indices[split_point:]

    # Handle cases where requested samples per epoch exceed available data
    num_train_available = len(train_indices)
    num_val_available = len(val_indices)
    actual_train_samples_per_epoch = min(train_samples_per_epoch, num_train_available)
    actual_val_samples_per_epoch = min(val_samples_per_epoch, num_val_available)

    print(f"Total samples: {total_samples}")
    print(f"Train/Val split: {num_train_available}/{num_val_available} ({train_val_split*100:.1f}% / {(1-train_val_split)*100:.1f}%)")
    print(f"Using {actual_train_samples_per_epoch} train samples per epoch (requested: {train_samples_per_epoch})")
    print(f"Using {actual_val_samples_per_epoch} val samples per epoch (requested: {val_samples_per_epoch})")
    if num_train_available == 0 or num_val_available == 0:
         print("ERROR: Training or validation set has 0 samples after split. Cannot proceed.")
         return None

    # Create model, optimizer, and loss function
    model = WaveletUNet_BN().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = CombinedLossDWT(alpha=alpha, wavelet=wavelet, device=device)
    # Reduce LR if val loss plateaus for 'patience' epochs
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

    # Create output directories
    samples_dir = os.path.join(output_dir, "samples")
    checkpoints_dir = os.path.join(output_dir, "checkpoints")
    os.makedirs(samples_dir, exist_ok=True)
    os.makedirs(checkpoints_dir, exist_ok=True)

    # Create CSV file for training/validation loss tracking
    csv_path = os.path.join(output_dir, "train_status.csv")
    try:
        with open(csv_path, 'w', newline='') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(['Epoch', 'Training Loss', 'Validation Loss', 'Learning Rate', 'Timestamp'])
    except IOError as e:
        print(f"Error creating CSV log file {csv_path}: {e}")
        return None # Stop if we can't log

    # Best model tracking
    best_val_loss = float('inf')
    best_model_path = os.path.join(output_dir, "wavelet_unet_bn_best.pth") # Defined outside loop

    # Determine num_workers
    if num_workers is None:
        num_workers = min(4, os.cpu_count() // 2 if os.cpu_count() is not None and os.cpu_count() > 1 else 1)
        print(f"Auto-detected num_workers: {num_workers}")
    else:
        print(f"Using specified num_workers: {num_workers}")


    print(f"\n--- Starting Training for {epochs} epochs ---")
    for epoch in range(epochs):
        epoch_start_time = datetime.now()
        print(f"\nEpoch {epoch+1}/{epochs}")

        # ===== TRAINING PHASE =====
        model.train() # Set model to training mode
        train_running_loss = 0.0
        processed_train_samples = 0

        # Randomly sample indices for THIS epoch's training subset
        epoch_train_indices = random.sample(train_indices, actual_train_samples_per_epoch)
        train_subset = Subset(full_dataset, epoch_train_indices)
        # pin_memory=True can speed up CPU->GPU transfer if using CUDA
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True,
                                 num_workers=num_workers, pin_memory=(device.type == 'cuda'), drop_last=True)

        print(f"Training with {len(epoch_train_indices)} samples ({len(train_loader)} batches of size {batch_size})")

        # Training loop with progress bar
        train_progress = tqdm(train_loader, desc=f"Train E{epoch+1}", leave=False, unit="batch")
        for i, batch in enumerate(train_progress):
            try:
                blurry_img, clear_img = batch
                # Skip if batch is None or tensors are invalid (e.g., from dummy data)
                if blurry_img is None or clear_img is None or not isinstance(blurry_img, torch.Tensor) or not isinstance(clear_img, torch.Tensor):
                     print(f"Warning: Skipping invalid batch {i} in training.")
                     continue
                if blurry_img.shape[0] == 0 or clear_img.shape[0] == 0:
                     print(f"Warning: Skipping empty batch {i} in training.")
                     continue

                blurry_img = blurry_img.to(device, non_blocking=True)
                clear_img = clear_img.to(device, non_blocking=True)

                # Generate wavelet input for the blurry image batch
                wavelet_input = get_wavelet_input(blurry_img, wavelet=wavelet, device=device)

                # --- Forward pass ---
                optimizer.zero_grad() # Reset gradients before forward pass
                predictions = model(blurry_img, wavelet_input)

                # --- Calculate Loss ---
                loss = criterion(predictions, clear_img)

                # Skip batch if loss is NaN (potentially from bad data or DWT issues)
                if torch.isnan(loss):
                    print(f"\nWarning: NaN loss detected in training batch {i}. Skipping backward/step.")
                    optimizer.zero_grad() # Ensure grads are zeroed if skipping step
                    continue # Skip backward and step

                # --- Backward pass ---
                loss.backward()

                # Optional: Gradient clipping (uncomment if needed for stability)
                # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                # --- Optimizer step ---
                optimizer.step()

                # --- Update metrics ---
                current_loss = loss.item()
                batch_sample_count = blurry_img.size(0)
                train_running_loss += current_loss * batch_sample_count # Weighted by batch size
                processed_train_samples += batch_sample_count

                # Update progress bar description
                train_progress.set_postfix(loss=f"{current_loss:.6f}",
                                          avg_loss=f"{train_running_loss/processed_train_samples:.6f}",
                                          lr=f"{optimizer.param_groups[0]['lr']:.1e}")

            except Exception as e:
                print(f"\n-----------------------------------")
                print(f"ERROR in training batch {i} (Epoch {epoch+1}): {e}")
                print(f"Batch blurry shape: {blurry_img.shape if 'blurry_img' in locals() and isinstance(blurry_img, torch.Tensor) else 'N/A'}")
                print(f"Batch clear shape: {clear_img.shape if 'clear_img' in locals() and isinstance(clear_img, torch.Tensor) else 'N/A'}")
                traceback.print_exc()
                print(f"-----------------------------------")
                continue # Continue to next batch if one fails

        # Calculate average training loss for the epoch
        train_loss = train_running_loss / processed_train_samples if processed_train_samples > 0 else float('inf')
        print(f"Epoch {epoch+1} Training Avg Loss: {train_loss:.6f}")


        # ===== VALIDATION PHASE =====
        model.eval() # Set model to evaluation mode
        val_running_loss = 0.0
        processed_val_samples = 0

        # Randomly sample indices for THIS epoch's validation subset
        epoch_val_indices = random.sample(val_indices, actual_val_samples_per_epoch)
        val_subset = Subset(full_dataset, epoch_val_indices)
        val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, # No shuffle for validation
                               num_workers=num_workers, pin_memory=(device.type == 'cuda'), drop_last=False) # Don't drop last val batch

        print(f"Validating with {len(epoch_val_indices)} samples ({len(val_loader)} batches)")

        # Keep track of the first few validation samples for visualization
        vis_blurry_batch = None
        vis_clear_batch = None
        vis_pred_batch = None
        # Take indices corresponding to the *first validation batch* or up to 5 samples
        vis_indices = epoch_val_indices[:min(batch_size, 5)]

        # Validation loop with progress bar
        val_progress = tqdm(val_loader, desc=f"Validate E{epoch+1}", leave=False, unit="batch")
        with torch.no_grad(): # Disable gradient calculations for validation
            for i, batch in enumerate(val_progress):
                try:
                    blurry_img, clear_img = batch
                    if blurry_img is None or clear_img is None or not isinstance(blurry_img, torch.Tensor) or not isinstance(clear_img, torch.Tensor):
                         print(f"Warning: Skipping invalid batch {i} in validation.")
                         continue
                    if blurry_img.shape[0] == 0 or clear_img.shape[0] == 0:
                         print(f"Warning: Skipping empty batch {i} in validation.")
                         continue

                    blurry_img = blurry_img.to(device, non_blocking=True)
                    clear_img = clear_img.to(device, non_blocking=True)

                    # Save first batch for visualization (match size of vis_indices)
                    if i == 0 and len(blurry_img) > 0:
                        num_to_vis = min(len(blurry_img), len(vis_indices))
                        vis_blurry_batch = blurry_img[:num_to_vis]
                        vis_clear_batch = clear_img[:num_to_vis]
                        # Update vis_indices to reflect actual indices used if batch was smaller
                        vis_indices = epoch_val_indices[:num_to_vis]


                    # Wavelet preprocessing and forward pass
                    wavelet_input = get_wavelet_input(blurry_img, wavelet=wavelet, device=device)
                    predictions = model(blurry_img, wavelet_input)

                    # Save predictions corresponding to the first batch
                    if i == 0 and vis_blurry_batch is not None:
                        vis_pred_batch = predictions[:len(vis_indices)] # Match size

                    # --- Calculate validation loss ---
                    val_loss_batch = criterion(predictions, clear_img)

                    if torch.isnan(val_loss_batch):
                         print(f"\nWarning: NaN loss detected in validation batch {i}. Skipping.")
                         continue

                    # --- Update validation metrics ---
                    current_val_loss = val_loss_batch.item()
                    batch_sample_count = blurry_img.size(0)
                    val_running_loss += current_val_loss * batch_sample_count # Weighted by batch size
                    processed_val_samples += batch_sample_count

                    # Update progress bar
                    val_progress.set_postfix(loss=f"{current_val_loss:.6f}",
                                           avg_loss=f"{val_running_loss/processed_val_samples:.6f}")
                except Exception as e:
                    print(f"\n-----------------------------------")
                    print(f"ERROR in validation batch {i} (Epoch {epoch+1}): {e}")
                    print(f"Batch blurry shape: {blurry_img.shape if 'blurry_img' in locals() and isinstance(blurry_img, torch.Tensor) else 'N/A'}")
                    print(f"Batch clear shape: {clear_img.shape if 'clear_img' in locals() and isinstance(clear_img, torch.Tensor) else 'N/A'}")
                    traceback.print_exc()
                    print(f"-----------------------------------")
                    continue # Continue validation if one batch fails

        # Calculate average validation loss for the epoch
        val_loss = val_running_loss / processed_val_samples if processed_val_samples > 0 else float('inf')
        print(f"Epoch {epoch+1} Validation Avg Loss: {val_loss:.6f}")

        # --- Post-Epoch Actions ---

        # Update learning rate scheduler based on validation loss
        current_lr = optimizer.param_groups[0]['lr'] # Get LR before scheduler step
        scheduler.step(val_loss)
        new_lr = optimizer.param_groups[0]['lr']
        if new_lr < current_lr:
             print(f"Learning rate reduced to {new_lr:.1e}")

        # Log metrics to CSV
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        try:
            with open(csv_path, 'a', newline='') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([
                    epoch + 1,
                    f"{train_loss:.6f}" if train_loss != float('inf') else 'N/A', # Handle potential inf
                    f"{val_loss:.6f}" if val_loss != float('inf') else 'N/A', # Handle potential inf
                    f"{current_lr:.6f}", # Log LR used *during* the epoch
                    timestamp
                ])
        except IOError as e:
             print(f"Warning: Could not write to CSV log file {csv_path}: {e}")

        # Create and save sample visualization montage from the validation set
        if vis_blurry_batch is not None and vis_clear_batch is not None and vis_pred_batch is not None and vis_indices:
            try:
                # Pass the actual original indices used for visualization
                montage = create_sample_montage(vis_blurry_batch, vis_pred_batch, vis_clear_batch, vis_indices)
                if montage is not None:
                    sample_path = os.path.join(samples_dir, f"sample_montage_epoch_{epoch+1:03d}.png")
                    cv2.imwrite(sample_path, montage)
                    # print(f"Saved validation sample montage to {sample_path}") # Optional print
            except Exception as e:
                print(f"Error creating or saving sample montage for epoch {epoch+1}: {e}")
                traceback.print_exc()
        else:
            print(f"Skipping montage generation for epoch {epoch+1} (missing visualization data or indices).")


        # --- Checkpoint Saving ---

        # Save best model based on validation loss
        if val_loss < best_val_loss and val_loss != float('inf'):
            best_val_loss = val_loss
            try:
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'train_loss': train_loss,
                    'val_loss': val_loss,
                    'best_val_loss': best_val_loss,
                }, best_model_path) # Use the predefined path
                print(f"*** New best model saved with Validation Loss {val_loss:.6f} at Epoch {epoch+1} ***")
            except Exception as e:
                print(f"Error saving best model checkpoint: {e}")

        # Save periodic checkpoint (e.g., every 10 epochs or last epoch)
        if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
            checkpoint_path = os.path.join(checkpoints_dir, f"wavelet_unet_bn_epoch_{epoch+1:03d}.pth")
            try:
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'train_loss': train_loss,
                    'val_loss': val_loss,
                }, checkpoint_path)
                print(f"Saved periodic checkpoint to {checkpoint_path}")
            except Exception as e:
                print(f"Error saving periodic checkpoint: {e}")

        epoch_duration = datetime.now() - epoch_start_time
        print(f"Epoch {epoch+1} completed in {str(epoch_duration).split('.')[0]}")

    # --- End of Training Loop ---

    print("\n--- Training Finished ---")

    # Save final model state (useful regardless of best validation loss)
    final_model_path = os.path.join(output_dir, "wavelet_unet_bn_final.pth")
    try:
        torch.save({
            'epoch': epochs, # Save final epoch number
            'model_state_dict': model.state_dict(),
            'best_val_loss': best_val_loss, # Record the best val loss achieved
        }, final_model_path)
        print(f"Saved final model state to {final_model_path}")
    except Exception as e:
        print(f"Error saving final model state: {e}")

    # Generate and save loss curves plot
    try:
        # Read the CSV file with training history
        if os.path.exists(csv_path):
            history_df = pd.read_csv(csv_path)

            # Convert loss columns to numeric, coercing errors to NaN
            history_df['Training Loss'] = pd.to_numeric(history_df['Training Loss'], errors='coerce')
            history_df['Validation Loss'] = pd.to_numeric(history_df['Validation Loss'], errors='coerce')

            # Drop rows where loss is NaN (e.g., if logging failed for an epoch)
            history_df.dropna(subset=['Training Loss', 'Validation Loss'], inplace=True)

            if not history_df.empty:
                plt.figure(figsize=(12, 7)) # Wider figure
                plt.plot(history_df['Epoch'], history_df['Training Loss'], label='Training Loss', marker='o', linestyle='-')
                plt.plot(history_df['Epoch'], history_df['Validation Loss'], label='Validation Loss', marker='x', linestyle='--')

                # Add titles and labels
                plt.title('Training and Validation Loss Curves')
                plt.xlabel('Epoch')
                plt.ylabel('Loss (L1 + Wavelet L1)')
                plt.legend()
                plt.grid(True, which='both', linestyle='--', linewidth=0.5)
                plt.yscale('log') # Use log scale if losses vary greatly
                plt.tight_layout() # Adjust layout

                # Save the plot
                plot_path = os.path.join(output_dir, "loss_curves.png")
                plt.savefig(plot_path)
                print(f"Saved loss curves plot to {plot_path}")
                plt.close() # Close the plot to free memory
            else:
                print("No valid data found in CSV to plot loss curves.")
        else:
            print(f"Could not find CSV file {csv_path} to generate plot.")

    except ImportError:
        print("Plotting requires matplotlib and pandas. Please install them (`pip install matplotlib pandas`)")
    except Exception as e:
        print(f"Error creating loss curves plot: {e}")
        traceback.print_exc()

    print(f"Best validation loss achieved: {best_val_loss:.6f}")
    print(f"Find best model at: {best_model_path}")
    print(f"Find final model at: {final_model_path}")
    print(f"Find training status log at: {csv_path}")
    print(f"Find sample montages in: {samples_dir}")
    print(f"Find checkpoints in: {checkpoints_dir}")

    return model # Return the trained model


if __name__ == "__main__":
    # Configure directories (Use absolute paths or paths relative to script location)
    # Example using absolute paths (replace with your actual paths)
    base_notebook_dir = "/notebooks" # Or wherever your notebooks directory is mounted
    clear_dir = os.path.join(base_notebook_dir, "output_3_18/smokeV2/clear")
    blurry_dir = os.path.join(base_notebook_dir, "output_3_18/smokeV2/blurry")
    output_dir = os.path.join(base_notebook_dir, "output_3_18/smoke/output_V3_with_validation_FIXED")

    # Make sure the output directory exists
    try:
        os.makedirs(output_dir, exist_ok=True)
    except OSError as e:
        print(f"Error creating output directory {output_dir}: {e}")
        sys.exit(1)

    # Print configuration clearly
    print("--- Configuration ---")
    print(f"Clear image directory : {os.path.abspath(clear_dir)}")
    print(f"Blurry image directory: {os.path.abspath(blurry_dir)}")
    print(f"Output directory      : {os.path.abspath(output_dir)}")
    print("---------------------")

    # Check if input directories exist
    if not os.path.isdir(clear_dir):
        print(f"Error: Clear image directory not found at '{clear_dir}'")
        sys.exit(1)
    if not os.path.isdir(blurry_dir):
        print(f"Error: Blurry image directory not found at '{blurry_dir}'")
        sys.exit(1)

    try:
        # Start training with defined parameters
        trained_model = train_model(
            clear_dir=clear_dir,
            blurry_dir=blurry_dir,
            output_dir=output_dir,
            epochs=50,                     # Total number of epochs
            batch_size=8,                  # Batch size (adjust based on GPU memory)
            lr=0.0005,                     # Initial learning rate
            wavelet='db1',                 # Wavelet type for loss and input features
            alpha=0.85,                    # Loss balance (0.85*Spatial + 0.15*Frequency)
            weight_decay=1e-5,             # Weight decay for AdamW optimizer
            train_samples_per_epoch=1000,  # Max training samples per epoch subset
            val_samples_per_epoch=200,     # Max validation samples per epoch subset (increased for better stats)
            train_val_split=0.8,           # 80% training, 20% validation split
            random_seed=42,                # Seed for reproducibility
            num_workers=4                  # Number of CPU workers for DataLoader (adjust based on system)
        )

        if trained_model:
            print("\n--- Training script completed successfully! ---")
        else:
            print("\n--- Training script finished with errors (model returned None). ---")

    except KeyboardInterrupt:
         print("\n--- Training interrupted by user (KeyboardInterrupt). ---")
         sys.exit(0) # Exit cleanly
    except Exception as e:
        print(f"\n--- An unexpected error occurred during script execution ---")
        print(f"Error type: {type(e).__name__}")
        print(f"Error details: {e}")
        print("--- Traceback ---")
        traceback.print_exc()
        print("-----------------")
        sys.exit(1) # Exit with error code

--- Configuration ---
Clear image directory : /notebooks/output_3_18/smokeV2/clear
Blurry image directory: /notebooks/output_3_18/smokeV2/blurry
Output directory      : /notebooks/output_3_18/smoke/output_V3_with_validation_FIXED
---------------------
Using device: cuda
Total paired samples available for this dataset instance: 7418
Total samples: 7418
Train/Val split: 5934/1484 (80.0% / 20.0%)
Using 1000 train samples per epoch (requested: 1000)
Using 200 val samples per epoch (requested: 200)
Using specified num_workers: 4

--- Starting Training for 50 epochs ---

Epoch 1/50
Training with 1000 samples (125 batches of size 8)


                                                                                                            

Epoch 1 Training Avg Loss: 0.069688
Validating with 200 samples (25 batches)


                                                                                                 

Epoch 1 Validation Avg Loss: 0.026880
*** New best model saved with Validation Loss 0.026880 at Epoch 1 ***
Epoch 1 completed in 0:02:19

Epoch 2/50
Training with 1000 samples (125 batches of size 8)


                                                                                                            

Epoch 2 Training Avg Loss: 0.024240
Validating with 200 samples (25 batches)


                                                                                                 

Epoch 2 Validation Avg Loss: 0.017917
*** New best model saved with Validation Loss 0.017917 at Epoch 2 ***
Epoch 2 completed in 0:02:15

Epoch 3/50
Training with 1000 samples (125 batches of size 8)


                                                                                                            

Epoch 3 Training Avg Loss: 0.018895
Validating with 200 samples (25 batches)


                                                                                                 

Epoch 3 Validation Avg Loss: 0.017140
*** New best model saved with Validation Loss 0.017140 at Epoch 3 ***
Epoch 3 completed in 0:02:26

Epoch 4/50
Training with 1000 samples (125 batches of size 8)


                                                                                                            

Epoch 5 Training Avg Loss: 0.015055
Validating with 200 samples (25 batches)


                                                                                                 

Epoch 5 Validation Avg Loss: 0.014204
*** New best model saved with Validation Loss 0.014204 at Epoch 5 ***
Epoch 5 completed in 0:02:16

Epoch 6/50
Training with 1000 samples (125 batches of size 8)


                                                                                                            

Epoch 6 Training Avg Loss: 0.014571
Validating with 200 samples (25 batches)


                                                                                                 

Epoch 6 Validation Avg Loss: 0.014324
Epoch 6 completed in 0:02:18

Epoch 7/50
Training with 1000 samples (125 batches of size 8)


                                                                                                            

Epoch 7 Training Avg Loss: 0.013549
Validating with 200 samples (25 batches)


                                                                                                 

Epoch 7 Validation Avg Loss: 0.012837
*** New best model saved with Validation Loss 0.012837 at Epoch 7 ***
Epoch 7 completed in 0:02:16

Epoch 8/50
Training with 1000 samples (125 batches of size 8)


                                                                                                            

Epoch 8 Training Avg Loss: 0.012938
Validating with 200 samples (25 batches)


                                                                                                 

Epoch 8 Validation Avg Loss: 0.011431
*** New best model saved with Validation Loss 0.011431 at Epoch 8 ***
Epoch 8 completed in 0:02:13

Epoch 9/50
Training with 1000 samples (125 batches of size 8)


                                                                                                            

Epoch 9 Training Avg Loss: 0.012282
Validating with 200 samples (25 batches)


                                                                                                 

Epoch 9 Validation Avg Loss: 0.011952
Epoch 9 completed in 0:02:17

Epoch 10/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 10 Training Avg Loss: 0.012194
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 10 Validation Avg Loss: 0.012201
Saved periodic checkpoint to /notebooks/output_3_18/smoke/output_V3_with_validation_FIXED/checkpoints/wavelet_unet_bn_epoch_010.pth
Epoch 10 completed in 0:02:15

Epoch 11/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 11 Training Avg Loss: 0.011794
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 11 Validation Avg Loss: 0.011518
Epoch 11 completed in 0:02:14

Epoch 12/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 12 Training Avg Loss: 0.011499
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 12 Validation Avg Loss: 0.010871
*** New best model saved with Validation Loss 0.010871 at Epoch 12 ***
Epoch 12 completed in 0:02:19

Epoch 13/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 13 Training Avg Loss: 0.011179
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 13 Validation Avg Loss: 0.010417
*** New best model saved with Validation Loss 0.010417 at Epoch 13 ***
Epoch 13 completed in 0:02:18

Epoch 14/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 14 Training Avg Loss: 0.010910
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 14 Validation Avg Loss: 0.010556
Epoch 14 completed in 0:02:10

Epoch 15/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 15 Training Avg Loss: 0.010787
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 15 Validation Avg Loss: 0.011413
Epoch 15 completed in 0:02:15

Epoch 16/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 16 Training Avg Loss: 0.010369
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 16 Validation Avg Loss: 0.010264
*** New best model saved with Validation Loss 0.010264 at Epoch 16 ***
Epoch 16 completed in 0:02:12

Epoch 17/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 17 Training Avg Loss: 0.010567
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 17 Validation Avg Loss: 0.010001
*** New best model saved with Validation Loss 0.010001 at Epoch 17 ***
Epoch 17 completed in 0:02:11

Epoch 18/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 18 Training Avg Loss: 0.010278
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 18 Validation Avg Loss: 0.009865
*** New best model saved with Validation Loss 0.009865 at Epoch 18 ***
Epoch 18 completed in 0:02:15

Epoch 19/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 19 Training Avg Loss: 0.010189
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 19 Validation Avg Loss: 0.009764
*** New best model saved with Validation Loss 0.009764 at Epoch 19 ***
Epoch 19 completed in 0:02:11

Epoch 20/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 20 Training Avg Loss: 0.010088
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 20 Validation Avg Loss: 0.009178
*** New best model saved with Validation Loss 0.009178 at Epoch 20 ***
Saved periodic checkpoint to /notebooks/output_3_18/smoke/output_V3_with_validation_FIXED/checkpoints/wavelet_unet_bn_epoch_020.pth
Epoch 20 completed in 0:02:12

Epoch 21/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 21 Training Avg Loss: 0.009679
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 21 Validation Avg Loss: 0.009508
Epoch 21 completed in 0:02:12

Epoch 22/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 22 Training Avg Loss: 0.009765
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 22 Validation Avg Loss: 0.009654
Epoch 22 completed in 0:02:12

Epoch 23/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 23 Training Avg Loss: 0.009426
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 23 Validation Avg Loss: 0.008527
*** New best model saved with Validation Loss 0.008527 at Epoch 23 ***
Epoch 23 completed in 0:02:11

Epoch 24/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 24 Training Avg Loss: 0.009258
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 24 Validation Avg Loss: 0.009053
Epoch 24 completed in 0:02:12

Epoch 25/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 25 Training Avg Loss: 0.009437
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 25 Validation Avg Loss: 0.008276
*** New best model saved with Validation Loss 0.008276 at Epoch 25 ***
Epoch 25 completed in 0:02:13

Epoch 26/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 26 Training Avg Loss: 0.009506
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 26 Validation Avg Loss: 0.009185
Epoch 26 completed in 0:02:12

Epoch 27/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 27 Training Avg Loss: 0.009144
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 27 Validation Avg Loss: 0.009432
Epoch 27 completed in 0:02:11

Epoch 28/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 28 Training Avg Loss: 0.009037
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 28 Validation Avg Loss: 0.008845
Epoch 28 completed in 0:02:12

Epoch 29/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 29 Training Avg Loss: 0.008914
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 29 Validation Avg Loss: 0.008969
Epoch 29 completed in 0:02:10

Epoch 30/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 30 Training Avg Loss: 0.008770
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 30 Validation Avg Loss: 0.009004
Saved periodic checkpoint to /notebooks/output_3_18/smoke/output_V3_with_validation_FIXED/checkpoints/wavelet_unet_bn_epoch_030.pth
Epoch 30 completed in 0:02:12

Epoch 31/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 31 Training Avg Loss: 0.008692
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 31 Validation Avg Loss: 0.008375
Epoch 00031: reducing learning rate of group 0 to 2.5000e-04.
Learning rate reduced to 2.5e-04
Epoch 31 completed in 0:02:13

Epoch 32/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 32 Training Avg Loss: 0.008149
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 32 Validation Avg Loss: 0.007928
*** New best model saved with Validation Loss 0.007928 at Epoch 32 ***
Epoch 32 completed in 0:02:12

Epoch 33/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 33 Training Avg Loss: 0.008082
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 33 Validation Avg Loss: 0.007865
*** New best model saved with Validation Loss 0.007865 at Epoch 33 ***
Epoch 33 completed in 0:02:12

Epoch 34/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 34 Training Avg Loss: 0.007904
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 34 Validation Avg Loss: 0.007569
*** New best model saved with Validation Loss 0.007569 at Epoch 34 ***
Epoch 34 completed in 0:02:12

Epoch 35/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 35 Training Avg Loss: 0.007799
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 35 Validation Avg Loss: 0.007485
*** New best model saved with Validation Loss 0.007485 at Epoch 35 ***
Epoch 35 completed in 0:02:14

Epoch 36/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 36 Training Avg Loss: 0.007908
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 36 Validation Avg Loss: 0.008344
Epoch 36 completed in 0:02:12

Epoch 37/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 37 Training Avg Loss: 0.007880
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 37 Validation Avg Loss: 0.007644
Epoch 37 completed in 0:02:12

Epoch 38/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 38 Training Avg Loss: 0.007918
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 38 Validation Avg Loss: 0.008055
Epoch 38 completed in 0:02:11

Epoch 39/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 39 Training Avg Loss: 0.007590
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 39 Validation Avg Loss: 0.007317
*** New best model saved with Validation Loss 0.007317 at Epoch 39 ***
Epoch 39 completed in 0:02:13

Epoch 40/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 40 Training Avg Loss: 0.007844
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 40 Validation Avg Loss: 0.007404
Saved periodic checkpoint to /notebooks/output_3_18/smoke/output_V3_with_validation_FIXED/checkpoints/wavelet_unet_bn_epoch_040.pth
Epoch 40 completed in 0:02:11

Epoch 41/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 41 Training Avg Loss: 0.007901
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 41 Validation Avg Loss: 0.007589
Epoch 41 completed in 0:02:12

Epoch 42/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 42 Training Avg Loss: 0.007701
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 42 Validation Avg Loss: 0.007658
Epoch 42 completed in 0:02:12

Epoch 43/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 43 Training Avg Loss: 0.007807
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 43 Validation Avg Loss: 0.008181
Epoch 43 completed in 0:02:11

Epoch 44/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 44 Training Avg Loss: 0.007520
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 44 Validation Avg Loss: 0.007387
Epoch 44 completed in 0:02:11

Epoch 45/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 45 Training Avg Loss: 0.007470
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 45 Validation Avg Loss: 0.007211
*** New best model saved with Validation Loss 0.007211 at Epoch 45 ***
Epoch 45 completed in 0:02:13

Epoch 46/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 46 Training Avg Loss: 0.007468
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 46 Validation Avg Loss: 0.007309
Epoch 46 completed in 0:02:11

Epoch 47/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 47 Training Avg Loss: 0.007497
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 47 Validation Avg Loss: 0.007256
Epoch 47 completed in 0:02:11

Epoch 48/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 48 Training Avg Loss: 0.007369
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 48 Validation Avg Loss: 0.007720
Epoch 48 completed in 0:02:12

Epoch 49/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 49 Training Avg Loss: 0.007200
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 49 Validation Avg Loss: 0.007233
Epoch 49 completed in 0:02:11

Epoch 50/50
Training with 1000 samples (125 batches of size 8)


                                                                                                             

Epoch 50 Training Avg Loss: 0.007413
Validating with 200 samples (25 batches)


                                                                                                  

Epoch 50 Validation Avg Loss: 0.006904
*** New best model saved with Validation Loss 0.006904 at Epoch 50 ***
Saved periodic checkpoint to /notebooks/output_3_18/smoke/output_V3_with_validation_FIXED/checkpoints/wavelet_unet_bn_epoch_050.pth
Epoch 50 completed in 0:02:11

--- Training Finished ---
Saved final model state to /notebooks/output_3_18/smoke/output_V3_with_validation_FIXED/wavelet_unet_bn_final.pth
Saved loss curves plot to /notebooks/output_3_18/smoke/output_V3_with_validation_FIXED/loss_curves.png
Best validation loss achieved: 0.006904
Find best model at: /notebooks/output_3_18/smoke/output_V3_with_validation_FIXED/wavelet_unet_bn_best.pth
Find final model at: /notebooks/output_3_18/smoke/output_V3_with_validation_FIXED/wavelet_unet_bn_final.pth
Find training status log at: /notebooks/output_3_18/smoke/output_V3_with_validation_FIXED/train_status.csv
Find sample montages in: /notebooks/output_3_18/smoke/output_V3_with_validation_FIXED/samples
Find checkpoints in: /noteboo