In [1]:
!pip install pytorch_wavelets

Collecting pytorch_wavelets
  Downloading pytorch_wavelets-1.3.0-py3-none-any.whl.metadata (10.0 kB)
Downloading pytorch_wavelets-1.3.0-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.9/54.9 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pytorch_wavelets
Successfully installed pytorch_wavelets-1.3.0
[0m

In [1]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pywt
from tqdm import tqdm
from pytorch_wavelets import DWTForward
import traceback

# Custom Dataset for Paired Smoky/Clear Images
class ColonoscopyDataset(Dataset):
    def __init__(self, clear_dir, blurry_dir, wavelet='db1', max_samples=None):
        try:
            self.clear_files = sorted([f for f in os.listdir(clear_dir) if os.path.isfile(os.path.join(clear_dir, f))])
            self.blurry_files = sorted([f for f in os.listdir(blurry_dir) if os.path.isfile(os.path.join(blurry_dir, f))])
        except FileNotFoundError as e:
            print(f"Error accessing directories: {e}")
            raise e

        if not self.clear_files or not self.blurry_files:
            raise ValueError(f"No files found in directories: Clear={clear_dir}, Blurry={blurry_dir}")

        min_len = min(len(self.clear_files), len(self.blurry_files))
        if len(self.clear_files) != len(self.blurry_files):
            print(f"Warning: Mismatched file counts. Using {min_len} pairs.")
            self.clear_files = self.clear_files[:min_len]
            self.blurry_files = self.blurry_files[:min_len]

        if max_samples is not None and max_samples > 0 and min_len > max_samples:
            print(f"Limiting dataset from {min_len} to {max_samples} samples.")
            self.clear_files = self.clear_files[:max_samples]
            self.blurry_files = self.blurry_files[:max_samples]

        self.clear_dir = clear_dir
        self.blurry_dir = blurry_dir
        self.wavelet = wavelet

    def __len__(self):
        return len(self.clear_files)

    def _load_and_preprocess_image(self, img_path):
        img = cv2.imread(img_path)
        if img is None:
            base, ext = os.path.splitext(img_path)
            if not ext:
                for try_ext in [".png", ".jpg"]:
                    img = cv2.imread(img_path + try_ext)
                    if img is not None:
                        break
            if img is None:
                raise ValueError(f"Failed to load image: {img_path}")

        if len(img.shape) == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
        elif img.shape[2] == 4:
            img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
        elif img.shape[2] != 3:
            raise ValueError(f"Image {img_path} has unexpected shape {img.shape}")

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
        return img

    def __getitem__(self, idx):
        try:
            if idx >= len(self.clear_files):
                raise IndexError(f"Index {idx} out of bounds: {len(self.clear_files)}")

            clear_img_path = os.path.join(self.clear_dir, self.clear_files[idx])
            blurry_img_path = os.path.join(self.blurry_dir, self.blurry_files[idx])

            clear_img_np = self._load_and_preprocess_image(clear_img_path)
            blurry_img_np = self._load_and_preprocess_image(blurry_img_path)

            if clear_img_np.shape[:2] != blurry_img_np.shape[:2]:
                target_h, target_w = clear_img_np.shape[:2]
                blurry_img_np = cv2.resize(blurry_img_np, (target_w, target_h), interpolation=cv2.INTER_LINEAR)

            # print(f"Idx {idx}: clear_img range [{clear_img_np.min():.3f}, {clear_img_np.max():.3f}], shape {clear_img_np.shape}")
            # print(f"Idx {idx}: blurry_img range [{blurry_img_np.min():.3f}, {blurry_img_np.max():.3f}], shape {blurry_img_np.shape}")

            clear_img_tensor = torch.from_numpy(clear_img_np).permute(2, 0, 1).float()
            blurry_img_tensor = torch.from_numpy(blurry_img_np).permute(2, 0, 1).float()

            return blurry_img_tensor, clear_img_tensor
        except Exception as e:
            print(f"Error loading item at index {idx}: {e}")
            traceback.print_exc()
            if idx > 0:
                return self.__getitem__(0)
            else:
                dummy_tensor = torch.zeros((3, 256, 256), dtype=torch.float32)
                return dummy_tensor, dummy_tensor

# Wavelet-U-Net Model with BatchNorm
class WaveletUNet_BN(nn.Module):
    def __init__(self, in_channels=3, wavelet_channels=12):
        super().__init__()

        def conv_block(in_ch, out_ch):
            block = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            )
            for m in block.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            return block

        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)

        self.wavelet_enc1 = conv_block(wavelet_channels, 64)
        self.wavelet_enc2 = conv_block(64, 128)

        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.dec3 = conv_block(512 + 256 + 128, 256)
        self.dec2 = conv_block(256 + 128 + 64, 128)
        self.dec1 = conv_block(128 + 64, 64)

        self.final = nn.Conv2d(64, 3, kernel_size=1)
        nn.init.kaiming_normal_(self.final.weight, mode='fan_out', nonlinearity='linear')
        if self.final.bias is not None:
            nn.init.constant_(self.final.bias, 0)

    def forward(self, x, wavelet):
        e1 = self.enc1(x)
        p1 = self.pool(e1)
        e2 = self.enc2(p1)
        p2 = self.pool(e2)
        e3 = self.enc3(p2)
        p3 = self.pool(e3)
        e4 = self.enc4(p3)

        w_feat2 = self.wavelet_enc1(wavelet)
        pw_feat2 = self.pool(w_feat2)
        w_feat3 = self.wavelet_enc2(pw_feat2)

        up3 = self.up(e4)
        cat3 = torch.cat([up3, e3, w_feat3], dim=1)
        d3 = self.dec3(cat3)

        up2 = self.up(d3)
        cat2 = torch.cat([up2, e2, w_feat2], dim=1)
        d2 = self.dec2(cat2)

        up1 = self.up(d2)
        cat1 = torch.cat([up1, e1], dim=1)
        d1 = self.dec1(cat1)

        out = self.final(d1)
        return torch.sigmoid(out)

