# NAFNet GAN to denoise images

This is my implementation trained on an AMD Ryzen 7 5800X / 32GB RAM / RTX 5060 Ti 16GB

### Step 0: Import all the libraries and codes necessary to execute the model

In [None]:
import numpy as np
import pandas as pd
from skimage import io

In [None]:
# Python imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
from torch.optim.lr_scheduler import CosineAnnealingLR
import matplotlib.pyplot as plt
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import time

# Code imports
from dataloader import DenoisingDataset2D
from models import NAFNet
from models import Deep_Discriminator as Discriminator

### Step 1: Define paths to load the images.

The dataset has images that are between 64x64 and 1024x1024. By default, the cropping was set to 64. If you want to tweak the cropping, just change the parameter 'crop_size'. Note that this will only be used on train and not on validation.

In [None]:
def plot_images(noisy, pred, target, epoch, save_dir="training_visuals"):
    fig, axs = plt.subplots(1, 3, figsize=(12, 5))
    titles = ["Noisy Input", "Denoised Output", "Ground Truth"]
    for i, img in enumerate([noisy, pred, target]):
        img = img.squeeze().cpu().numpy()  # Shape: (H, W)
        img = np.clip(img, 0, 1)
        axs[i].imshow(img, cmap='gray')
        axs[i].set_title(titles[i])
        axs[i].axis("off")
    plt.tight_layout()
    plt.show()
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    plt.savefig(os.path.join(save_dir, f"epoch_{epoch}_viz.png"))
    plt.close()  # Close the plot to free memory

def load_data(train_dir, val_dir, test_size, batch_size, num_workers_tr, num_workers_val, crop_size = None):
    train_noisy = sorted([os.path.join(train_dir, "RAW", f) for f in os.listdir(os.path.join(train_dir, "RAW")) if f.endswith('.tif')])
    train_gt = sorted([os.path.join(train_dir, "GT", f) for f in os.listdir(os.path.join(train_dir, "GT")) if f.endswith('.tif')])
    val_noisy = sorted([os.path.join(val_dir, "RAW", f) for f in os.listdir(os.path.join(val_dir, "RAW")) if f.endswith('.tif')])
    val_gt = sorted([os.path.join(val_dir, "GT", f) for f in os.listdir(os.path.join(val_dir, "GT")) if f.endswith('.tif')])

    print(f"{len(train_noisy)} training images and {len(val_noisy)} validation images after split.")

    train_ds = DenoisingDataset2D(train_noisy, train_gt, crop_size=crop_size, augment=True)
    val_ds = DenoisingDataset2D(val_noisy, val_gt, crop_size=crop_size, augment=False)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers_tr, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=num_workers_val, pin_memory=True) # batch_size = 1

    print(f"\nTraining DataLoader created with {len(train_loader)} batches.")
    print(f"Validation DataLoader created with {len(val_loader)} batches.")

    return train_loader, val_loader

# Define data paths (update these to your actual paths)
train_dir = r"D:\Manuscipts_Coding\Denoising_paper\IgG-1D\Exported_Data_TIFF\train"
val_dir = r"D:\Manuscipts_Coding\Denoising_paper\IgG-1D\Exported_Data_TIFF\val"
pretrained_path = "DATASET_7_NAFNet_GAN_best_model_LOSS_2_V2.pth"  # If it does not exist, then it will not load the weights on train.
checkpoint_path = "DATASET_7_NAFNet_GAN_best_model_LOSS_2_V2.pth"
crop_size = 128
num_workers_tr = 0 # Number of workers for training 
num_workers_val = 0 # Number of workers for validation 

# Step 1: Load data
if os.path.exists(train_dir) == True and os.path.exists(val_dir) == True:
    train_loader, val_loader = load_data(
        train_dir=train_dir,
        val_dir=val_dir,
        test_size=0.04,
        batch_size=32,
        num_workers_tr=num_workers_tr,  
        num_workers_val=num_workers_val,  
        crop_size=crop_size
    )

### Step 2: Visualize the dataloaders.

In [None]:
# Step 2: Visualize a sample from training and validation data
def visualize_dataloader(loader, title="Sample"):
    start_time = time.perf_counter()
    noisy_img_batch, gt_img_batch = next(iter(loader))
    end_time = time.perf_counter()
    print(f"Time to load first batch: {end_time - start_time:.4f} seconds")
    print(f"Noisy batch shape: {noisy_img_batch.shape}, GT batch shape: {gt_img_batch.shape}")

    noisy_sample = noisy_img_batch[0].squeeze().cpu().numpy()
    gt_sample = gt_img_batch[0].squeeze().cpu().numpy()

    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    im0 = axs[0].imshow(noisy_sample, cmap='gray')
    axs[0].set_title(f"Noisy Input {title}")
    axs[0].axis("off")
    fig.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)
    im1 = axs[1].imshow(gt_sample, cmap='gray')
    axs[1].set_title(f"Ground Truth {title}")
    axs[1].axis("off")
    fig.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04)
    plt.tight_layout()
    plt.show()

