# N2V Denoising — Single 128×128 TEM Image

Uses the **original PPN2V library** with minimal config changes.

### What we change from defaults (and why)

| Parameter | Default | Ours | Reason |
|---|---|---|---|
| `num_classes` | 800 (PN2V) | **1** | N2V mode: single mean output, no noise model |
| `start_filts` | 64 | **32** | 128×128 = 16K pixels. 64 overfits |
| `depth` | 5 | **3** | 2 pooling ops, bottleneck 16×16 for 64-px patches |
| `merge_mode` | `'add'` | **`'concat'`** | Standard U-Net skip connections |
| `patchSize` | 100 | **64** | Must be < 128 for random crop diversity |
| `learningRate` | 1e-4 | **3e-4** | Faster convergence for small data |
| `virtualBatchSize` | 20 | **1** | No gradient accumulation needed |
| `noiseModel` | varies | **None** | Activates N2V mode automatically |

### What stays exactly as-is from the original code
- **Masking**: neighbor-replacement from 5×5 ROI (original N2V paper)
- **Mask ratio**: ~3% (`patchSize²/32`)
- **Weight init**: Xavier (UNet default)
- **Loss**: `lossFunctionN2V` — MSE on masked pixels only
- **Training functions**: `trainingPred()`, `lossFunction()` from `training.py`
- **Prediction**: `prediction.predict()` from `prediction.py`
- **Augmentation**: random flips + 90° rotations

---
## Cell 1 — Mount Drive & Set Paths

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os

# === EDIT THESE PATHS IF NEEDED ===
DRIVE_DATA_PATH   = '/content/drive/MyDrive/PPN2V/DATASET_01'
N2V_RESULTS_PATH  = '/content/drive/MyDrive/PPN2V/results/DATASET_01/n2v_optimized'
PN2V_RESULTS_PATH = '/content/drive/MyDrive/PPN2V/results/DATASET_01/pn2v_bootstrap'
COMPARISON_PATH   = '/content/drive/MyDrive/PPN2V/results/DATASET_01/comparison'

for p in [N2V_RESULTS_PATH, PN2V_RESULTS_PATH, COMPARISON_PATH]:
    os.makedirs(p, exist_ok=True)

print('Drive mounted')
print(f'Data:    {DRIVE_DATA_PATH}')
print(f'Results: {N2V_RESULTS_PATH}')

---
## Cell 2 — Clone Repo & Install

In [None]:
import subprocess, sys, shutil

REPO_PATH    = '/content/PPN2V'
PROJECT_ROOT = '/content/PPN2V/PPN2V-main'
GITHUB_REPO  = 'https://github.com/ZurvanAkarna/PPN2V-main.git'

# 1. Always start fresh to avoid stale cache issues
if os.path.exists(REPO_PATH):
    shutil.rmtree(REPO_PATH)
    print(f'Removed old {REPO_PATH}')

# 2. Clone
subprocess.run(['git', 'clone', GITHUB_REPO, REPO_PATH], check=True)
print(f'Cloned to {REPO_PATH}')

# 3. Verify structure
print(f'\nFiles in PROJECT_ROOT ({PROJECT_ROOT}):')
for f in sorted(os.listdir(PROJECT_ROOT)):
    print(f'  {f}')

# 4. Install — capture output so errors are visible
os.chdir(PROJECT_ROOT)

def pip_install(args, label=''):
    """Run pip and show full output on failure."""
    result = subprocess.run(
        [sys.executable, '-m', 'pip', 'install'] + args,
        capture_output=True, text=True)
    if result.returncode != 0:
        print(f'\n❌ pip install {" ".join(args)} FAILED:')
        print(result.stdout)
        print(result.stderr)
        raise RuntimeError(f'pip install failed: {" ".join(args)}')
    else:
        if label:
            print(f'  ✓ {label}')

pip_install(['hatchling'], 'hatchling')
pip_install(['-e', '.'], 'ppn2v (editable)')
pip_install(['tifffile', 'scikit-image'], 'tifffile + scikit-image')
print('\n✅ Installation complete')

---
## Cell 3 — Imports & GPU