# Combined Loss with Differentiable DWT (Fixed)
class CombinedLossDWT(nn.Module):
    def __init__(self, alpha=0.85, wavelet='db1', device='cpu'):
        super().__init__()
        self.alpha = alpha
        self.l1_loss = nn.L1Loss()
        self.dwt = DWTForward(J=1, wave=wavelet, mode='symmetric').to(device)
        self.device = device

    def _compute_dwt(self, x):
        x = x.to(self.device)
        LL, H_coeffs = self.dwt(x)

        # The H_coeffs[0] has shape [B, C, 3, H, W] where 3 is for the three detail coefficients
        # We need to reshape it to match LL's dimensions for concatenation
        detail_coeffs = H_coeffs[0]  # Shape: [B, C, 3, H, W]
        B, C, _, H, W = detail_coeffs.shape

        # Reshape and concatenate the detail coefficients along the channel dimension
        # Option 1: Flatten the 3 detail coefficient types into the channel dimension
        detail_coeffs_reshaped = detail_coeffs.reshape(B, C*3, H, W)

        # Concatenate the approximation (LL) and detail coefficients
        coeffs_combined = torch.cat([LL, detail_coeffs_reshaped], dim=1)

        return coeffs_combined

    def forward(self, pred, target):
        pred, target = pred.float(), target.float()
        spatial_loss = self.l1_loss(pred, target)
        pred_wavelet = self._compute_dwt(pred)
        target_wavelet = self._compute_dwt(target)
        freq_loss = self.l1_loss(pred_wavelet, target_wavelet)
        total_loss = self.alpha * spatial_loss + (1 - self.alpha) * freq_loss

        if torch.isnan(total_loss):
            print(f"NaN loss! Spatial: {spatial_loss.item()}, Freq: {freq_loss.item()}")
            return torch.tensor(1000.0, device=self.device, requires_grad=True)
        return total_loss