print("Visualizing training data sample:")
visualize_dataloader(train_loader, title="Training Sample (Cropped)")

print("Visualizing validation data sample:")
visualize_dataloader(val_loader, title="Validation Sample (Full)")

### Step 3: Train the model.

Note that we have only trained it for 1 epoch.

In [None]:
from losses import MasterLoss

In [None]:
import os
os.environ['TORCH_COMPILE_DISABLE'] = '1'

import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import GradScaler 
from torch.amp import autocast 
import warnings
import time

# --- Setup & Helper Functions ---
warnings.filterwarnings('ignore')

def plot_images(noisy, pred, target, epoch, save_dir="training_visuals"):
    """Saves and displays the input, predicted, and ground truth images with colorbars."""
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    titles = ["Noisy Input", "Denoised Output", "Ground Truth"]
    
    # Move to CPU and detach for plotting
    images = [
        noisy.squeeze().cpu().detach().numpy(),
        pred.squeeze().cpu().detach().numpy(),
        target.squeeze().cpu().detach().numpy()
    ]
    
    for i, img in enumerate(images):
        im = axs[i].imshow(img, cmap='gray', vmin=0, vmax=1)
        axs[i].set_title(f"{titles[i]} (Epoch {epoch})")
        axs[i].axis("off")
        fig.colorbar(im, ax=axs[i], fraction=0.046, pad=0.04)

    plt.tight_layout()
    plt.show()

# Assuming NAFNet, MasterLoss, LossWeights, train_loader, val_loader are defined/imported
# If running as standalone script, ensure imports for these are present.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Model Setup ---
img_channel = 1  # Grayscale
width = 16
enc_blks = [2, 2, 4, 8]
middle_blk_num = 12
dec_blks = [2, 2, 2, 2]

# Initialize Generator (NAFNet)
generator = NAFNet(
    img_channel=img_channel,
    width=width,
    middle_blk_num=middle_blk_num,
    enc_blk_nums=enc_blks,
    dec_blk_nums=dec_blks
).to(device)

# Initialize Discriminator
discriminator = Discriminator(in_channels=1).to(device)

# Paths
checkpoint_path = 'NAFNet_GAN_LVUP_Dataset_7_Conf-het_Best_Loss_3.pth'
os.makedirs("training_visuals", exist_ok=True)

# Hyperparameters
num_epochs = 30
T_max = max(1, int(num_epochs / 5)) 

# Optimizers & Schedulers
g_optimizer = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))

g_scheduler = CosineAnnealingLR(g_optimizer, T_max=T_max)
d_scheduler = CosineAnnealingLR(d_optimizer, T_max=T_max)

# --- Loss Configuration ---
# IMPORTANT: Using 'enhanced_deblur' for GAN training
LOSS_TYPE = 'enhanced_deblur' 

class LossWeights:
    lambda_lpips = 3.00    
    lambda_vgg = 0.00     
    lambda_charb = 0.5    
    lambda_ssim = 0.0     
    lambda_lap = 2.00     
    lambda_edge = 0.0     
    lambda_fft_cc = 0.5 
    lambda_fft = 0.0     
    lambda_gan = 0.5     
    r1_gamma = 0.0       

criterion = MasterLoss(loss_type=LOSS_TYPE, weights=LossWeights(), device=device)

# Scalers for Mixed Precision
scaler_g = GradScaler()
scaler_d = GradScaler()

# Tracking
best_val_loss = float('inf')
train_g_losses = []
train_d_losses = []
val_g_losses = []

# --- PRE-TRAINING SETUP: Fixed Validation Batch ---
# Grab one batch to use for consistent visualization throughout training
print("Creating fixed validation batch for consistency...")
try:
    fixed_val_input, fixed_val_target = next(iter(val_loader))
    fixed_val_input = fixed_val_input.to(device)
    fixed_val_target = fixed_val_target.to(device)
    print("Fixed batch created successfully.")
except StopIteration:
    print("Error: Validation loader is empty!")
    exit()