In [None]:
import numpy as np
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
import tifffile
import time

sys.path.insert(0, '/content/PPN2V/PPN2V-main/src')
from ppn2v.unet.model import UNet
from ppn2v.pn2v import training, prediction, utils

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if device.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')

---
## Cell 4 — Load Data

**Edit `NOISY_FILENAME` and `GT_FILENAME` below.**  
Both should be 128×128. The ground truth is **only** for evaluation (PSNR), never for training.

In [None]:
# List files in data directory
print(f'Files in {DRIVE_DATA_PATH}:')
for f in sorted(os.listdir(DRIVE_DATA_PATH)):
    fpath = os.path.join(DRIVE_DATA_PATH, f)
    if os.path.isfile(fpath):
        print(f'  {f}  ({os.path.getsize(fpath)/1e3:.1f} KB)')

# ============================================================
# >>> EDIT THESE TWO FILENAMES <<<
# ============================================================
NOISY_FILENAME = 'noisy_image_jitter_skips_0__0_3_flags_0__0_4_Gaussian_0.6.tif'         # your noisy 128x128 image
GT_FILENAME    = 'clean_image.tif'   # your clean ground truth
# ============================================================

noisy_img = tifffile.imread(os.path.join(DRIVE_DATA_PATH, NOISY_FILENAME)).astype(np.float32)
gt_img    = tifffile.imread(os.path.join(DRIVE_DATA_PATH, GT_FILENAME)).astype(np.float32)

# Handle 3D (1, H, W) files
if noisy_img.ndim == 3: noisy_img = noisy_img[0]
if gt_img.ndim == 3:    gt_img = gt_img[0]

print(f'\nNoisy: shape={noisy_img.shape}, range=[{noisy_img.min():.1f}, {noisy_img.max():.1f}]')
print(f'GT:    shape={gt_img.shape}, range=[{gt_img.min():.1f}, {gt_img.max():.1f}]')

# trainNetwork expects 3D array (N, H, W). N=1 for single image.
# We use the SAME image for train & val. With only 128x128 pixels,
# a spatial split leaves regions too small for 64x64 patches.
# Blind-spot masking already prevents trivial identity solutions.
train_data = noisy_img[np.newaxis, ...].copy()  # (1, 128, 128)
val_data   = train_data.copy()

print(f'\nTrain: {train_data.shape},  Val: {val_data.shape}')

# Baseline PSNR (noisy vs GT)
data_range = gt_img.max() - gt_img.min()
input_psnr = utils.PSNR(gt_img, noisy_img, range_=data_range)
print(f'Input PSNR (noisy vs GT): {input_psnr:.2f} dB')

---
## Cell 5 — Create U-Net

In [None]:
net = UNet(
    num_classes=1,           # N2V: single output (mean prediction)
    in_channels=1,           # grayscale
    depth=3,                 # 2 pooling ops -> 32 -> 64 -> 128
    start_filts=32,          # C=32
    up_mode='transpose',
    merge_mode='concat',     # standard U-Net concat skip connections
)

total_params = sum(p.numel() for p in net.parameters())
print(f'U-Net: {total_params:,} parameters')
print(f'Encoder: 1 -> 32 -> 64 -> 128')
print(f'Decoder: 128 -> 64 -> 32 -> 1')

# Sanity check
with torch.no_grad():
    t = torch.randn(1, 1, 64, 64)
    print(f'Forward: {list(t.shape)} -> {list(net(t).shape)}')

---
## Cell 6 — Train with Per-Epoch PSNR Tracking

This is the only substantial addition: after each epoch we call the original
`prediction.predict()` on the full image and compute PSNR vs ground truth.

The **training itself** uses the exact original functions:
- `training.trainingPred()` — assembles batch, crops, masks, forward pass
- `training.lossFunction()` — MSE on masked pixels / std²
- `prediction.predict()` — standard forward with normalization + denormalization

In [None]:
# === TRAINING HYPERPARAMETERS ===
PATCH_SIZE = 64
NUM_MASKED = int(PATCH_SIZE**2 / 32.0)   # ~128 pixels = ~3.1%
NUM_EPOCHS = 300
STEPS_PER_EPOCH = 50
BATCH_SIZE = 4
LR = 3e-4