# Helper function for wavelet input
def get_wavelet_input(img_tensor, wavelet='db1', device='cpu'):
    img_np = img_tensor.detach().cpu().numpy().transpose(0, 2, 3, 1)
    batch_wavelets = []
    for i in range(img_np.shape[0]):
        coeffs = pywt.dwt2(img_np[i], wavelet, mode='symmetric', axes=(-3, -2))
        cA, (cH, cV, cD) = coeffs
        wavelet_np = np.concatenate([cA, cH, cV, cD], axis=2).astype(np.float32)
        for ch in range(wavelet_np.shape[2]):
            wavelet_np[:, :, ch] = (wavelet_np[:, :, ch] - wavelet_np[:, :, ch].mean()) / (wavelet_np[:, :, ch].std() + 1e-8)
        batch_wavelets.append(wavelet_np)

    wavelet_tensor = torch.from_numpy(np.stack(batch_wavelets)).permute(0, 3, 1, 2).float().to(device)
    return wavelet_tensor

# Load Test Image Pair
def load_test_image_pair(clear_dir, blurry_dir, idx=0, device='cpu'):
    temp_dataset = ColonoscopyDataset(clear_dir, blurry_dir)
    effective_idx = min(idx, len(temp_dataset) - 1)
    if idx != effective_idx:
        print(f"Adjusted test idx from {idx} to {effective_idx}")
    blurry_img_tensor, clear_img_tensor = temp_dataset[effective_idx]
    wavelet_input = get_wavelet_input(blurry_img_tensor.unsqueeze(0), device=device)
    return blurry_img_tensor.unsqueeze(0).to(device), wavelet_input, clear_img_tensor.unsqueeze(0).to(device)