# --- TRAINING LOOP ---
for epoch in range(num_epochs):
    generator.train()
    discriminator.train()
    
    current_g_loss = 0.0
    current_d_loss = 0.0
    valid_batches = 0

    # Train Step
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
    for noisy_img, clean_img in pbar:
        noisy_img, clean_img = noisy_img.to(device), clean_img.to(device)

        # NaN Check
        if torch.isnan(noisy_img).any() or torch.isinf(noisy_img).any():
            continue

        # ---------------------
        #  Train Discriminator
        # ---------------------
        d_optimizer.zero_grad()
        with autocast(device_type=device.type):
            fake_img = generator(noisy_img)
            fake_img = torch.clamp(fake_img, 0, 1) # Enforce [0,1] range
            
            # Real vs Fake inputs
            d_real = discriminator(clean_img)
            d_fake = discriminator(fake_img.detach()) # Detach to stop gradient to Generator
            
            d_loss = criterion.forward_discriminator(d_real, d_fake)

        if torch.isnan(d_loss) or torch.isinf(d_loss):
            continue

        scaler_d.scale(d_loss).backward()
        torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
        scaler_d.step(d_optimizer)
        scaler_d.update()

        # -----------------
        #  Train Generator
        # -----------------
        g_optimizer.zero_grad()
        with autocast(device_type=device.type):
            # Re-compute D output for Generator update (gradients flow this time)
            d_fake_for_g = discriminator(fake_img)
            
            # Dictionary inputs for MasterLoss
            g_loss_inputs = {
                'pred_img': fake_img,
                'target_img': clean_img,
                'd_fake_logits': d_fake_for_g
            }
            g_loss = criterion.forward_generator(g_loss_inputs)

        if torch.isnan(g_loss) or torch.isinf(g_loss):
            continue

        scaler_g.scale(g_loss).backward()
        torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
        scaler_g.step(g_optimizer)
        scaler_g.update()

        current_g_loss += g_loss.item()
        current_d_loss += d_loss.item()
        valid_batches += 1
        
        # Update progress bar
        pbar.set_postfix({'G_Loss': g_loss.item(), 'D_Loss': d_loss.item()})

    # Epoch Averages
    if valid_batches > 0:
        current_g_loss /= valid_batches
        current_d_loss /= valid_batches
    else:
        current_g_loss = float('inf')
        current_d_loss = float('inf')

    # --- VALIDATION LOOP ---
    generator.eval()
    discriminator.eval()
    current_val_g_loss = 0.0
    valid_val_batches = 0
    
    with torch.no_grad():
        with autocast(device_type=device.type):
            for val_input, val_target in tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Valid]'):
                val_input = val_input.to(device)
                val_target = val_target.to(device)
                
                val_output = generator(val_input)
                val_output = torch.clamp(val_output, 0, 1)
                
                # We only care about G loss for validation metric
                d_fake_val = discriminator(val_output)
                
                val_inputs = {
                    'pred_img': val_output,
                    'target_img': val_target,
                    'd_fake_logits': d_fake_val
                }
                
                val_loss = criterion.forward_generator(val_inputs)
                
                if not (torch.isnan(val_loss) or torch.isinf(val_loss)):
                    current_val_g_loss += val_loss.item() * val_input.size(0)
                    valid_val_batches += val_input.size(0)

    if valid_val_batches > 0:
        current_val_g_loss /= valid_val_batches
    else:
        current_val_g_loss = float('inf')

    # --- VISUALIZATION (Using Fixed Batch) ---
    if (epoch + 1) % 1 == 0:
        generator.eval()
        with torch.no_grad():
            # Use the SAME fixed batch we grabbed at start
            fixed_pred = generator(fixed_val_input)
            fixed_pred = torch.clamp(fixed_pred, 0, 1)
            
            # Plot the first image from the fixed batch
            plot_images(fixed_val_input[0], fixed_pred[0], fixed_val_target[0], epoch+1)

    # --- LOGGING & SAVING ---
    train_g_losses.append(current_g_loss)
    train_d_losses.append(current_d_loss)
    val_g_losses.append(current_val_g_loss)

    print(f'Epoch {epoch + 1:02d} | Train G: {current_g_loss:.4f} | Train D: {current_d_loss:.4f} | Val G: {current_val_g_loss:.4f}')

    # Save Best Model
    if current_val_g_loss < best_val_loss and valid_val_batches > 0:
        best_val_loss = current_val_g_loss
        print(f"--> New Best Model! Loss: {best_val_loss:.4f}")
        torch.save({
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'g_optimizer_state_dict': g_optimizer.state_dict(),
            'd_optimizer_state_dict': d_optimizer.state_dict(),
            'val_loss': best_val_loss,
        }, checkpoint_path)

    # Save Latest Checkpoint
    torch.save({
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'g_optimizer_state_dict': g_optimizer.state_dict(),
        'd_optimizer_state_dict': d_optimizer.state_dict(),
        'val_loss': current_val_g_loss,
    }, 'NAFNet_GAN_LVUP_Dataset_7_Conf-het_Latest_Loss_3.pth')
    
    g_scheduler.step()
    d_scheduler.step()

