In [None]:
import os
import torch
import numpy as np
import pandas as pd
from skimage import io
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import random

# --- IMPORTS FROM YOUR FILES ---
# Ensure models.py is accessible
from models import NAFNet

# --- YOUR DATASET CLASS (Copied here for standalone functionality) ---
class DenoisingDataset2D(Dataset):
    def __init__(self, noisy_paths, gt_paths, crop_size=None, augment=True, p=0.5):
        assert len(noisy_paths) == len(gt_paths), "Noisy and GT paths must have the same length"
        self.noisy_paths = noisy_paths
        self.gt_paths = gt_paths
        self.crop_size = crop_size
        self.augment = augment
        self.p = p

    def __getitem__(self, idx):
        noisy = io.imread(self.noisy_paths[idx]).astype(np.float32)
        gt = io.imread(self.gt_paths[idx]).astype(np.float32)
        
        # Normalize
        noisy = (noisy - noisy.min()) / (noisy.max() - noisy.min() + 1e-8)
        gt = (gt - gt.min()) / (gt.max() - gt.min() + 1e-8)
        
        # Augment (Skipped if augment=False)
        if self.augment:
            h, w = noisy.shape
            crop_h = crop_w = self.crop_size 
            if h < crop_h or w < crop_w:
                raise ValueError(f"Image too small for augmentation crop: ({h}, {w}) at index {idx}")
            
            max_x = h - crop_h
            max_y = w - crop_w
            x = random.randint(0, max_x) if max_x > 0 else 0
            y = random.randint(0, max_y) if max_y > 0 else 0
            noisy = noisy[x:x+crop_h, y:y+crop_w]
            gt = gt[x:x+crop_h, y:y+crop_w]
            
            if random.random() < self.p:
                noisy = np.fliplr(noisy).copy()
                gt = np.fliplr(gt).copy()
            if random.random() < self.p:
                noisy = np.flipud(noisy).copy()
                gt = np.flipud(gt).copy()
        
        # Convert to tensors
        noisy = torch.from_numpy(noisy.copy()).unsqueeze(0)
        gt = torch.from_numpy(gt.copy()).unsqueeze(0)
        return noisy, gt

    def __len__(self):
        return len(self.noisy_paths)

# --- HELPER: PADDING FOR NAFNET ---
def pad_to_multiple(x, multiple=32):
    """Pads image tensor to be divisible by 'multiple' (required for NAFNet)."""
    h, w = x.shape[2], x.shape[3]
    H = ((h + multiple - 1) // multiple) * multiple
    W = ((w + multiple - 1) // multiple) * multiple
    pad_h = H - h
    pad_w = W - w
    # Pad using reflection to minimize edge artifacts
    x_padded = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
    return x_padded, pad_h, pad_w

# --- CONFIGURATION ---
INPUT_DIR = r"D:\Manuscipts_Coding\Denoising_paper\IgG-1D\Exported_Data_TIFF\test\RAW"
OUTPUT_DIR = r"D:\Manuscipts_Coding\Denoising_paper\Models\Evaluation\Preds_NafNet_GAN_LV3\Dataset_7"
MODEL_PATH = "NAFNet_GAN_LVUP_Dataset_7_Conf-het_Best_Loss_3.pth" 

# Model Params (Must match training)
IMG_CHANNEL = 1
WIDTH = 16
ENC_BLKS = [2, 2, 4, 8]
MIDDLE_BLK_NUM = 12
DEC_BLKS = [2, 2, 2, 2]

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def run_inference():
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
    
    # 1. Prepare Data Lists
    # Get all .tif files
    file_names = sorted([f for f in os.listdir(INPUT_DIR) if f.endswith('.tif')])
    noisy_paths = [os.path.join(INPUT_DIR, f) for f in file_names]
    
    # TRICK: Use noisy paths as GT paths so the Dataset doesn't crash.
    # We won't use the loaded GT, but the Dataset requires the list to exist.
    gt_paths = noisy_paths 

    print(f"Found {len(noisy_paths)} images.")

    # 2. Initialize Dataset & DataLoader
    # CRITICAL: augment=False so it doesn't crop!
    test_ds = DenoisingDataset2D(noisy_paths, gt_paths, crop_size=None, augment=False)
    
    # Batch size 1 is safest for inference on full-size images
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0)

    # 3. Load Model
    print("Loading Model...")
    model = NAFNet(
        img_channel=IMG_CHANNEL,
        width=WIDTH,
        middle_blk_num=MIDDLE_BLK_NUM,
        enc_blk_nums=ENC_BLKS,
        dec_blk_nums=DEC_BLKS
    ).to(DEVICE)

    checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
    if 'generator_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['generator_state_dict'])
    else:
        model.load_state_dict(checkpoint)
    model.eval()

    # 4. Inference Loop
    print("Starting Inference...")
    
    # zip(test_loader, file_names) lets us access the data AND the original filename
    with torch.no_grad():
        for (noisy, _), filename in tqdm(zip(test_loader, file_names), total=len(file_names)):
            
            noisy = noisy.to(DEVICE) # Shape: (1, 1, H, W)

            # --- Handle Dimensions (Padding) ---
            # NAFNet crashes if H, W aren't multiples of 32
            noisy_padded, ph, pw = pad_to_multiple(noisy, multiple=32)

            # --- Forward Pass ---
            pred = model(noisy_padded)

            # --- Un-Pad (Crop back to original) ---
            if ph > 0 or pw > 0:
                pred = pred[:, :, :pred.shape[2]-ph, :pred.shape[3]-pw]

            # --- Post-Processing ---
            pred = torch.clamp(pred, 0, 1)
            pred_np = pred.squeeze().cpu().numpy() # Shape: (H, W)

            # --- Save ---
            save_path = os.path.join(OUTPUT_DIR, filename)
            io.imsave(save_path, pred_np, check_contrast=False)

    print(f"\nProcessing Complete. Saved to: {OUTPUT_DIR}")

if __name__ == "__main__":
    run_inference()