# Training Function
def train_model(clear_dir, blurry_dir, output_dir, epochs=50, batch_size=4, lr=0.0005,
                wavelet='db1', alpha=0.85, weight_decay=1e-5, max_samples_train=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model = WaveletUNet_BN().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = CombinedLossDWT(alpha=alpha, wavelet=wavelet, device=device)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5, verbose=True)

    try:
        dataset = ColonoscopyDataset(clear_dir, blurry_dir, wavelet=wavelet, max_samples=max_samples_train)
        if len(dataset) == 0:
            print("ERROR: Dataset is empty!")
            return None
    except Exception as e:
        print(f"Error initializing dataset: {e}")
        return None

    num_w = min(4, os.cpu_count() // 2 if os.cpu_count() > 1 else 1)
    print(f"Using {num_w} dataloader workers.")
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_w,
                            pin_memory=True if device.type == 'cuda' else False, drop_last=True)

    samples_dir = os.path.join(output_dir, "samples")
    checkpoints_dir = os.path.join(output_dir, "checkpoints")
    os.makedirs(samples_dir, exist_ok=True)
    os.makedirs(checkpoints_dir, exist_ok=True)

    test_idx = min(5, len(dataset) - 1)
    print(f"Loading test image pair with index: {test_idx}")
    try:
        test_blurry, test_wavelet_input, test_clear = load_test_image_pair(clear_dir, blurry_dir, idx=test_idx, device=device)
    except Exception as e:
        print(f"ERROR loading test image: {e}")
        return None

    print(f"Training dataset size: {len(dataset)}")
    print(f"Test image shapes: Blurry={test_blurry.shape}, WaveletInput={test_wavelet_input.shape}, Clear={test_clear.shape}")

    model.eval()
    with torch.no_grad():
        initial_pred = model(test_blurry, test_wavelet_input)
        initial_pred_np = initial_pred.squeeze(0).cpu().numpy().transpose(1, 2, 0)
        initial_pred_img = (np.clip(initial_pred_np, 0, 1) * 255).astype(np.uint8)
        initial_pred_img = cv2.cvtColor(initial_pred_img, cv2.COLOR_RGB2BGR)
        cv2.imwrite(os.path.join(samples_dir, "initial_pred_test.png"), initial_pred_img)
        print("Saved initial prediction sample.")

    print("Starting training...")
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", leave=True)
        for i, batch in enumerate(progress_bar):
            if batch is None:
                continue
            try:
                blurry_img, clear_img = batch
            except Exception as e:
                print(f"Error unpacking batch {i}: {e}")
                continue

            blurry_img = blurry_img.to(device)
            clear_img = clear_img.to(device)

            if not isinstance(blurry_img, torch.Tensor) or not isinstance(clear_img, torch.Tensor):
                continue
            if blurry_img.ndim != 4 or clear_img.ndim != 4:
                continue

            diff = torch.mean(torch.abs(blurry_img - clear_img)).item()
            # print(f"Batch {i}: Mean abs diff between blurry and clear: {diff:.4f}")

            wavelet_input = get_wavelet_input(blurry_img, wavelet=wavelet, device=device)

            optimizer.zero_grad()
            pred = model(blurry_img, wavelet_input)
            # print(f"Epoch {epoch+1}, Batch {i}: pred range [{pred.min():.3f}, {pred.max():.3f}]")

            loss = criterion(pred, clear_img)
            loss.backward()
            optimizer.step()

            current_loss = loss.item()
            running_loss += current_loss
            progress_bar.set_postfix(loss=f"{current_loss:.6f}", avg_loss=f"{running_loss/(i+1):.6f}", lr=f"{optimizer.param_groups[0]['lr']:.1e}")

        avg_loss = running_loss / len(dataloader) if len(dataloader) > 0 else 0
        print(f"\nEpoch {epoch+1}/{epochs} finished. Average Training Loss: {avg_loss:.6f}")
        scheduler.step(avg_loss)

        if (epoch + 1) % 5 == 0 or epoch == 0 or epoch == epochs - 1:
            model.eval()
            with torch.no_grad():
                pred_test = model(test_blurry, test_wavelet_input)
                pred_test_np = pred_test.squeeze(0).cpu().numpy().transpose(1, 2, 0)
                pred_test_img = (np.clip(pred_test_np, 0, 1) * 255).astype(np.uint8)
                pred_test_img = cv2.cvtColor(pred_test_img, cv2.COLOR_RGB2BGR)
                sample_path = os.path.join(samples_dir, f"pred_epoch_{epoch+1:03d}.png")
                cv2.imwrite(sample_path, pred_test_img)
                print(f"Saved prediction sample to {sample_path}")

                clear_test_np = test_clear.squeeze(0).cpu().numpy().transpose(1, 2, 0)
                clear_test_img = (np.clip(clear_test_np, 0, 1) * 255).astype(np.uint8)
                clear_test_img = cv2.cvtColor(clear_test_img, cv2.COLOR_RGB2BGR)
                blurry_test_np = test_blurry.squeeze(0).cpu().numpy().transpose(1, 2, 0)
                blurry_test_img = (np.clip(blurry_test_np, 0, 1) * 255).astype(np.uint8)
                blurry_test_img = cv2.cvtColor(blurry_test_img, cv2.COLOR_RGB2BGR)

                h, w = pred_test_img.shape[:2]
                if blurry_test_img.shape[:2] != (h, w):
                    blurry_test_img = cv2.resize(blurry_test_img, (w, h))
                if clear_test_img.shape[:2] != (h, w):
                    clear_test_img = cv2.resize(clear_test_img, (w, h))

                comparison_img = np.concatenate((blurry_test_img, pred_test_img, clear_test_img), axis=1)
                comp_path = os.path.join(samples_dir, f"comparison_epoch_{epoch+1:03d}.png")
                cv2.imwrite(comp_path, comparison_img)

            model.train()

        if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
            checkpoint_path = os.path.join(checkpoints_dir, f"wavelet_unet_bn_epoch_{epoch+1:03d}.pth")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'loss': avg_loss,
            }, checkpoint_path)
            print(f"Saved checkpoint to {checkpoint_path}")

    print("Training finished.")
    final_model_path = os.path.join(output_dir, "wavelet_unet_bn_final.pth")
    torch.save(model.state_dict(), final_model_path)
    print(f"Saved final model to {final_model_path}")
    return model