# --- FINISH & SAVE CSV ---
loss_df = pd.DataFrame({
    'Epoch': range(1, num_epochs + 1),
    'Train Generator Loss': train_g_losses,
    'Train Discriminator Loss': train_d_losses,
    'Val Generator Loss': val_g_losses
})
loss_df.to_csv('NAFNet_GAN_LVUP_Dataset_7_Conf-het_Loss_3.csv', index=False)
print("Training Complete.")
# Final visualization
checkpoint = torch.load(checkpoint_path, map_location=device)
generator.load_state_dict(checkpoint['generator_state_dict'])
generator.eval()

val_input, val_target = next(iter(val_loader))
val_input = val_input.to(device)
val_target = val_target.to(device)
with torch.no_grad():
    val_output = generator(val_input)

input_img = val_input[0].cpu()
pred_img = val_output[0].cpu()
target_img = val_target[0].cpu()

plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.title("Input Image (Noisy)")
plt.imshow(input_img.squeeze(), cmap='gray')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.title("Predicted Image (Denoised)")
plt.imshow(pred_img.squeeze(), cmap='gray')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.title("Ground Truth (Clean)")
plt.imshow(target_img.squeeze(), cmap='gray')
plt.axis('off')
plt.savefig(os.path.join("training_visuals", "final_result.png"))
plt.show()

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from skimage.io import imread
from skimage.filters import gaussian
from skimage.util import random_noise
import random
import os
import matplotlib.pyplot as plt

class DenoisingDataset2D(Dataset):
    def __init__(self, gt_paths, crop_size=None, augment=True, mode="train", multiple=16, blind_ratio=0.0):
        self.gt_paths = gt_paths
        self.crop_size = crop_size
        self.augment = augment
        self.mode = mode
        self.multiple = multiple
        self.blind_ratio = blind_ratio  # NEW: Percentage of pixels to blind (e.g., 0.01)
        
        # Fixed validation schedule
        if self.mode == "val":
            n = len(gt_paths)
            n_blur = int(0.4 * n)
            n_noise = int(0.4 * n)
            n_clean = n - n_blur - n_noise
            self.val_aug_types = (["blur"] * n_blur + ["noise"] * n_noise + ["none"] * n_clean)
            random.shuffle(self.val_aug_types)

    def pad_to_multiple(self, img, multiple=16):
        """Mirrors image edges to reach the next multiple of 'multiple'."""
        h, w = img.shape
        pad_h = (multiple - (h % multiple)) % multiple
        pad_w = (multiple - (w % multiple)) % multiple
        
        if pad_h > 0 or pad_w > 0:
            top = pad_h // 2
            bottom = pad_h - top
            left = pad_w // 2
            right = pad_w - left
            img = np.pad(img, ((top, bottom), (left, right)), mode='reflect')
        return img

    def __getitem__(self, idx):
        # 1. Load Clean Ground Truth (Which is actually Raw in Self-Supervised)
        gt = imread(self.gt_paths[idx]).astype(np.float32)
        
        # Normalize
        gt = (gt - gt.min()) / (gt.max() - gt.min() + 1e-8)
        
        # 2. Create Noisy Input
        noisy = gt.copy()
        
        # Determine Augmentation
        aug_type = "none"
        if self.mode == "train":
            aug_type = random.choice(["blur", "noise", "none"])
        elif self.mode == "val":
            aug_type = self.val_aug_types[idx]

        # Apply Degradation (Noisier2Noise / Noise2Blur logic)
        if aug_type == "blur":
            sigma = random.uniform(2, 5) if self.mode == "train" else 3
            noisy = gaussian(noisy, sigma=sigma)
        elif aug_type == "noise":
            var = random.uniform(0.01, 0.1) if self.mode == "train" else 0.05
            noisy = random_noise(noisy, mode="gaussian", var=var)

        # 3. APPLY BLINDING (MASKING) - NEW STEP
        # Mask: 1 = Blinded (Calculate Loss Here), 0 = Kept (Ignore Loss Here)
        mask = np.zeros_like(noisy, dtype=np.float32)
        
        if self.blind_ratio > 0:
            # Create random boolean mask
            blind_mask = np.random.random(noisy.shape) < self.blind_ratio
            
            # Generate random noise to fill the blinded spots 
            # (Using image statistics so it's not obvious zeros)
            noise_fill = np.random.normal(loc=noisy.mean(), scale=noisy.std(), size=noisy.shape)
            
            # Apply blinding to input
            noisy[blind_mask] = noise_fill[blind_mask]
            
            # Set mask to 1 where we blinded
            mask[blind_mask] = 1.0

        # 4. Dimension Fixing (Padding/Cropping)
        if self.mode == "train" and self.crop_size:
            h, w = noisy.shape
            # Pad if needed
            if h < self.crop_size or w < self.crop_size:
                pad_h = max(0, self.crop_size - h)
                pad_w = max(0, self.crop_size - w)
                if pad_h > 0 or pad_w > 0:
                    noisy = np.pad(noisy, ((0, pad_h), (0, pad_w)), mode='reflect')
                    gt = np.pad(gt, ((0, pad_h), (0, pad_w)), mode='reflect')
                    mask = np.pad(mask, ((0, pad_h), (0, pad_w)), mode='constant', constant_values=0) # Pad mask with 0
                h, w = noisy.shape

            # Random Crop
            x = random.randint(0, h - self.crop_size)
            y = random.randint(0, w - self.crop_size)
            
            noisy = noisy[x:x+self.crop_size, y:y+self.crop_size]
            gt = gt[x:x+self.crop_size, y:y+self.crop_size]
            mask = mask[x:x+self.crop_size, y:y+self.crop_size]
            
            # Flips
            if self.augment:
                if random.random() < 0.5:
                    noisy = np.fliplr(noisy); gt = np.fliplr(gt); mask = np.fliplr(mask)
                if random.random() < 0.5:
                    noisy = np.flipud(noisy); gt = np.flipud(gt); mask = np.flipud(mask)

        else:
            # Validation Padding
            noisy = self.pad_to_multiple(noisy, self.multiple)
            gt = self.pad_to_multiple(gt, self.multiple)
            mask = self.pad_to_multiple(mask, self.multiple)

        # 5. Convert to Tensor
        # Copy ensures no negative stride issues from flipping
        noisy = torch.from_numpy(noisy.copy()).unsqueeze(0).float()
        gt = torch.from_numpy(gt.copy()).unsqueeze(0).float()
        mask = torch.from_numpy(mask.copy()).unsqueeze(0).float()
        
        return noisy, gt, mask

    def __len__(self):
        return len(self.gt_paths)