print(f'Patch:   {PATCH_SIZE}x{PATCH_SIZE}')
print(f'Masked:  {NUM_MASKED} pixels/patch ({NUM_MASKED/(PATCH_SIZE**2)*100:.1f}%)')
print(f'Method:  neighbor replacement (original training.py)')
print(f'Loss:    lossFunctionN2V (MSE on masked pixels only)')
print(f'LR:      {LR},  Epochs: {NUM_EPOCHS}')

# --- Normalization (same as trainNetwork line 340-341) ---
combined = np.concatenate([train_data, val_data])
net.mean = np.mean(combined)
net.std  = np.std(combined)
net.to(device)

optimizer = optim.Adam(net.parameters(), lr=LR)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 'min', patience=10, factor=0.5, verbose=True)

train_losses = []
val_losses   = []
psnr_history = []
best_val     = float('inf')
best_psnr    = 0.0
best_epoch   = 0
dataCounter  = 0

print(f'\nNormalization: mean={net.mean:.2f}, std={net.std:.2f}')
print(f'GT data range for PSNR: {data_range:.1f}')
print(f'\n{"="*65}')
print(f'{"Epoch":>5} | {"Train":>10} | {"Val":>10} | {"PSNR(dB)":>9} | {"LR":>9} | Note')
print(f'{"="*65}')

t_start = time.time()

for epoch in range(NUM_EPOCHS):

    # ===== TRAIN (original functions) =====
    net.train()
    ep_losses = []
    for step in range(STEPS_PER_EPOCH):
        optimizer.zero_grad()
        outputs, labels, masks, dataCounter = training.trainingPred(
            train_data, net, dataCounter,
            PATCH_SIZE, BATCH_SIZE, NUM_MASKED,
            device, augment=True, supervised=False)
        loss = training.lossFunction(
            outputs, labels, masks,
            noiseModel=None, pn2v=False, std=net.std)
        loss.backward()
        optimizer.step()
        ep_losses.append(loss.item())

    avg_train = np.mean(ep_losses)
    train_losses.append(avg_train)

    # ===== VALIDATION (original functions) =====
    net.eval()
    v_losses = []
    valCounter = 0
    with torch.no_grad():
        for _ in range(20):
            outputs, labels, masks, valCounter = training.trainingPred(
                val_data, net, valCounter,
                PATCH_SIZE, BATCH_SIZE, NUM_MASKED,
                device, augment=False, supervised=False)
            loss = training.lossFunction(
                outputs, labels, masks,
                noiseModel=None, pn2v=False, std=net.std)
            v_losses.append(loss.item())

    avg_val = np.mean(v_losses)
    val_losses.append(avg_val)

    # ===== PSNR (original prediction.predict) =====
    with torch.no_grad():
        denoised, _ = prediction.predict(
            noisy_img, net, noiseModel=None,
            device=device, outScaling=10.0)
    psnr = utils.PSNR(gt_img, denoised, range_=data_range)
    psnr_history.append(psnr)

    # ===== CHECKPOINT =====
    note = ''
    if avg_val < best_val:
        best_val = avg_val
        best_epoch = epoch + 1
        best_psnr = psnr
        torch.save(net, os.path.join(N2V_RESULTS_PATH, 'best_n2v.net'))
        note = '<< BEST'
    torch.save(net, os.path.join(N2V_RESULTS_PATH, 'last_n2v.net'))

    scheduler.step(avg_val)
    current_lr = optimizer.param_groups[0]['lr']

    # ===== PRINT =====
    if (epoch + 1) % 10 == 0 or epoch == 0 or note:
        print(f'{epoch+1:5d} | {avg_train:10.5f} | {avg_val:10.5f} | '
              f'{psnr:8.2f}  | {current_lr:.2e} | {note}')

    # Save histories each epoch (safe against crashes)
    np.savez(os.path.join(N2V_RESULTS_PATH, 'history.npz'),
             train=train_losses, val=val_losses, psnr=psnr_history)