# Example Usage
if __name__ == "__main__":
    clear_dir = "output_3_18/smoke/clear"
    blurry_dir = "output_3_18/smoke/blurry"
    output_dir = "output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2"

    os.makedirs(output_dir, exist_ok=True)
    print(f"Clear image directory: {os.path.abspath(clear_dir)}")
    print(f"Blurry image directory: {os.path.abspath(blurry_dir)}")
    print(f"Output directory: {os.path.abspath(output_dir)}")

    if not os.path.isdir(clear_dir) or not os.path.isdir(blurry_dir):
        print(f"Error: Directory not found - Clear: {clear_dir}, Blurry: {blurry_dir}")
    else:
        try:
            model = train_model(
                clear_dir, blurry_dir, output_dir,
                epochs=50,
                batch_size=4,
                lr=0.0005,
                wavelet='db1',
                alpha=0.85,
                weight_decay=1e-5,
                max_samples_train=1000
            )
            if model:
                print("Training process completed.")
            else:
                print("Training process failed.")
        except Exception as e:
            print(f"An unexpected error occurred: {e}")
            traceback.print_exc()

Clear image directory: /notebooks/output_3_18/smoke/clear
Blurry image directory: /notebooks/output_3_18/smoke/blurry
Output directory: /notebooks/output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2
Using device: cuda
Limiting dataset from 10212 to 1000 samples.
Using 4 dataloader workers.
Loading test image pair with index: 5
Training dataset size: 1000
Test image shapes: Blurry=torch.Size([1, 3, 512, 512]), WaveletInput=torch.Size([1, 12, 256, 256]), Clear=torch.Size([1, 3, 512, 512])
Saved initial prediction sample.
Starting training...


Epoch 1/50: 100%|██████████| 250/250 [01:57<00:00,  2.12it/s, avg_loss=0.019581, loss=0.009483, lr=5.0e-04]



Epoch 1/50 finished. Average Training Loss: 0.019581
Saved prediction sample to output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2/samples/pred_epoch_001.png


Epoch 2/50: 100%|██████████| 250/250 [02:01<00:00,  2.06it/s, avg_loss=0.010446, loss=0.008062, lr=5.0e-04]



Epoch 2/50 finished. Average Training Loss: 0.010446


Epoch 3/50: 100%|██████████| 250/250 [02:02<00:00,  2.05it/s, avg_loss=0.008094, loss=0.008627, lr=5.0e-04]



Epoch 3/50 finished. Average Training Loss: 0.008094


Epoch 4/50: 100%|██████████| 250/250 [02:01<00:00,  2.06it/s, avg_loss=0.006549, loss=0.006559, lr=5.0e-04]



Epoch 4/50 finished. Average Training Loss: 0.006549


Epoch 5/50: 100%|██████████| 250/250 [02:01<00:00,  2.06it/s, avg_loss=0.006011, loss=0.005653, lr=5.0e-04]



Epoch 5/50 finished. Average Training Loss: 0.006011
Saved prediction sample to output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2/samples/pred_epoch_005.png


Epoch 6/50: 100%|██████████| 250/250 [01:59<00:00,  2.10it/s, avg_loss=0.005432, loss=0.004233, lr=5.0e-04]



Epoch 6/50 finished. Average Training Loss: 0.005432


Epoch 7/50: 100%|██████████| 250/250 [01:58<00:00,  2.10it/s, avg_loss=0.005180, loss=0.005491, lr=5.0e-04]