# --- Loader Function Update ---
def load_data(gt_dir, batch_size, num_workers_tr, num_workers_val, crop_size=None):
    train_gt = sorted([os.path.join(gt_dir, "Patch_Train", f) 
                       for f in os.listdir(os.path.join(gt_dir, "Patch_Train")) if f.endswith('.tif')])
    val_gt = sorted([os.path.join(gt_dir, "Patch_Val", f) 
                     for f in os.listdir(os.path.join(gt_dir, "Patch_Val")) if f.endswith('.tif')])

    # Set blind_ratio here (e.g., 0.02 for 2% blinding, typical for N2V is 0.2% - 2%)
    train_ds = DenoisingDataset2D(train_gt, crop_size=crop_size, augment=True, mode="train", blind_ratio=0.05)
    val_ds = DenoisingDataset2D(val_gt, crop_size=crop_size, augment=False, mode="val", blind_ratio=0.0) # Usually don't blind validation

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers_tr, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False,
                            num_workers=num_workers_val, pin_memory=True)

    print(f"{len(train_gt)} training images and {len(val_gt)} validation images.")
    return train_loader, val_loader

gt_dir = r"C:\Users\Guill\Downloads\EMPIAR_10197\EMPIAR_10197\10197\data\leginondata\rawdata\GT"
pretrained_path = "DATASET_7_NAFNet_GAN_best_model_LOSS_2_Blind.pth"  # If it does not exist, then it will not load the weights on train.
checkpoint_path = "DATASET_7_NAFNet_GAN_best_model_LOSS_2_Blind.pth"
crop_size = int(512)
num_workers_tr = 0 # Number of workers for training 
num_workers_val = 0 # Number of workers for validation 

# Step 1: Load data
if os.path.exists(gt_dir) == True:
    train_loader, val_loader = load_data(
        gt_dir=gt_dir,
        batch_size=2,
        num_workers_tr=num_workers_tr,  
        num_workers_val=num_workers_val,  
        crop_size=crop_size
    )