total_min = (time.time() - t_start) / 60
peak_epoch = int(np.argmax(psnr_history) + 1)
peak_psnr  = max(psnr_history)

print(f'{"="*65}')
print(f'Done in {total_min:.1f} min')
print(f'Best val loss:  epoch {best_epoch}  (PSNR {best_psnr:.2f} dB)')
print(f'Peak PSNR:      epoch {peak_epoch}  ({peak_psnr:.2f} dB)')
print(f'Input PSNR:     {input_psnr:.2f} dB')
print(f'Improvement:    +{peak_psnr - input_psnr:.2f} dB')

---
## Cell 7 — Loss vs Epoch

In [None]:
epochs_arr = np.arange(1, len(train_losses) + 1)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(epochs_arr, train_losses, 'b-', alpha=0.7, label='Train')
ax1.plot(epochs_arr, val_losses,   'r-', alpha=0.7, label='Val')
ax1.axvline(best_epoch, color='green', ls='--', alpha=0.5,
            label=f'Best epoch {best_epoch}')
ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss')
ax1.set_title('Training & Validation Loss')
ax1.legend(); ax1.grid(True, alpha=0.3)

ax2.plot(epochs_arr, train_losses, 'b-', alpha=0.7, label='Train')
ax2.plot(epochs_arr, val_losses,   'r-', alpha=0.7, label='Val')
ax2.axvline(best_epoch, color='green', ls='--', alpha=0.5,
            label=f'Best epoch {best_epoch}')
ax2.set_xlabel('Epoch'); ax2.set_ylabel('Loss (log)')
ax2.set_title('Loss (log scale)'); ax2.set_yscale('log')
ax2.legend(); ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(N2V_RESULTS_PATH, 'loss_vs_epoch.png'),
            dpi=150, bbox_inches='tight')
plt.show()

---
## Cell 8 — PSNR vs Epoch

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))

ax.plot(epochs_arr, psnr_history, 'b-', lw=1.5,
        label='PSNR (denoised vs GT)')
ax.axhline(input_psnr, color='red', ls='--', alpha=0.7,
           label=f'Input PSNR = {input_psnr:.2f} dB')
ax.scatter([peak_epoch], [peak_psnr], color='green', s=100, zorder=5,
           label=f'Peak: {peak_psnr:.2f} dB @ epoch {peak_epoch}')
ax.axvline(peak_epoch, color='green', ls='--', alpha=0.3)

ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('PSNR (dB)', fontsize=12)
ax.set_title('PSNR vs Epoch', fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(N2V_RESULTS_PATH, 'psnr_vs_epoch.png'),
            dpi=150, bbox_inches='tight')
plt.show()

print(f'Peak PSNR:  {peak_psnr:.2f} dB  (epoch {peak_epoch})')
print(f'Input PSNR: {input_psnr:.2f} dB')
print(f'Gain:       +{peak_psnr - input_psnr:.2f} dB')

---
## Cell 9 — Load Best & Final Prediction

In [None]:
# Load best checkpoint
best_net = torch.load(
    os.path.join(N2V_RESULTS_PATH, 'best_n2v.net'),
    map_location=device)
best_net.eval()

# Predict using original prediction.predict()
with torch.no_grad():
    n2v_denoised, _ = prediction.predict(
        noisy_img, best_net, noiseModel=None,
        device=device, outScaling=10.0)

final_psnr = utils.PSNR(gt_img, n2v_denoised, range_=data_range)
print(f'Denoised PSNR: {final_psnr:.2f} dB')
print(f'Input PSNR:    {input_psnr:.2f} dB')
print(f'Improvement:   +{final_psnr - input_psnr:.2f} dB')

# Save
tifffile.imwrite(os.path.join(N2V_RESULTS_PATH, 'n2v_denoised.tif'),
                 n2v_denoised.astype(np.float32))
tifffile.imwrite(os.path.join(N2V_RESULTS_PATH, 'original_noisy.tif'),
                 noisy_img.astype(np.float32))
print(f'\nSaved to {N2V_RESULTS_PATH}')

---
## Cell 10 — Visual Comparison: Noisy / Denoised / GT / Residual