Epoch 7/50 finished. Average Training Loss: 0.005180


Epoch 8/50: 100%|██████████| 250/250 [01:59<00:00,  2.10it/s, avg_loss=0.004965, loss=0.004451, lr=5.0e-04]



Epoch 8/50 finished. Average Training Loss: 0.004965


Epoch 9/50: 100%|██████████| 250/250 [01:58<00:00,  2.10it/s, avg_loss=0.004914, loss=0.005630, lr=5.0e-04]



Epoch 9/50 finished. Average Training Loss: 0.004914


Epoch 10/50: 100%|██████████| 250/250 [01:58<00:00,  2.11it/s, avg_loss=0.004649, loss=0.004817, lr=5.0e-04]



Epoch 10/50 finished. Average Training Loss: 0.004649
Saved prediction sample to output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2/samples/pred_epoch_010.png
Saved checkpoint to output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2/checkpoints/wavelet_unet_bn_epoch_010.pth


Epoch 11/50: 100%|██████████| 250/250 [01:58<00:00,  2.11it/s, avg_loss=0.004300, loss=0.004293, lr=5.0e-04]



Epoch 11/50 finished. Average Training Loss: 0.004300


Epoch 12/50: 100%|██████████| 250/250 [01:58<00:00,  2.12it/s, avg_loss=0.004061, loss=0.004935, lr=5.0e-04]



Epoch 12/50 finished. Average Training Loss: 0.004061


Epoch 13/50: 100%|██████████| 250/250 [01:58<00:00,  2.11it/s, avg_loss=0.003948, loss=0.003790, lr=5.0e-04]



Epoch 13/50 finished. Average Training Loss: 0.003948


Epoch 14/50: 100%|██████████| 250/250 [01:58<00:00,  2.11it/s, avg_loss=0.003842, loss=0.003186, lr=5.0e-04]



Epoch 14/50 finished. Average Training Loss: 0.003842


Epoch 15/50: 100%|██████████| 250/250 [01:58<00:00,  2.11it/s, avg_loss=0.003973, loss=0.004707, lr=5.0e-04]



Epoch 15/50 finished. Average Training Loss: 0.003973
Saved prediction sample to output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2/samples/pred_epoch_015.png


Epoch 16/50: 100%|██████████| 250/250 [01:57<00:00,  2.12it/s, avg_loss=0.003702, loss=0.003009, lr=5.0e-04]



Epoch 16/50 finished. Average Training Loss: 0.003702


Epoch 17/50: 100%|██████████| 250/250 [01:58<00:00,  2.11it/s, avg_loss=0.003696, loss=0.004715, lr=5.0e-04]



Epoch 17/50 finished. Average Training Loss: 0.003696


Epoch 18/50: 100%|██████████| 250/250 [01:58<00:00,  2.11it/s, avg_loss=0.003518, loss=0.003050, lr=5.0e-04]



Epoch 18/50 finished. Average Training Loss: 0.003518


Epoch 19/50: 100%|██████████| 250/250 [01:58<00:00,  2.11it/s, avg_loss=0.003615, loss=0.002831, lr=5.0e-04]



Epoch 19/50 finished. Average Training Loss: 0.003615


Epoch 20/50: 100%|██████████| 250/250 [01:58<00:00,  2.12it/s, avg_loss=0.003446, loss=0.002454, lr=5.0e-04]



Epoch 20/50 finished. Average Training Loss: 0.003446
Saved prediction sample to output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2/samples/pred_epoch_020.png
Saved checkpoint to output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2/checkpoints/wavelet_unet_bn_epoch_020.pth


Epoch 21/50: 100%|██████████| 250/250 [01:57<00:00,  2.12it/s, avg_loss=0.003486, loss=0.003828, lr=5.0e-04]



Epoch 21/50 finished. Average Training Loss: 0.003486