In [None]:
# Step 2: Visualize a sample from training and validation data
import time
import matplotlib.pyplot as plt

def visualize_dataloader(loader, title="Sample"):
    start_time = time.perf_counter()
    # Unpack 3 values: noisy, gt, mask
    noisy_img_batch, gt_img_batch, mask_batch = next(iter(loader))
    end_time = time.perf_counter()
    print(f"Time to load first batch: {end_time - start_time:.4f} seconds")
    print(f"Noisy batch shape: {noisy_img_batch.shape}, GT batch shape: {gt_img_batch.shape}")

    noisy_sample = noisy_img_batch[0].squeeze().cpu().numpy()
    gt_sample = gt_img_batch[0].squeeze().cpu().numpy()
    mask_sample = mask_batch[0].squeeze().cpu().numpy()

    fig, axs = plt.subplots(1, 3, figsize=(15, 5)) # Plot 3 things now
    
    im0 = axs[0].imshow(noisy_sample, cmap='gray')
    axs[0].set_title(f"Noisy Input (Blinded) {title}")
    axs[0].axis("off")
    fig.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)
    
    im1 = axs[1].imshow(gt_sample, cmap='gray')
    axs[1].set_title(f"Ground Truth {title}")
    axs[1].axis("off")
    fig.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04)
    
    # Visualize the Mask too so you can see where the blind spots are
    im2 = axs[2].imshow(mask_sample, cmap='gray')
    axs[2].set_title(f"Blind Mask {title}")
    axs[2].axis("off")
    fig.colorbar(im2, ax=axs[2], fraction=0.046, pad=0.04)

    plt.tight_layout()
    plt.show()

print("Visualizing training data sample:")
visualize_dataloader(train_loader, title="Training Sample (Cropped)")

print("Visualizing validation data sample:")
visualize_dataloader(val_loader, title="Validation Sample (Full)")

In [None]:
import os
os.environ['TORCH_COMPILE_DISABLE'] = '1'

import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import GradScaler 
from torch.amp import autocast 
import warnings
import time

# --- Setup & Helper Functions ---
warnings.filterwarnings('ignore')

def plot_images(noisy, pred, target, epoch, save_dir="training_visuals"):
    """Saves and displays the input, predicted, and ground truth images with colorbars."""
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    titles = ["Noisy Input", "Denoised Output", "Ground Truth"]
    
    # Move to CPU and detach for plotting
    images = [
        noisy.squeeze().cpu().detach().numpy(),
        pred.squeeze().cpu().detach().numpy(),
        target.squeeze().cpu().detach().numpy()
    ]
    
    for i, img in enumerate(images):
        im = axs[i].imshow(img, cmap='gray', vmin=0, vmax=1)
        axs[i].set_title(f"{titles[i]} (Epoch {epoch})")
        axs[i].axis("off")
        fig.colorbar(im, ax=axs[i], fraction=0.046, pad=0.04)

    plt.tight_layout()
    plt.show()

# Assuming NAFNet, MasterLoss, LossWeights, train_loader, val_loader are defined/imported
# If running as standalone script, ensure imports for these are present.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Model Setup ---
img_channel = 1  # Grayscale
width = 16
enc_blks = [2, 2, 4, 8]
middle_blk_num = 12
dec_blks = [2, 2, 2, 2]

# Initialize Generator (NAFNet)
generator = NAFNet(
    img_channel=img_channel,
    width=width,
    middle_blk_num=middle_blk_num,
    enc_blk_nums=enc_blks,
    dec_blk_nums=dec_blks
).to(device)

# Initialize Discriminator
discriminator = Discriminator(in_channels=1).to(device)

# Paths
checkpoint_path = 'NAFNet_GAN_LVUP_Dataset_7_Best_Loss_3_Blind.pth'
os.makedirs("training_visuals", exist_ok=True)

# Hyperparameters
num_epochs = 100
T_max = max(1, int(num_epochs / 5)) 

# Optimizers & Schedulers
g_optimizer = optim.Adam(generator.parameters(), lr=1e-3, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=1e-3, betas=(0.5, 0.999))

g_scheduler = CosineAnnealingLR(g_optimizer, T_max=T_max)
d_scheduler = CosineAnnealingLR(d_optimizer, T_max=T_max)

# --- Loss Configuration ---
# IMPORTANT: Using 'enhanced_deblur' for GAN training
LOSS_TYPE = 'enhanced_deblur' 