In [None]:
vmin, vmax = np.percentile(gt_img, [1, 99])

fig, axes = plt.subplots(2, 2, figsize=(12, 12))

axes[0,0].imshow(noisy_img, cmap='gray', vmin=vmin, vmax=vmax)
axes[0,0].set_title(f'Noisy Input\nPSNR = {input_psnr:.2f} dB', fontsize=13)

axes[0,1].imshow(n2v_denoised, cmap='gray', vmin=vmin, vmax=vmax)
axes[0,1].set_title(f'N2V Denoised (Best Checkpoint)\nPSNR = {final_psnr:.2f} dB', fontsize=13)

axes[1,0].imshow(gt_img, cmap='gray', vmin=vmin, vmax=vmax)
axes[1,0].set_title('Ground Truth', fontsize=13)

residual = noisy_img - n2v_denoised
rlim = np.std(residual) * 3
axes[1,1].imshow(residual, cmap='RdBu_r', vmin=-rlim, vmax=rlim)
axes[1,1].set_title(f'Residual (Removed Noise)\nstd = {np.std(residual):.1f}', fontsize=13)

for ax in axes.flat:
    ax.axis('off')

plt.suptitle('N2V Single-Image Denoising', fontsize=16, fontweight='bold', y=0.98)
plt.tight_layout()
plt.savefig(os.path.join(N2V_RESULTS_PATH, 'comparison_4panel.png'),
            dpi=150, bbox_inches='tight')
plt.show()

---
## Cell 11 — Error Maps

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

err_noisy    = np.abs(noisy_img - gt_img)
err_denoised = np.abs(n2v_denoised - gt_img)
err_max      = max(err_noisy.max(), err_denoised.max())

axes[0].imshow(err_noisy, cmap='hot', vmin=0, vmax=err_max)
axes[0].set_title(f'|Noisy \u2212 GT|\nMAE = {err_noisy.mean():.2f}', fontsize=13)

axes[1].imshow(err_denoised, cmap='hot', vmin=0, vmax=err_max)
axes[1].set_title(f'|Denoised \u2212 GT|\nMAE = {err_denoised.mean():.2f}', fontsize=13)

improvement = err_noisy - err_denoised
ilim = np.std(improvement) * 3
axes[2].imshow(improvement, cmap='RdYlGn', vmin=-ilim, vmax=ilim)
axes[2].set_title('Improvement Map\n(green = denoised better)', fontsize=13)

for ax in axes:
    ax.axis('off')

plt.tight_layout()
plt.savefig(os.path.join(N2V_RESULTS_PATH, 'error_maps.png'),
            dpi=150, bbox_inches='tight')
plt.show()

---
## Cell 12 — Summary

In [None]:
print('=' * 60)
print('  N2V RESULTS SUMMARY')
print('=' * 60)
print(f'''
  Architecture
    Model:          UNet (depth=3, C=32, concat skips)
    Parameters:     {total_params:,}
    Init:           Xavier (original default)
    Activations:    ReLU  (original default)

  Training (all using original training.py functions)
    Masking:        neighbor replacement ~3%
    Patch size:     {PATCH_SIZE}x{PATCH_SIZE}
    Learning rate:  {LR}
    Batch size:     {BATCH_SIZE}
    Epochs run:     {len(train_losses)} / {NUM_EPOCHS}

  Results
    Input PSNR:     {input_psnr:.2f} dB
    Best PSNR:      {peak_psnr:.2f} dB  (epoch {peak_epoch})
    Final PSNR:     {final_psnr:.2f} dB  (best checkpoint, epoch {best_epoch})
    Improvement:    +{final_psnr - input_psnr:.2f} dB

  Files saved to {N2V_RESULTS_PATH}:
''')
for f in sorted(os.listdir(N2V_RESULTS_PATH)):
    sz = os.path.getsize(os.path.join(N2V_RESULTS_PATH, f)) / 1e3
    print(f'    {f:35s} ({sz:.1f} KB)')

print(f'\n  Next: use n2v_denoised.tif as pseudo-GT for PN2V bootstrap')