Epoch 22/50: 100%|██████████| 250/250 [01:57<00:00,  2.13it/s, avg_loss=0.003453, loss=0.003284, lr=5.0e-04]



Epoch 22/50 finished. Average Training Loss: 0.003453


Epoch 23/50: 100%|██████████| 250/250 [01:57<00:00,  2.12it/s, avg_loss=0.003373, loss=0.002503, lr=5.0e-04]



Epoch 23/50 finished. Average Training Loss: 0.003373


Epoch 24/50: 100%|██████████| 250/250 [01:58<00:00,  2.10it/s, avg_loss=0.003238, loss=0.003157, lr=5.0e-04]



Epoch 24/50 finished. Average Training Loss: 0.003238


Epoch 25/50: 100%|██████████| 250/250 [01:58<00:00,  2.12it/s, avg_loss=0.003289, loss=0.002895, lr=5.0e-04]



Epoch 25/50 finished. Average Training Loss: 0.003289
Saved prediction sample to output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2/samples/pred_epoch_025.png


Epoch 26/50: 100%|██████████| 250/250 [01:59<00:00,  2.09it/s, avg_loss=0.003127, loss=0.002864, lr=5.0e-04]



Epoch 26/50 finished. Average Training Loss: 0.003127


Epoch 27/50: 100%|██████████| 250/250 [01:59<00:00,  2.10it/s, avg_loss=0.003090, loss=0.002898, lr=5.0e-04]



Epoch 27/50 finished. Average Training Loss: 0.003090


Epoch 28/50: 100%|██████████| 250/250 [01:59<00:00,  2.09it/s, avg_loss=0.003062, loss=0.003061, lr=5.0e-04]



Epoch 28/50 finished. Average Training Loss: 0.003062


Epoch 29/50: 100%|██████████| 250/250 [01:58<00:00,  2.10it/s, avg_loss=0.003025, loss=0.003193, lr=5.0e-04]



Epoch 29/50 finished. Average Training Loss: 0.003025


Epoch 30/50: 100%|██████████| 250/250 [01:59<00:00,  2.10it/s, avg_loss=0.003522, loss=0.004363, lr=5.0e-04]



Epoch 30/50 finished. Average Training Loss: 0.003522
Saved prediction sample to output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2/samples/pred_epoch_030.png
Saved checkpoint to output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2/checkpoints/wavelet_unet_bn_epoch_030.pth


Epoch 31/50: 100%|██████████| 250/250 [02:00<00:00,  2.07it/s, avg_loss=0.003613, loss=0.002887, lr=5.0e-04]



Epoch 31/50 finished. Average Training Loss: 0.003613


Epoch 32/50: 100%|██████████| 250/250 [02:00<00:00,  2.08it/s, avg_loss=0.003177, loss=0.002924, lr=5.0e-04]



Epoch 32/50 finished. Average Training Loss: 0.003177


Epoch 33/50: 100%|██████████| 250/250 [02:00<00:00,  2.07it/s, avg_loss=0.003027, loss=0.002817, lr=5.0e-04]



Epoch 33/50 finished. Average Training Loss: 0.003027


Epoch 34/50: 100%|██████████| 250/250 [02:00<00:00,  2.08it/s, avg_loss=0.003042, loss=0.002904, lr=5.0e-04]



Epoch 34/50 finished. Average Training Loss: 0.003042


Epoch 35/50: 100%|██████████| 250/250 [02:00<00:00,  2.08it/s, avg_loss=0.002830, loss=0.002987, lr=5.0e-04]



Epoch 35/50 finished. Average Training Loss: 0.002830
Saved prediction sample to output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2/samples/pred_epoch_035.png


Epoch 36/50: 100%|██████████| 250/250 [02:01<00:00,  2.06it/s, avg_loss=0.002902, loss=0.002818, lr=5.0e-04]



Epoch 36/50 finished. Average Training Loss: 0.002902