class LossWeights:
    lambda_lpips = 3.00     
    lambda_vgg = 0.00      
    lambda_charb = 0.5     
    lambda_ssim = 0.0      
    lambda_lap = 2.00      
    lambda_edge = 0.0      
    lambda_fft_cc = 0.5 
    lambda_fft = 0.0      
    lambda_gan = 0.5      
    r1_gamma = 0.0        

criterion = MasterLoss(loss_type=LOSS_TYPE, weights=LossWeights(), device=device)

# Scalers for Mixed Precision
scaler_g = GradScaler()
scaler_d = GradScaler()

# Tracking
best_val_loss = float('inf')
train_g_losses = []
train_d_losses = []
val_g_losses = []

# --- PRE-TRAINING SETUP: Fixed Validation Batch ---
# Grab one batch to use for consistent visualization throughout training
print("Creating fixed validation batch for consistency...")
try:
    # UPDATED: Expect 3 values
    fixed_val_input, fixed_val_target, fixed_val_mask = next(iter(val_loader))
    fixed_val_input = fixed_val_input.to(device)
    fixed_val_target = fixed_val_target.to(device)
    # fixed_val_mask is likely mostly 0s for validation if blind_ratio=0, but kept for compatibility
    print("Fixed batch created successfully.")
except StopIteration:
    print("Error: Validation loader is empty!")
    exit()

# --- TRAINING LOOP ---
for epoch in range(num_epochs):
    generator.train()
    discriminator.train()
    
    current_g_loss = 0.0
    current_d_loss = 0.0
    valid_batches = 0

    # Train Step
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
    
    # UPDATED LOOP: Unpack 3 items
    for noisy_img, clean_img, mask in pbar:
        noisy_img = noisy_img.to(device)
        clean_img = clean_img.to(device)
        mask = mask.to(device)

        # NaN Check
        if torch.isnan(noisy_img).any() or torch.isinf(noisy_img).any():
            continue

        # ---------------------
        #  Train Discriminator
        # ---------------------
        d_optimizer.zero_grad()
        with autocast(device_type=device.type):
            fake_img = generator(noisy_img)
            fake_img = torch.clamp(fake_img, 0, 1) # Enforce [0,1] range
            
            # Real vs Fake inputs
            d_real = discriminator(clean_img)
            d_fake = discriminator(fake_img.detach()) # Detach to stop gradient to Generator
            
            d_loss = criterion.forward_discriminator(d_real, d_fake)

        if torch.isnan(d_loss) or torch.isinf(d_loss):
            continue

        scaler_d.scale(d_loss).backward()
        torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
        scaler_d.step(d_optimizer)
        scaler_d.update()

        # -----------------
        #  Train Generator
        # -----------------
        g_optimizer.zero_grad()
        with autocast(device_type=device.type):
            # Re-compute D output for Generator update (gradients flow this time)
            d_fake_for_g = discriminator(fake_img)
            
            # --- CRITICAL CHANGE: MASKING THE LOSS ---
            # Only calculate loss where mask == 1 (the blinded pixels)
            # Note: GAN loss usually applies to the whole image structure, 
            # but reconstruction loss (L1/Charbonnier) strictly needs the mask.
            # MasterLoss might need adjustment if it calculates L1 internally.
            # Ideally, pass masked images to MasterLoss if supported, or apply mask here.
            
            # For simplicity in this script, we assume MasterLoss handles the weighted sum.
            # We pass masked inputs where strictly pixel-wise comparison is needed.
            # However, since MasterLoss is a black box here, passing the full images is standard
            # UNLESS you modify MasterLoss. 
            
            # If MasterLoss is NOT modified to handle masks, simply passing full images 
            # trains strictly on the "Noisier2Noise" task (input has extra noise, target is original noisy).
            # The "Blind Spot" constraint is implicitly handled because the input pixels are corrupted.
            
            g_loss_inputs = {
                'pred_img': fake_img,
                'target_img': clean_img, 
                'd_fake_logits': d_fake_for_g,
                'mask': mask # Passing mask in case MasterLoss uses it (Optional if MasterLoss ignores it)
            }
            
            # If MasterLoss does NOT support 'mask', the training still works as a 
            # Denoising Autoencoder (DAE) because the input is corrupted (blinded) 
            # and the target is the original.
            
            g_loss = criterion.forward_generator(g_loss_inputs)

        if torch.isnan(g_loss) or torch.isinf(g_loss):
            continue

        scaler_g.scale(g_loss).backward()
        torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
        scaler_g.step(g_optimizer)
        scaler_g.update()

        current_g_loss += g_loss.item()
        current_d_loss += d_loss.item()
        valid_batches += 1
        
        # Update progress bar
        pbar.set_postfix({'G_Loss': g_loss.item(), 'D_Loss': d_loss.item()})

    # Epoch Averages
    if valid_batches > 0:
        current_g_loss /= valid_batches
        current_d_loss /= valid_batches
    else:
        current_g_loss = float('inf')
        current_d_loss = float('inf')

    # --- VALIDATION LOOP ---
    generator.eval()
    discriminator.eval()
    current_val_g_loss = 0.0
    valid_val_batches = 0
    
    with torch.no_grad():
        with autocast(device_type=device.type):
            # UPDATED LOOP: Unpack 3 items
            for val_input, val_target, val_mask in tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Valid]'):
                val_input = val_input.to(device)
                val_target = val_target.to(device)
                
                val_output = generator(val_input)
                val_output = torch.clamp(val_output, 0, 1)
                
                d_fake_val = discriminator(val_output)
                
                val_inputs = {
                    'pred_img': val_output,
                    'target_img': val_target,
                    'd_fake_logits': d_fake_val,
                    'mask': val_mask # Pass mask just in case
                }
                
                val_loss = criterion.forward_generator(val_inputs)
                
                if not (torch.isnan(val_loss) or torch.isinf(val_loss)):
                    current_val_g_loss += val_loss.item() * val_input.size(0)
                    valid_val_batches += val_input.size(0)

    if valid_val_batches > 0:
        current_val_g_loss /= valid_val_batches
    else:
        current_val_g_loss = float('inf')

    # --- VISUALIZATION (Using Fixed Batch) ---
    if (epoch + 1) % 2 == 0:
        generator.eval()
        with torch.no_grad():
            # Use the SAME fixed batch we grabbed at start
            fixed_pred = generator(fixed_val_input)
            fixed_pred = torch.clamp(fixed_pred, 0, 1)
            
            # Plot the first image from the fixed batch
            plot_images(fixed_val_input[0], fixed_pred[0], fixed_val_target[0], epoch+1)

    # --- LOGGING & SAVING ---
    train_g_losses.append(current_g_loss)
    train_d_losses.append(current_d_loss)
    val_g_losses.append(current_val_g_loss)

    print(f'Epoch {epoch + 1:02d} | Train G: {current_g_loss:.4f} | Train D: {current_d_loss:.4f} | Val G: {current_val_g_loss:.4f}')

    # Save Best Model
    if current_val_g_loss < best_val_loss and valid_val_batches > 0:
        best_val_loss = current_val_g_loss
        print(f"--> New Best Model! Loss: {best_val_loss:.4f}")
        torch.save({
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'g_optimizer_state_dict': g_optimizer.state_dict(),
            'd_optimizer_state_dict': d_optimizer.state_dict(),
            'val_loss': best_val_loss,
        }, checkpoint_path)

    # Save Latest Checkpoint
    torch.save({
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'g_optimizer_state_dict': g_optimizer.state_dict(),
        'd_optimizer_state_dict': d_optimizer.state_dict(),
        'val_loss': current_val_g_loss,
    }, 'NAFNet_GAN_LVUP_Dataset_7_Latest_Loss_3_Blind.pth')
    
    g_scheduler.step()
    d_scheduler.step()

# --- FINISH & SAVE CSV ---
loss_df = pd.DataFrame({
    'Epoch': range(1, num_epochs + 1),
    'Train Generator Loss': train_g_losses,
    'Train Discriminator Loss': train_d_losses,
    'Val Generator Loss': val_g_losses
})
loss_df.to_csv('NAFNet_GAN_LVUP_Dataset_7_Loss_3_Blind.csv', index=False)
print("Training Complete.")

# Final visualization
checkpoint = torch.load(checkpoint_path, map_location=device)
generator.load_state_dict(checkpoint['generator_state_dict'])
generator.eval()

# UPDATED: Unpack 3 items
val_input, val_target, val_mask = next(iter(val_loader))
val_input = val_input.to(device)
val_target = val_target.to(device)
with torch.no_grad():
    val_output = generator(val_input)

input_img = val_input[0].cpu()
pred_img = val_output[0].cpu()
target_img = val_target[0].cpu()

plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.title("Input Image (Noisy/Blinded)")
plt.imshow(input_img.squeeze(), cmap='gray')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.title("Predicted Image (Denoised)")
plt.imshow(pred_img.squeeze(), cmap='gray')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.title("Ground Truth (Target)")
plt.imshow(target_img.squeeze(), cmap='gray')
plt.axis('off')
plt.savefig(os.path.join("training_visuals", "final_result.png"))
plt.show()