Epoch 37/50: 100%|██████████| 250/250 [02:01<00:00,  2.06it/s, avg_loss=0.002792, loss=0.002122, lr=5.0e-04]



Epoch 37/50 finished. Average Training Loss: 0.002792


Epoch 38/50: 100%|██████████| 250/250 [02:01<00:00,  2.06it/s, avg_loss=0.002800, loss=0.002710, lr=5.0e-04]



Epoch 38/50 finished. Average Training Loss: 0.002800


Epoch 39/50: 100%|██████████| 250/250 [02:01<00:00,  2.06it/s, avg_loss=0.002751, loss=0.002690, lr=5.0e-04]



Epoch 39/50 finished. Average Training Loss: 0.002751


Epoch 40/50: 100%|██████████| 250/250 [02:00<00:00,  2.07it/s, avg_loss=0.002713, loss=0.002182, lr=5.0e-04]



Epoch 40/50 finished. Average Training Loss: 0.002713
Saved prediction sample to output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2/samples/pred_epoch_040.png
Saved checkpoint to output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2/checkpoints/wavelet_unet_bn_epoch_040.pth


Epoch 41/50: 100%|██████████| 250/250 [01:58<00:00,  2.12it/s, avg_loss=0.002675, loss=0.002594, lr=5.0e-04]



Epoch 41/50 finished. Average Training Loss: 0.002675


Epoch 42/50: 100%|██████████| 250/250 [01:57<00:00,  2.13it/s, avg_loss=0.002702, loss=0.002138, lr=5.0e-04]



Epoch 42/50 finished. Average Training Loss: 0.002702


Epoch 43/50: 100%|██████████| 250/250 [01:58<00:00,  2.12it/s, avg_loss=0.002675, loss=0.003826, lr=5.0e-04]



Epoch 43/50 finished. Average Training Loss: 0.002675


Epoch 44/50: 100%|██████████| 250/250 [01:58<00:00,  2.12it/s, avg_loss=0.002613, loss=0.002997, lr=5.0e-04]



Epoch 44/50 finished. Average Training Loss: 0.002613


Epoch 45/50: 100%|██████████| 250/250 [01:58<00:00,  2.12it/s, avg_loss=0.002622, loss=0.002590, lr=5.0e-04]



Epoch 45/50 finished. Average Training Loss: 0.002622
Saved prediction sample to output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2/samples/pred_epoch_045.png


Epoch 46/50: 100%|██████████| 250/250 [01:57<00:00,  2.12it/s, avg_loss=0.002538, loss=0.002249, lr=5.0e-04]



Epoch 46/50 finished. Average Training Loss: 0.002538


Epoch 47/50: 100%|██████████| 250/250 [01:58<00:00,  2.12it/s, avg_loss=0.002514, loss=0.003646, lr=5.0e-04]



Epoch 47/50 finished. Average Training Loss: 0.002514


Epoch 48/50: 100%|██████████| 250/250 [01:57<00:00,  2.12it/s, avg_loss=0.002588, loss=0.003038, lr=5.0e-04]



Epoch 48/50 finished. Average Training Loss: 0.002588


Epoch 49/50: 100%|██████████| 250/250 [01:57<00:00,  2.13it/s, avg_loss=0.002575, loss=0.002375, lr=5.0e-04]



Epoch 49/50 finished. Average Training Loss: 0.002575


Epoch 50/50: 100%|██████████| 250/250 [01:56<00:00,  2.14it/s, avg_loss=0.002510, loss=0.002309, lr=5.0e-04]



Epoch 50/50 finished. Average Training Loss: 0.002510
Saved prediction sample to output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2/samples/pred_epoch_050.png
Saved checkpoint to output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2/checkpoints/wavelet_unet_bn_epoch_050.pth
Training finished.
Saved final model to output_3_18/smoke/output_wavelet_bn_diffdwt_1k_limit_v2/wavelet_unet_bn_final.pth
Training process completed.
